diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml
index 0068ac861e17db6b78af2c834bdac3a5ea5e9dab..4c7798012f351c1f30e266c72ca4f28f08d87933 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yaml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yaml
@@ -1,18 +1,21 @@
 name: 🪲 Bug Report
 description: Something went wrong? Let us know! 🐣
-title: "[Bug]: "
 labels: ["bug"]
 body:
   - type: markdown
     attributes:
       value: |
-        Before submitting a bug, please make sure the issue hasn't been already addressed by searching through the existing and past issues.
+        **Before submitting a bug report, please read the following instructions:**
+
+        - Make sure the issue hasn't already been addressed by searching through existing and past issues.
+        - Use a clear and concise title for your bug report.
+        - Fill out all relevant sections below to help us understand and reproduce the issue.
 
   - type: textarea
     id: describe-the-bug
     attributes:
       label: Describe the bug
-      description: Short and clear description of what the bug is.
+      description: Provide a clear and concise description of the bug.
     validations:
       required: True
 
@@ -20,7 +23,7 @@ body:
     id: expected-behaviour
     attributes:
       label: Expected behaviour
-      description: A description of what you expected to happen.
+      description: Describe what you expected to happen.
     validations:
       required: True
 
@@ -29,7 +32,15 @@ body:
     attributes:
       label: To Reproduce
       description: |
-        If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as minimal as possible. We will copy-paste your code, and we expect to get the same result as you did: avoid any external data, and include the relevant imports.
+        If relevant, add a minimal example or detailed steps to reproduce the error. You can share code directly using Google Colab:
+        1. Visit [Google Colab](https://colab.research.google.com/).
+        2. Create a new notebook.
+        3. Paste your code into the notebook.
+        4. Share the notebook by clicking on "Share" in the top-right corner.
+        5. Share the notebook's link here.
+
+        In the worst case, provide detailed steps to reproduce the behavior.
+
       placeholder: "```python #your code ``` \n ```yaml #your yaml code ```"
     validations:
       required: False
@@ -37,27 +48,29 @@ body:
   - type: textarea
     id: versions
     attributes:
-      label: Versions
-      description: "Please tell us more about your current SpeechBrain version and/or git hash (if installed via cloning+editable install). You can also add other setup information that might be relevant."
+      label: Environment Details
+      description: Provide information about your SpeechBrain version, setup, and any other relevant environment details.
     validations:
       required: False
 
   - type: textarea
     id: logs
     attributes:
-      label: Relevant log output
-      description: Please copy and paste any relevant log output.
+      label: Relevant Log Output
+      description: Copy and paste any relevant log output here.
       render: shell
+    validations:
+      required: False
 
   - type: textarea
     id: add-context
     attributes:
-      label: Additional context
-      description: "Add any other context about the problem here."
+      label: Additional Context
+      description: Share any other context about the problem or your environment that may help in troubleshooting.
     validations:
       required: False
 
   - type: markdown
     attributes:
       value: |
-        Thanks for contributing to SpeechBrain!
+        **Thank you for contributing to SpeechBrain!** Your bug report helps us improve the project's reliability.
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 835b4b860b5d952e44ebba55f58e918bb01e7ccf..c04d6d1613bb83163a4a1cb3e56ccc1b7fe96a20 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -1,28 +1,45 @@
-# Contribution in a nutshell
-Hey, this could help our community 🌱
+## What does this PR do?
 
-# Scope
-* [ ] I want to get done ...
-* [ ] ... and hope to also achieve ...
+<!--
+Please include a summary of the change and which issue is fixed.
+Please also include relevant motivation and context.
+List any dependencies that are required for this change.
 
-# Notes for reviewing (optional)
-This change has these implication which might need attention over here; —how should we tackle this?
+-->
 
-# Pre-review
-* [ ] (if applicable) add an `extra_requirements.txt` file
-* [ ] (if applicable) add database preparation scripts & use symlinks for nested folders (to the level of task READMEs)
-* [ ] (if applicable) add a recipe test entry in the depending CSV file under: tests/recipes
-* [ ] create a fresh testing environment (install SpeechBrain from cloned repo branch of this PR)
-* [ ] (if applicable) run a recipe test for each yaml/your recipe dataset
-* [ ] check function comments: are there docstrings w/ arguments & returns? If you're not the verbose type, put a comment every three lines of code (better: every line)
-* [ ] use CI locally: `pre-commit run -a` to check linters; run `pytest tests/consistency`
-* [ ] (optional) run `tests/.run-doctests.sh` & `tests/.run-unittests.sh`
-* [ ] exhausted patience before clicking « Ready for review » in the merge box 🍄
+Fixes #<issue_number>
 
----
+<!-- Does your PR introduce any breaking changes? If yes, please list them. -->
 
-Note: when merged, we desire to include your PR title in our contributions list, check out one of our past version releases
-—https://github.com/speechbrain/speechbrain/releases/tag/v0.5.14
+<details>
+  <summary><b>Before submitting</b></summary>
 
-Tip: below, on the « Create Pull Request » use the drop-down to select: « Create Draft Pull Request » – your PR will be in draft mode until you declare it « Ready for review »
+- [ ] Did you read the [contributor guideline](https://speechbrain.readthedocs.io/en/latest/contributing.html)?
+- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together?
+- [ ] Did you make sure to **update the documentation** with your changes? (if necessary)
+- [ ] Did you write any **new necessary tests**? (not for typos and docs)
+- [ ] Did you verify new and **existing [tests](https://github.com/speechbrain/speechbrain/tree/develop/tests) pass** locally with your changes?
+- [ ] Did you list all the **breaking changes** introduced by this pull request?
+- [ ] Does your code adhere to project-specific code style and conventions?
 
+</details>
+
+## PR review
+
+<details>
+  <summary>Reviewer checklist</summary>
+
+- [ ] Is this pull request ready for review? (if not, please submit in draft mode)
+- [ ] Check that all items from **Before submitting** are resolved
+- [ ] Make sure the title is self-explanatory and the description concisely explains the PR
+- [ ] Add labels and milestones (and optionally projects) to the PR so it can be classified
+- [ ] Confirm that the changes adhere to compatibility requirements (e.g., Python version, platform)
+- [ ] Review the self-review checklist to ensure the code is ready for review
+
+</details>
+
+<!--
+
+🎩 Magic happens when you code. Keep the spells flowing!
+
+-->
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index 377fca5debd0a71bffea804087bfb0994ff94ac4..fd71f2716d2fc06657e5a22158ea2a77357d8bf7 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -12,5 +12,5 @@ jobs:
       - uses: actions/checkout@v2
       - uses: actions/setup-python@v2
         with:
-          python-version: '3.8'
+          python-version: '3.9'
       - uses: pre-commit/action@v2.0.3
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
index dca91b29d1685da4be660730642a4e7bd3f6c711..55e4c882c00c358be84eec265bf8c61ce9005450 100644
--- a/.github/workflows/pythonapp.yml
+++ b/.github/workflows/pythonapp.yml
@@ -21,22 +21,29 @@ jobs:
               uses: actions/setup-python@v1
               with:
                   python-version: ${{ matrix.python-version }}
-            - name: Install libsndfile
+            - name: Install sox
               run: |
                   sudo apt-get update
-                  sudo apt-get install -y libsndfile1
-            - name: Install ffmpeg
-              run: |
-                  sudo apt-get update
-                  sudo apt-get install -y ffmpeg
+                  sudo apt install sox libsox-dev
+            # Installing only SoX for now due to FFmpeg issues on the CI server with Torchaudio 2.1.
+            # FFmpeg works fine on all other machines. We'll switch back when the CI server is fixed.
+            #- name: Install ffmpeg
+            #  run: |
+            #      sudo apt-get update
+            #      sudo apt-get install -y ffmpeg
             - name: Display Python version
               run: python -c "import sys; print(sys.version)"
             - name: Full dependencies
               run: |
                   sudo apt-get update
+                  # up to k2 compatible torch version
+                  pip install torch==2.1.2 torchaudio==2.1.2
                   pip install -r requirements.txt
                   pip install --editable .
                   pip install ctc-segmentation
+                  pip install k2==1.24.4.dev20231220+cpu.torch2.1.2 -f https://k2-fsa.github.io/k2/cpu.html
+                  pip install protobuf
+                  pip install kaldilm==1.15
             - name: Consistency tests with pytest
               run: |
                   pytest tests/consistency
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index b48b4c032fd2713da74a291f6c05b9d022edf25c..67fb026ea0b2e8e8cdbd2ff598847bd759a873cc 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -17,7 +17,7 @@ jobs:
           ref: main
       - uses: actions/setup-python@v2
         with:
-          python-version: 3.8
+          python-version: 3.9
       - name: Install pypa/build
         run: python -m pip install build --user
       - name: Build binary wheel and source tarball
diff --git a/.github/workflows/verify-docs-gen.yml b/.github/workflows/verify-docs-gen.yml
index ac279a9cd8369b3140a9355f2a3ed4fad82d25b5..3c87aba2b1bdc05fe2b9f389589f2e8d9fdffe76 100644
--- a/.github/workflows/verify-docs-gen.yml
+++ b/.github/workflows/verify-docs-gen.yml
@@ -11,15 +11,18 @@ jobs:
         runs-on: ubuntu-latest
         steps:
             - uses: actions/checkout@v2
-            - name: Setup Python 3.8
+            - name: Setup Python 3.9
               uses: actions/setup-python@v2
               with:
-                  python-version: '3.8'
+                  python-version: '3.9'
             - name: Full dependencies
               run: |
+                  # up to k2 compatible torch version
+                  pip install torch==2.1.2 torchaudio==2.1.2
                   pip install -r requirements.txt
                   pip install --editable .
                   pip install -r docs/docs-requirements.txt
+                  pip install k2==1.24.4.dev20231220+cpu.torch2.1.2 -f https://k2-fsa.github.io/k2/cpu.html
             - name: Generate docs
               run: |
                   cd docs
diff --git a/.gitignore b/.gitignore
index 6606d1085924b4f3912bf4b629e2ec4695380791..598748dd87186941dc8f093e14f7b44a26cb5db6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -72,11 +72,8 @@ instance/
 .scrapy
 
 # Sphinx documentation
-docs/_build/
-docs/source/*.rst
-!docs/source/index.rst
-!docs/source/_templates
-!docs/source/_static
+docs/build/
+docs/API/*.rst
 
 # PyBuilder
 target/
@@ -158,4 +155,4 @@ dmypy.json
 **/log/
 
 # Mac OS
-.DS_Store
\ No newline at end of file
+.DS_Store
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index ed8451a3c0368e35df064d0ff2df07fff824e7b0..f445a2ec62f326f45dc0047769cbc2f99aa76f6f 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -1,13 +1,15 @@
 # .readthedocs.yaml
 
+version: 2
+
 build:
-  image: latest
+  os: ubuntu-20.04
+  tools:
+    python: "3.9"
 
 python:
-  version: 3.8
-  pip_install: True
+  install:
+    - requirements: docs/readthedocs-requirements.txt
 
 # Don't build any extra formats
 formats: []
-
-requirements_file: docs/docs-requirements.txt
diff --git a/CITATION.cff b/CITATION.cff
new file mode 100644
index 0000000000000000000000000000000000000000..8b48809676d505ef70e89bbb86189446d5b70585
--- /dev/null
+++ b/CITATION.cff
@@ -0,0 +1,116 @@
+# This CITATION.cff file was generated with cffinit.
+# Visit https://bit.ly/cffinit to generate yours today!
+
+cff-version: 1.2.0
+title: SpeechBrain
+message: A PyTorch-based Speech Toolkit
+type: software
+authors:
+  - given-names: Mirco
+    family-names: Ravanelli
+    affiliation: 'Mila - Quebec AI Institute, Université de Montréal'
+  - given-names: Titouan
+    family-names: Parcollet
+    affiliation: >-
+      LIA - Avignon Université, CaMLSys - University of
+      Cambridge
+  - given-names: Peter
+    family-names: Plantinga
+    affiliation: Ohio State University
+  - given-names: Aku
+    family-names: Rouhe
+    affiliation: Aalto University
+  - given-names: Samuele
+    family-names: Cornell
+    affiliation: Università Politecnica delle Marche
+  - given-names: Loren
+    family-names: Lugosch
+    affiliation: 'Mila - Quebec AI Institute, McGill University'
+  - given-names: Cem
+    family-names: Subakan
+    affiliation: Mila - Quebec AI Institute
+  - given-names: Nauman
+    family-names: Dawalatabad
+    affiliation: Indian Institute of Technology Madras
+  - given-names: Abdelwahab
+    family-names: Heba
+    affiliation: IRIT - Université Paul Sabatier
+  - given-names: Jianyuan
+    family-names: Zhong
+    affiliation: Mila - Quebec AI Institute
+  - given-names: Ju-Chieh
+    family-names: Chou
+    affiliation: Toyota Technological Institute at Chicago
+  - given-names: Sung-Lin
+    family-names: Yeh
+    affiliation: University of Edinburgh
+  - given-names: Szu-Wei
+    family-names: Fu
+    affiliation: 'Academia Sinica, Taiwan'
+  - given-names: Chien-Feng
+    family-names: Liao
+    affiliation: 'Academia Sinica, Taiwan'
+  - given-names: Elena
+    family-names: Rastorgueva
+    affiliation: NVIDIA
+  - given-names: François
+    family-names: Grondin
+    affiliation: Université de Sherbrooke
+  - given-names: William
+    family-names: Aris
+    affiliation: Université de Sherbrooke
+  - given-names: Hwidong
+    family-names: Na
+    affiliation: Samsung-SAIT
+  - given-names: Yan
+    family-names: Gao
+    affiliation: CaMLSys - University of Cambridge
+  - given-names: Renato
+    name-particle: De
+    family-names: Mori
+    affiliation: 'LIA - Avignon Université, McGill University'
+  - given-names: Yoshua
+    family-names: Bengio
+    affiliation: 'Mila - Quebec AI Institute, Université de Montréal'
+identifiers:
+  - type: doi
+    value: 10.48550/arXiv.2106.04624
+    description: 'SpeechBrain: A General-Purpose Speech Toolkit'
+repository-code: 'https://github.com/speechbrain/speechbrain/'
+url: 'https://speechbrain.github.io/'
+abstract: >-
+  SpeechBrain is an open-source and all-in-one speech
+  toolkit. It is designed to facilitate the research and
+  development of neural speech processing technologies by
+  being simple, flexible, user-friendly, and
+  well-documented. This paper describes the core
+  architecture designed to support several tasks of common
+  interest, allowing users to naturally conceive, compare
+  and share novel speech processing pipelines. SpeechBrain
+  achieves competitive or state-of-the-art performance in a
+  wide range of speech benchmarks. It also provides training
+  recipes, pretrained models, and inference scripts for
+  popular speech datasets, as well as tutorials which allow
+  anyone with basic Python proficiency to familiarize
+  themselves with speech technologies.
+keywords:
+  - speech toolkit
+  - audio
+  - deep learning
+  - PyTorch
+  - transformers
+  - voice recognition
+  - speech recognition
+  - speech-to-text
+  - language model
+  - speaker recognition
+  - speaker verification
+  - speech processing
+  - audio processing
+  - ASR
+  - speaker diarization
+  - speech separation
+  - speech enhancement
+  - spoken language understanding
+  - HuggingFace
+license: Apache-2.0
diff --git a/PERFORMANCE.md b/PERFORMANCE.md
new file mode 100644
index 0000000000000000000000000000000000000000..045f6dd0c2b2942debf0354090f5ce18ef8a5413
--- /dev/null
+++ b/PERFORMANCE.md
@@ -0,0 +1,479 @@
+# SpeechBrain Performance Report
+    This document provides an overview of the performance achieved on key datasets and tasks supported by SpeechBrain.
+
+## AISHELL-1 Dataset
+
+### ASR
+
+| Model | Checkpoints | HuggingFace | Test-CER |
+| --------| --------| --------| --------|
+ | recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/e4bth1bylk7c6h8/AADFq3cWzBBKxuDv09qjvUMta?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-ctc-aishell) | 5.06 |
+ | recipes/AISHELL-1/ASR/seq2seq/hparams/train.yaml | [here](https://www.dropbox.com/sh/kefuzzf6jaljqbr/AADBRWRzHz74GCMDqJY9BES4a?dl=0) | - | 7.51 |
+ | recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer.yaml | [here](https://www.dropbox.com/sh/tp6tjmysorgvsr4/AAD7KNqi1ot0gR4N406JbKM6a?dl=0) | [here](https://huggingface.co/speechbrain/asr-transformer-aishell) | 6.04 |
+ | recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer_with_wav2vect.yaml | [here](https://www.dropbox.com/sh/tp6tjmysorgvsr4/AAD7KNqi1ot0gR4N406JbKM6a?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-transformer-aishell) | 5.58 |
+
+
+## Aishell1Mix Dataset
+
+### Separation
+
+| Model | Checkpoints | HuggingFace | SI-SNRi |
+| --------| --------| --------| --------|
+ | recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2.yaml | [here](https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0) | - | 13.4dB |
+ | recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3.yaml | [here](https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0) | - | 11.2dB |
+
+
+## BinauralWSJ0Mix Dataset
+
+### Separation
+
+| Model | Checkpoints | HuggingFace | SI-SNRi |
+| --------| --------| --------| --------|
+ | recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-cross.yaml | [here](https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0) | - | 12.39dB |
+ | recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-independent.yaml | [here](https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0) | - | 11.90dB |
+ | recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-noise.yaml | [here](https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0) | - | 18.25dB |
+ | recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-reverb.yaml | [here](https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0) | - | 6.95dB |
+ | recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel.yaml | [here](https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0) | - | 16.93dB |
+
+
+## CVSS Dataset
+
+### S2ST
+
+| Model | Checkpoints | HuggingFace | Test-sacrebleu |
+| --------| --------| --------| --------|
+ | recipes/CVSS/S2ST/hparams/train_fr-en.yaml | [here]( https://www.dropbox.com/sh/woz4i1p8pkfkqhf/AACmOvr3sS7p95iXl3twCj_xa?dl=0) | [here]( ) | 24.47 |
+
+
+## CommonLanguage Dataset
+
+### Language-id
+
+| Model | Checkpoints | HuggingFace | Error |
+| --------| --------| --------| --------|
+ | recipes/CommonLanguage/lang_id/hparams/train_ecapa_tdnn.yaml | [here](https://www.dropbox.com/sh/1fxpzyv67ouwd2c/AAAeMUWYP2f1ycpE1Lp1CwEla?dl=0) | [here](https://huggingface.co/speechbrain/lang-id-commonlanguage_ecapa) | 15.1% |
+
+
+## CommonVoice Dataset
+
+### ASR-transformer
+
+| Model | Checkpoints | HuggingFace | Test-WER |
+| --------| --------| --------| --------|
+ | recipes/CommonVoice/ASR/transformer/hparams/train_fr.yaml | [here](https://www.dropbox.com/sh/zvu9h9pctksnuvp/AAD1kyS3-N0YtmcoMgjM-_Tba?dl=0) | - | 17.61% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_it.yaml | [here](https://www.dropbox.com/sh/yy8du12jgbkm3qe/AACBHhTCM-cU-oGvAKJ9kTtaa?dl=0) | - | 16.80% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_de.yaml | [here](https://www.dropbox.com/sh/umfq986o3d9o1px/AAARNF2BFYELOWx3xhIOEoZka?dl=0) | - | 16.76% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml | [here](https://www.dropbox.com/sh/0e4vtvbg6hf2e13/AAD-tfzCZGUrh85aeAeJj8I9a?dl=0) | [here](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-ar) | 16.96% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml | [here](https://www.dropbox.com/sh/w1urihacmtoulmi/AADMtK3qeAF5mLYk5LMHyiOra?dl=0) | [here](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fa) | 31.75% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml | [here](https://www.dropbox.com/sh/7zlk07yxnslk4yy/AAANcI3EaG0ZFy6UrKk1Mm2Ga?dl=0) | [here](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fr) | 10.62% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml | [here](https://www.dropbox.com/sh/5lhk230q45sd97z/AAD-U9b_Ws_vFPs-cazsbOY0a?dl=0) | [here](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-sr) | 22.29% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml | [here](https://www.dropbox.com/sh/6fbhmey7q1udykf/AAAiGObWTTe2cdXHt2Uv2VQXa?dl=0) | [here](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-mn) | 67.84% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml | [here](https://www.dropbox.com/sh/z9vriyy3i6xqvif/AAB7ql-40yWTjKEQJiuhYUr5a?dl=0) | [here](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-hi) | 15.27% |
+ | recipes/CommonVoice/ASR/transformer/hparams/train_it_hf_whisper.yaml | [here](https://www.dropbox.com/sh/u5tex3nvzzs5pex/AAD-J7cOBE_fNfBono8waTKCa?dl=0) | [here](https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-it) | 9.63% |
+
+
+### ASR-CTC
+
+| Model | Checkpoints | HuggingFace | Test-WER |
+| --------| --------| --------| --------|
+ | recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/ch10cnbhf1faz3w/AACdHFG65LC6582H0Tet_glTa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-en) | 16.16% |
+ | recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/0i7esfa8jp3rxpp/AAArdi8IuCRmob2WAS7lg6M4a?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-fr) | 9.71% |
+ | recipes/CommonVoice/ASR/CTC/hparams/train_it_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/hthxqzh5boq15rn/AACftSab_FM6EFWWPgHpKw82a?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-it) | 7.99% |
+ | recipes/CommonVoice/ASR/CTC/hparams/train_rw_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/4iax0l4yfry37gn/AABuQ31JY-Sbyi1VlOJfV7haa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-rw) | 22.52% |
+ | recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/dn7plq4wfsujsi1/AABS1kqB_uqLJVkg-bFkyPpVa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-de) | 8.39% |
+ | recipes/CommonVoice/ASR/CTC/hparams/train_ar_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/7tnuqqbr4vy96cc/AAA_5_R0RmqFIiyR0o1nVS4Ia?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-ar) | 28.53% |
+ | recipes/CommonVoice/ASR/CTC/hparams/train_es_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/ejvzgl3d3g8g9su/AACYtbSWbDHvBr06lAb7A4mVa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-es) | 12.67% |
+ | recipes/CommonVoice/ASR/CTC/hparams/train_pt_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/80wucrvijdvao2a/AAD6-SZ2_ZZXmlAjOTw6fVloa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-pt) | 21.69% |
+ | recipes/CommonVoice/ASR/CTC/hparams/train_zh-CN_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/2bikr81vgufoglf/AABMpD0rLIaZBxjtwBHgrNpga?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-zh-CN) | 23.17% |
+
+
+### ASR-transducer
+
+| Model | Checkpoints | HuggingFace | Test-WER |
+| --------| --------| --------| --------|
+ | recipes/CommonVoice/ASR/transducer/hparams/train_fr.yaml | [here](https://www.dropbox.com/sh/nv2pnpo5n3besn3/AADZ7l41oLt11ZuOE4MqoJhCa?dl=0) | [here](speechbrain/asr-transducer-commonvoice-14-fr) | 17.58% |
+ | recipes/CommonVoice/ASR/transducer/hparams/train_it.yaml | [here](https://www.dropbox.com/sh/ksm08x0wwiomrgs/AABnjPePWGPxqIqW7bJHp1jea?dl=0) | [here](speechbrain/asr-transducer-commonvoice-14-it) | 14.88% |
+ | recipes/CommonVoice/ASR/transducer/hparams/train_de.yaml | [here](https://www.dropbox.com/sh/jfge6ixbtoje64t/AADeAgL5un0A8uEjPSM84ex8a?dl=0) | [here](speechbrain/asr-transducer-commonvoice-14-de) | 15.25% |
+
+
+### ASR-seq2seq
+
+| Model | Checkpoints | HuggingFace | Test-WER |
+| --------| --------| --------| --------|
+ | recipes/CommonVoice/ASR/seq2seq/hparams/train_de.yaml | [here](https://www.dropbox.com/sh/zgatirb118f79ef/AACmjh-D94nNDWcnVI4Ef5K7a?dl=0) | [here](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-de) | 12.25% |
+ | recipes/CommonVoice/ASR/seq2seq/hparams/train_en.yaml | [here](https://www.dropbox.com/sh/h8ged0yu3ztypkh/AAAu-12k_Ceg-tTjuZnrg7dza?dl=0) | [here](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-en) | 23.88% |
+ | recipes/CommonVoice/ASR/seq2seq/hparams/train_fr.yaml | [here](https://www.dropbox.com/sh/07a5lt21wxp98x5/AABhNwmWFaNFyA734bNZUO03a?dl=0) | [here](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-fr) | 14.88% |
+ | recipes/CommonVoice/ASR/seq2seq/hparams/train_it.yaml | [here](https://www.dropbox.com/sh/ss59uu0j5boscvp/AAASsiFhlB1nDWPkFX410bzna?dl=0) | [here](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-it) | 17.02% |
+ | recipes/CommonVoice/ASR/seq2seq/hparams/train_rw.yaml | [here](https://www.dropbox.com/sh/i1fv4f8miilqgii/AAB3gE97kmFDA0ISkIDSUW_La?dl=0) | [here](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-rw) | 29.22% |
+ | recipes/CommonVoice/ASR/seq2seq/hparams/train_es.yaml | [here](https://www.dropbox.com/sh/r3w0b2tm1p73vft/AADCxdhUwDN6j4PVT9TYe-d5a?dl=0) | [here](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-es) | 14.77% |
+
+
+## DNS Dataset
+
+### Enhancement
+
+| Model | Checkpoints | HuggingFace | valid-PESQ | test-SIG | test-BAK | test-OVRL |
+| --------| --------| --------| --------| --------| --------| --------|
+ | recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml | [here](https://www.dropbox.com/sh/d3rp5d3gjysvy7c/AACmwcEkm_IFvaW1lt2GdtQka?dl=0) | [here](https://huggingface.co/speechbrain/sepformer-dns4-16k-enhancement) | 2.06 | 2.999 | 3.076 | 2.437 |
+
+
+## DVoice Dataset
+
+### ASR-CTC
+
+| Model | Checkpoints | HuggingFace | Test-WER |
+| --------| --------| --------| --------|
+ | recipes/DVoice/ASR/CTC/hparams/train_amh_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-amharic) | 24.92% |
+ | recipes/DVoice/ASR/CTC/hparams/train_dar_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-darija) | 18.28% |
+ | recipes/DVoice/ASR/CTC/hparams/train_fon_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-fongbe) | 9.00% |
+ | recipes/DVoice/ASR/CTC/hparams/train_sw_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-swahili) | 23.16% |
+ | recipes/DVoice/ASR/CTC/hparams/train_wol_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-wolof) | 16.05% |
+
+
+### Multilingual-ASR-CTC
+
+| Model | Checkpoints | HuggingFace | WER-Darija | WER-Swahili | WER-Fongbe | Fongbe-Wolof | WER-Amharic |
+| --------| --------| --------| --------| --------| --------| --------| --------|
+ | recipes/DVoice/ASR/CTC/hparams/train_multi_with_wav2vec.yaml | [here](https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0) | - | 13.27% | 29.31% | 10.26% | 21.54% | 31.15% |
+
+
+## ESC50 Dataset
+
+### SoundClassification
+
+| Model | Checkpoints | HuggingFace | Accuracy |
+| --------| --------| --------| --------|
+ | recipes/ESC50/classification/hparams/cnn14_classifier.yaml | [here](https://www.dropbox.com/sh/fbe7l14o3n8f5rw/AACABE1BQGBbX4j6A1dIhBcSa?dl=0) | - | 82% |
+ | recipes/ESC50/classification/hparams/conv2d_classifier.yaml | [here](https://www.dropbox.com/sh/tl2pbfkreov3z7e/AADwwhxBLw1sKvlSWzp6DMEia?dl=0) | - | 75% |
+
+
+## Fisher-Callhome-Spanish Dataset
+
+### Speech_Translation
+
+| Model | Checkpoints | HuggingFace | Test-sacrebleu |
+| --------| --------| --------| --------|
+ | recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/transformer.yaml | [here](https://www.dropbox.com/sh/tmh7op8xwthdta0/AACuU9xHDHPs8ToxIIwoTLB0a?dl=0) | - | 47.31 |
+ | recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/conformer.yaml | [here](https://www.dropbox.com/sh/tmh7op8xwthdta0/AACuU9xHDHPs8ToxIIwoTLB0a?dl=0) | - | 48.04 |
+
+
+## Google-speech-commands Dataset
+
+### Command_recognition
+
+| Model | Checkpoints | HuggingFace | Test-accuracy |
+| --------| --------| --------| --------|
+ | recipes/Google-speech-commands/hparams/xvect.yaml | [here](https://www.dropbox.com/sh/9n9q42pugbx0g7a/AADihpfGKuWf6gkwQznEFINDa?dl=0) | [here](https://huggingface.co/speechbrain/google_speech_command_xvector) | 97.43% |
+ | recipes/Google-speech-commands/hparams/xvect_leaf.yaml | [here](https://www.dropbox.com/sh/r63w4gytft4s1x6/AAApP8-pp179QKGCZHV_OuD8a?dl=0) | - | 96.79% |
+
+
+## IEMOCAP Dataset
+
+### Emotion_recognition
+
+| Model | Checkpoints | HuggingFace | Test-Accuracy |
+| --------| --------| --------| --------|
+ | recipes/IEMOCAP/emotion_recognition/hparams/train_with_wav2vec2.yaml | [here](https://www.dropbox.com/sh/lmebg4li83sgkhg/AACooPKbNlwd-7n5qSJMbc7ya?dl=0) | [here](https://huggingface.co/speechbrain/emotion-recognition-wav2vec2-IEMOCAP/) | 65.7% |
+ | recipes/IEMOCAP/emotion_recognition/hparams/train.yaml | [here](https://www.dropbox.com/sh/ke4fxiry97z58m8/AACPEOM5bIyxo9HxG2mT9v_aa?dl=0) | - | 77.0% |
+
+
+## IWSLT22_lowresource Dataset
+
+### Speech_Translation
+
+| Model | Checkpoints | HuggingFace | Test-BLEU |
+| --------| --------| --------| --------|
+ | recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_mbart_st.yaml | [here](https://www.dropbox.com/sh/xjo0ou739oksnus/AAAgyrCwywmDRRuUiDnUva2za?dl=0) | - | 7.73 |
+ | recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_nllb_st.yaml | [here](https://www.dropbox.com/sh/spp2ijgfdbzuz26/AABkJ97e72D7aKzNLTm1qmWEa?dl=0) | - | 8.70 |
+ | recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_mbart_st.yaml | [here](https://www.dropbox.com/sh/98s1xyc3chreaw6/AABom3FnwY5SsIvg4en9tWC2a?dl=0) | - | 10.28 |
+ | recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_nllb_st.yaml | [here](https://www.dropbox.com/sh/ekkpl9c3kxsgllj/AABa0q2LrJe_o7JF-TTbfxZ-a?dl=0) | - | 11.32 |
+
+
+## KsponSpeech Dataset
+
+### ASR
+
+| Model | Checkpoints | HuggingFace | clean-WER | others-WER |
+| --------| --------| --------| --------| --------|
+ | recipes/KsponSpeech/ASR/transformer/hparams/conformer_medium.yaml | [here](https://www.dropbox.com/sh/uibokbz83o8ybv3/AACtO5U7mUbu_XhtcoOphAjza?dl=0) | [here](https://huggingface.co/speechbrain/asr-conformer-transformerlm-ksponspeech) | 20.78% | 25.73% |
+
+
+## LibriMix Dataset
+
+### Separation
+
+| Model | Checkpoints | HuggingFace | SI-SNR |
+| --------| --------| --------| --------|
+ | recipes/LibriMix/separation/hparams/sepformer-libri2mix.yaml | [here](https://www.dropbox.com/sh/skkiozml92xtgdo/AAD0eJxgbCTK03kAaILytGtVa?dl=0) | - | 20.4dB |
+ | recipes/LibriMix/separation/hparams/sepformer-libri3mix.yaml | [here](https://www.dropbox.com/sh/kmyz7tts9tyg198/AACsDcRwKvelXxEB-k5q1OaIa?dl=0) | - | 19.0dB |
+
+
+## LibriParty Dataset
+
+### VAD
+
+| Model | Checkpoints | HuggingFace | Test-Precision | Recall | F-Score |
+| --------| --------| --------| --------| --------| --------|
+ | recipes/LibriParty/VAD/hparams/train.yaml | [here](https://www.dropbox.com/sh/6yguuzn4pybjasd/AABpUF8LAQ8d2TJyC8aK2OBga?dl=0 ) | [here](https://huggingface.co/speechbrain/vad-crdnn-libriparty) | 0.9518 | 0.9437 | 0.9477 |
+
+
+## LibriSpeech Dataset
+
+### ASR-Seq2Seq
+
+| Model | Checkpoints | HuggingFace | Test_clean-WER | Test_other-WER |
+| --------| --------| --------| --------| --------|
+ | recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml | [here](https://www.dropbox.com/sh/1ycv07gyxdq8hdl/AABUDYzza4SLYtY45RcGf2_0a?dl=0) | [here](https://huggingface.co/speechbrain/asr-crdnn-transformerlm-librispeech) | 2.89% | 8.09% |
+
+
+### ASR-CTC
+
+| Model | Checkpoints | HuggingFace | Test_clean-WER | Test_other-WER |
+| --------| --------| --------| --------| --------|
+ | recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml | [here](https://www.dropbox.com/sh/qj2ps85g8oiicrj/AAAxlkQw5Pfo0M9EyHMi8iAra?dl=0) | [here](https://huggingface.co/speechbrain/asr-wav2vec2-librispeech) | 1.65% | 3.67% |
+ | recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_transformer_rescoring.yaml | [here](https://www.dropbox.com/sh/ijqalvre7mm08ng/AAD_hsN-8dBneUMMkELsOOxga?dl=0) | - | 1.57% | 3.37% |
+
+
+### ASR-Transducers
+
+| Model | Checkpoints | HuggingFace | Test_clean-WER | Test_other-WER |
+| --------| --------| --------| --------| --------|
+ | recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml | [here](https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing) | - | 2.72% | 6.47% |
+
+
+### ASR-Transformers
+
+| Model | Checkpoints | HuggingFace | Test_clean-WER | Test_other-WER |
+| --------| --------| --------| --------| --------|
+ | recipes/LibriSpeech/ASR/transformer/hparams/conformer_small.yaml | [here](https://www.dropbox.com/sh/s0x6ni124858b8i/AAALaCH6sGTMRUVTjh8Tm8Jwa?dl=0) | [here](https://huggingface.co/speechbrain/asr-conformersmall-transformerlm-librispeech) | 2.49% | 6.10% |
+ | recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml | [here](https://www.dropbox.com/sh/653kq8h2k87md4p/AAByAaAryXtQKpRzYtzV9ih5a?dl=0) | [here](https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech) | 2.27% | 5.53% |
+ | recipes/LibriSpeech/ASR/transformer/hparams/conformer_large.yaml | [here](https://www.dropbox.com/sh/ef3chrau8i45ip1/AAD9un8oabOB1a9OiSomZEhZa?dl=0) | - | 2.01% | 4.52% |
+ | recipes/LibriSpeech/ASR/transformer/hparams/branchformer_large.yaml | [here](https://www.dropbox.com/sh/gxkye4efa6hvl2c/AADO85EkkfbIGe5KjBAU6BrEa?dl=0) | - | 2.04% | 4.12% |
+ | recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_22M.yaml | [here](https://www.dropbox.com/sh/gxkye4efa6hvl2c/AADO85EkkfbIGe5KjBAU6BrEa?dl=0) | - | 2.23% | 4.54% |
+ | recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_8M.yaml | [here](https://www.dropbox.com/sh/gxkye4efa6hvl2c/AADO85EkkfbIGe5KjBAU6BrEa?dl=0) | - | 2.55% | 6.61% |
+ | recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_25M.yaml | - | - | 2.36% | 6.89% |
+ | recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_13M.yaml | - | - | 2.54% | 6.58% |
+ | recipes/LibriSpeech/ASR/transformer/hparams/train_hf_whisper.yaml | - | - |  |
+ | recipes/LibriSpeech/ASR/transformer/hparams/bayesspeech.yaml | [here](https://www.dropbox.com/scl/fo/cdken4jqfj96ev1v84jxm/h?rlkey=25eu1ytgm5ac51zqj8p65zwxd&dl=0) | - | 2.84% | 6.27% |
+
+
+### G2P
+
+| Model | Checkpoints | HuggingFace | PER-Test |
+| --------| --------| --------| --------|
+ | recipes/LibriSpeech/G2P/hparams/hparams_g2p_rnn.yaml | [here](https://www.dropbox.com/sh/qmcl1obp8pxqaap/AAC3yXvjkfJ3mL-RKyAUxPdNa?dl=0) | - | 2.72% |
+ | recipes/LibriSpeech/G2P/hparams/hparams_g2p_transformer.yaml | [here](https://www.dropbox.com/sh/zhrxg7anuhje7e8/AADTeJtdsja_wClkE2DsF9Ewa?dl=0) | [here](https://huggingface.co/speechbrain/soundchoice-g2p) | 2.89% |
+
+
+## MEDIA Dataset
+
+### SLU
+
+| Model | Checkpoints | HuggingFace | Test-ChER | Test-CER | Test-CVER |
+| --------| --------| --------| --------| --------| --------|
+ | recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_full.yaml | - | [here](https://huggingface.co/speechbrain/slu-wav2vec2-ctc-MEDIA-relax) | 7.46% | 20.10% | 31.41% |
+ | recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_relax.yaml | - | [here](https://huggingface.co/speechbrain/slu-wav2vec2-ctc-MEDIA-full) | 7.78% | 24.88% | 35.77% |
+
+
+### ASR
+
+| Model | Checkpoints | HuggingFace | Test-ChER | Test-CER |
+| --------| --------| --------| --------| --------|
+ | recipes/MEDIA/ASR/CTC/hparams/train_hf_wav2vec.yaml | - | [here](https://huggingface.co/speechbrain/asr-wav2vec2-ctc-MEDIA) | 7.78% | 4.78% |
+
+
+## MultiWOZ Dataset
+
+### Response-Generation
+
+| Model | Checkpoints | HuggingFace | Test-PPL | Test_BLEU-4 |
+| --------| --------| --------| --------| --------|
+ | recipes/MultiWOZ/response_generation/gpt/hparams/train_gpt.yaml | [here](https://www.dropbox.com/sh/vm8f5iavohr4zz9/AACrkOxXuxsrvJy4Cjpih9bQa?dl=0) | [here](https://huggingface.co/speechbrain/MultiWOZ-GPT-Response_Generation) | 4.01 | 2.54e-04 |
+ | recipes/MultiWOZ/response_generation/llama2/hparams/train_llama2.yaml | [here](https://www.dropbox.com/sh/d093vsje1d7ijj9/AAA-nHEd_MwNEFJfBGLmXxJra?dl=0) | [here](https://huggingface.co/speechbrain/MultiWOZ-Llama2-Response_Generation) | 2.90 | 7.45e-04 |
+
+
+## REAL-M Dataset
+
+### Sisnr-estimation
+
+| Model | Checkpoints | HuggingFace | L1-Error |
+| --------| --------| --------| --------|
+ | recipes/REAL-M/sisnr-estimation/hparams/pool_sisnrestimator.yaml | [here](https://www.dropbox.com/sh/n55lm8i5z51pbm1/AABHfByOEy__UP_bmT4GJvSba?dl=0) | [here](https://huggingface.co/speechbrain/REAL-M-sisnr-estimator) | 1.71dB |
+
+
+## RescueSpeech Dataset
+
+### ASR+enhancement
+
+| Model | Checkpoints | HuggingFace | SISNRi | SDRi | PESQ | STOI | WER |
+| --------| --------| --------| --------| --------| --------| --------| --------|
+ | recipes/RescueSpeech/ASR/noise-robust/hparams/robust_asr_16k.yaml | [here](https://www.dropbox.com/sh/kqs2ld14fm20cxl/AACiobSLdNtXhm-4Y3IIbTeia?dl=0) | [here](https://huggingface.co/sangeet2020/noisy-whisper-resucespeech) | 7.482 | 8.011 | 2.083 | 0.854 | 45.29% |
+
+
+## SLURP Dataset
+
+### SLU
+
+| Model | Checkpoints | HuggingFace | scenario-accuracy | action-accuracy | intent-accuracy |
+| --------| --------| --------| --------| --------| --------|
+ | recipes/SLURP/NLU/hparams/train.yaml | [here](https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 ) | - | 90.81% | 88.29% | 87.28% |
+ | recipes/SLURP/direct/hparams/train.yaml | [here](https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 ) | - | 81.73% | 77.11% | 75.05% |
+ | recipes/SLURP/direct/hparams/train_with_wav2vec2.yaml | [here](https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 ) | [here](https://huggingface.co/speechbrain/SLU-direct-SLURP-hubert-enc) | 91.24% | 88.47% | 87.55% |
+
+
+## Switchboard Dataset
+
+### ASR
+
+| Model | Checkpoints | HuggingFace | Swbd-WER | Callhome-WER | Eval2000-WER |
+| --------| --------| --------| --------| --------| --------|
+ | recipes/Switchboard/ASR/CTC/hparams/train_with_wav2vec.yaml | - | [here](https://huggingface.co/speechbrain/asr-wav2vec2-switchboard) | 8.76% | 14.67% | 11.78% |
+ | recipes/Switchboard/ASR/seq2seq/hparams/train_BPE_2000.yaml | - | [here](https://huggingface.co/speechbrain/asr-crdnn-switchboard) | 16.90% | 25.12% | 20.71% |
+ | recipes/Switchboard/ASR/transformer/hparams/transformer.yaml | - | [here](https://huggingface.co/speechbrain/asr-transformer-switchboard) | 9.80% | 17.89% | 13.94% |
+
+
+## TIMIT Dataset
+
+### ASR
+
+| Model | Checkpoints | HuggingFace | Test-PER |
+| --------| --------| --------| --------|
+ | recipes/TIMIT/ASR/CTC/hparams/train.yaml | [here](https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0) | - | 14.78% |
+ | recipes/TIMIT/ASR/seq2seq/hparams/train.yaml | [here](https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0) | - | 14.07% |
+ | recipes/TIMIT/ASR/seq2seq/hparams/train_with_wav2vec2.yaml | [here](https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0) | - | 8.04% |
+ | recipes/TIMIT/ASR/transducer/hparams/train.yaml | [here](https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0) | - | 14.12% |
+ | recipes/TIMIT/ASR/transducer/hparams/train_wav2vec.yaml | [here](https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0) | - | 8.91% |
+
+
+## Tedlium2 Dataset
+
+### ASR
+
+| Model | Checkpoints | HuggingFace | Test-WER_No_LM |
+| --------| --------| --------| --------|
+ | recipes/Tedlium2/ASR/transformer/hparams/branchformer_large.yaml | [here](https://www.dropbox.com/sh/el523uofs96czfi/AADgTd838pKo2aR8fhqVOh-Oa?dl=0) | [here](https://huggingface.co/speechbrain/asr-branchformer-large-tedlium2) | 8.11% |
+
+
+## UrbanSound8k Dataset
+
+### SoundClassification
+
+| Model | Checkpoints | HuggingFace | Accuracy |
+| --------| --------| --------| --------|
+ | recipes/UrbanSound8k/SoundClassification/hparams/train_ecapa_tdnn.yaml | [here](https://www.dropbox.com/sh/f61325e3w8h5yy2/AADm3E3PXFi1NYA7-QW3H-Ata?dl=0 ) | [here](https://huggingface.co/speechbrain/urbansound8k_ecapa) | 75.4% |
+
+
+## Voicebank Dataset
+
+### Dereverberation
+
+| Model | Checkpoints | HuggingFace | PESQ |
+| --------| --------| --------| --------|
+ | recipes/Voicebank/dereverb/MetricGAN-U/hparams/train_dereverb.yaml | [here](https://www.dropbox.com/sh/r94qn1f5lq9r3p7/AAAZfisBhhkS8cwpzy1O5ADUa?dl=0 ) | - | 2.07 |
+ | recipes/Voicebank/dereverb/spectral_mask/hparams/train.yaml | [here](https://www.dropbox.com/sh/pw8aer8gcsrdbx7/AADknh7plHF5GBeTRK9VkIKga?dl=0 ) | - | 2.35 |
+
+
+### ASR+enhancement
+
+| Model | Checkpoints | HuggingFace | PESQ | COVL | test-WER |
+| --------| --------| --------| --------| --------| --------|
+ | recipes/Voicebank/MTL/ASR_enhance/hparams/robust_asr.yaml | [here](https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0) | [here](https://huggingface.co/speechbrain/mtl-mimic-voicebank) | 3.05 | 3.74 | 2.80 |
+
+
+### Enhancement
+
+| Model | Checkpoints | HuggingFace | PESQ |
+| --------| --------| --------| --------|
+ | recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml | [here](https://www.dropbox.com/sh/n5q9vjn0yn1qvk6/AAB-S7i2-XzVm6ux0MrXCvqya?dl=0 ) | [here](https://huggingface.co/speechbrain/metricgan-plus-voicebank) | 3.15 |
+ | recipes/Voicebank/enhance/SEGAN/hparams/train.yaml | [here](https://www.dropbox.com/sh/ez0folswdbqiad4/AADDasepeoCkneyiczjCcvaOa?dl=0 ) | - | 2.38 |
+ | recipes/Voicebank/enhance/spectral_mask/hparams/train.yaml | [here](https://www.dropbox.com/sh/n5q9vjn0yn1qvk6/AAB-S7i2-XzVm6ux0MrXCvqya?dl=0 ) | - | 2.65 |
+
+
+### ASR
+
+| Model | Checkpoints | HuggingFace | Test-PER |
+| --------| --------| --------| --------|
+ | recipes/Voicebank/ASR/CTC/hparams/train.yaml | [here](https://www.dropbox.com/sh/w4j0auezgmmo005/AAAjKcoJMdLDp0Pqe3m7CLVaa?dl=0) | - | 10.12% |
+
+
+## VoxCeleb Dataset
+
+### Speaker_recognition
+
+| Model | Checkpoints | HuggingFace | EER |
+| --------| --------| --------| --------|
+ | recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn.yaml | [here](https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0) | [here](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb) | 0.80% |
+ | recipes/VoxCeleb/SpeakerRec/hparams/train_x_vectors.yaml | [here](https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0) | [here](https://huggingface.co/speechbrain/spkrec-xvect-voxceleb) | 3.23% |
+ | recipes/VoxCeleb/SpeakerRec/hparams/train_resnet.yaml | [here](https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0) | [here](https://huggingface.co/speechbrain/spkrec-resnet-voxceleb) | 0.95% |
+
+
+## VoxLingua107 Dataset
+
+### Language-id
+
+| Model | Checkpoints | HuggingFace | Accuracy |
+| --------| --------| --------| --------|
+ | recipes/VoxLingua107/lang_id/hparams/train_ecapa.yaml | [here](https://www.dropbox.com/sh/72gpuic5m4x8ztz/AAB5R-RVIEsXJtRH8SGkb_oCa?dl=0 ) | [here](https://huggingface.co/speechbrain/lang-id-voxlingua107-ecapa) | 93.3% |
+
+
+## WHAMandWHAMR Dataset
+
+### Separation
+
+| Model | Checkpoints | HuggingFace | SI-SNR |
+| --------| --------| --------| --------|
+ | recipes/WHAMandWHAMR/separation/hparams/sepformer-wham.yaml | [here](https://www.dropbox.com/sh/sfrgb3xivri432e/AACQodNmiDIKrB9vCeCFUDWUa?dl=0) | [here](https://huggingface.co/speechbrain/sepformer-whamr) | 16.5 |
+ | recipes/WHAMandWHAMR/separation/hparams/sepformer-whamr.yaml | [here](https://www.dropbox.com/sh/1sia32z01xbfgvu/AADditsqaTyfN3N6tzfEFPica?dl=0) | [here](https://huggingface.co/speechbrain/sepformer-wham) | 14.0 |
+
+
+### Enhancement
+
+| Model | Checkpoints | HuggingFace | SI-SNR | PESQ |
+| --------| --------| --------| --------| --------|
+ | recipes/WHAMandWHAMR/enhancement/hparams/sepformer-wham.yaml | [here](https://www.dropbox.com/sh/pxz2xbj76ijd5ci/AAD3c3dHyszk4oHJaa26K1_ha?dl=0) | [here](https://huggingface.co/speechbrain/sepformer-wham-enhancement) | 14.4 | 3.05 |
+ | recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr.yaml | [here](https://www.dropbox.com/sh/kb0xrvi5k168ou2/AAAPB2U6HyyUT1gMoUH8gxQCa?dl=0) | [here](https://huggingface.co/speechbrain/sepformer-whamr-enhancement) | 10.6 | 2.84 |
+
+
+## WSJ0Mix Dataset
+
+### Separation (2mix)
+
+| Model | Checkpoints | HuggingFace | SI-SNRi |
+| --------| --------| --------| --------|
+ | recipes/WSJ0Mix/separation/hparams/convtasnet.yaml | [here](https://www.dropbox.com/sh/hdpxj47signsay7/AABbDjGoyQesnFxjg0APxl7qa?dl=0) | - | 14.8dB |
+ | recipes/WSJ0Mix/separation/hparams/dprnn.yaml | [here](https://www.dropbox.com/sh/o8fohu5s07h4bnw/AADPNyR1E3Q4aRobg3FtXTwVa?dl=0) | - | 18.5dB |
+ | recipes/WSJ0Mix/separation/hparams/resepformer.yaml | [here](https://www.dropbox.com/sh/obnu87zhubn1iia/AAAbn_jzqzIfeqaE9YQ7ujyQa?dl=0) | [here](https://huggingface.co/speechbrain/resepformer-wsj02mix) | 18.6dB |
+ | recipes/WSJ0Mix/separation/hparams/sepformer.yaml | [here](https://www.dropbox.com/sh/9klsqadkhin6fw1/AADEqGdT98rcqxVgFlfki7Gva?dl=0 ) | [here](https://huggingface.co/speechbrain/sepformer-wsj02mix) | 22.4dB |
+ | recipes/WSJ0Mix/separation/hparams/skim.yaml | [here](https://www.dropbox.com/sh/zy0l5rc8abxdfp3/AAA2ngB74fugqpWXmjZo5v3wa?dl=0) | [here](https://huggingface.co/speechbrain/resepformer-wsj02mix ) | 18.1dB |
+
+
+## ZaionEmotionDataset Dataset
+
+### Emotion_Diarization
+
+| Model | Checkpoints | HuggingFace | EDER |
+| --------| --------| --------| --------|
+ | recipes/ZaionEmotionDataset/emotion_diarization/hparams/train.yaml | [here](https://www.dropbox.com/sh/woudm1v31a7vyp5/AADAMxpQOXaxf8E_1hX202GJa?dl=0) | [here](https://huggingface.co/speechbrain/emotion-diarization-wavlm-large) | 30.2% |
+
+
+## fluent-speech-commands Dataset
+
+### SLU
+
+| Model | Checkpoints | HuggingFace | Test-accuracy |
+| --------| --------| --------| --------|
+ | recipes/fluent-speech-commands/direct/hparams/train.yaml | [here](https://www.dropbox.com/sh/wal9ap0go9f66qw/AADBVlGs_E2pEU4vYJgEe3Fba?dl=0) | - | 99.60% |
+
+
+## timers-and-such Dataset
+
+### SLU
+
+| Model | Checkpoints | HuggingFace | Accuracy-Test_real |
+| --------| --------| --------| --------|
+ | recipes/timers-and-such/decoupled/hparams/train_TAS_LM.yaml | [here](https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0) | - | 46.8% |
+ | recipes/timers-and-such/direct/hparams/train.yaml | [here](https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0) | [here](https://huggingface.co/speechbrain/slu-timers-and-such-direct-librispeech-asr) | 77.5% |
+ | recipes/timers-and-such/direct/hparams/train_with_wav2vec2.yaml | [here](https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0) | - | 94.0% |
+ | recipes/timers-and-such/multistage/hparams/train_TAS_LM.yaml | [here](https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0) | - | 72.6% |
+
+
diff --git a/README.md b/README.md
index 42ce585e6f338f99662f09d62097d88bcffcbb97..3aa8c2cfc47b54e817245c776cee54dcff18d297 100644
--- a/README.md
+++ b/README.md
@@ -2,292 +2,262 @@
   <img src="https://raw.githubusercontent.com/speechbrain/speechbrain/develop/docs/images/speechbrain-logo.svg" alt="SpeechBrain Logo"/>
 </p>
 
-[![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/SpeechBrain1/)
-[![Discord](https://dcbadge.vercel.app/api/server/3wYvAaz3Ck?style=flat)](https://discord.gg/3wYvAaz3Ck)
+[![Typing SVG](https://readme-typing-svg.demolab.com?font=Fira+Code&size=40&duration=7000&pause=1000&random=false&width=1200&height=100&lines=Simplify+Conversational+AI+Development)](https://git.io/typing-svg)
 
 
-SpeechBrain is an **open-source** and **all-in-one** conversational AI toolkit based on PyTorch.
+| 📘 [Tutorials](https://speechbrain.github.io/tutorial_basics.html) | 🌐 [Website](https://speechbrain.github.io/) | 📚 [Documentation](https://speechbrain.readthedocs.io/en/latest/index.html) | 🤝 [Contributing](https://speechbrain.readthedocs.io/en/latest/contributing.html) | 🤗 [HuggingFace](https://huggingface.co/speechbrain) | ▶️ [YouTube](https://www.youtube.com/@SpeechBrainProject) | 🐦 [X](https://twitter.com/SpeechBrain1) |
 
-The goal is to create a **single**, **flexible**, and **user-friendly** toolkit that can be used to easily develop **state-of-the-art speech technologies**, including systems for **speech recognition**, **speaker recognition**, **speech enhancement**, **speech separation**, **language identification**, **multi-microphone signal processing**, and many others.
+![GitHub Repo stars](https://img.shields.io/github/stars/speechbrain/speechbrain?style=social) *Please, help our community project. Star on GitHub!*
 
-<img src="https://github.blog/wp-content/uploads/2020/09/github-stars-logo_Color.png" alt="drawing" width="25"/> **Please, star our project on github (see top-right corner) if you appreciate our contribution to the community!**
+**Exciting News (January, 2024):** Discover what is new in SpeechBrain 1.0 [here](https://colab.research.google.com/drive/1IEPfKRuvJRSjoxu22GZhb3czfVHsAy0s?usp=sharing)!
+#
+# 🗣️💬 What SpeechBrain Offers
 
-*SpeechBrain is currently in beta*.
+- SpeechBrain is an **open-source** [PyTorch](https://pytorch.org/) toolkit that accelerates **Conversational AI** development, i.e., the technology behind *speech assistants*, *chatbots*, and *large language models*.
 
-| **[Tutorials](https://speechbrain.github.io/tutorial_basics.html)** | **[Website](https://speechbrain.github.io/)** | **[Documentation](https://speechbrain.readthedocs.io/en/latest/index.html)** | **[Contributing](https://speechbrain.readthedocs.io/en/latest/contributing.html)** | **[HuggingFace](https://huggingface.co/speechbrain)** |
+- It is crafted for fast and easy creation of advanced technologies for **Speech** and **Text** Processing.
 
-# PyTorch 2.0 considerations
 
-In March 2023, PyTorch introduced a new version, PyTorch 2.0, which offers numerous enhancements to the community. At present, the majority of SpeechBrain is compatible with PyTorch 2.0. However, certain sections of the code remain incompatible, and we are actively working towards full compatibility with PyTorch 2.0. For the time being, we recommend users continue utilizing PyTorch 1.13, as this is the version employed in our experiments.
+## 🌐  Vision
+- With the rise of [deep learning](https://www.deeplearningbook.org/), once-distant domains like speech processing and NLP are now very close. A well-designed neural network and large datasets are all you need.
 
-If you wish to use SpeechBrain alongside PyTorch 2.0 and encounter any issues, kindly inform us by responding to this [issue](https://github.com/speechbrain/speechbrain/issues/1897).
-
-# Key features
-
-SpeechBrain provides various useful tools to speed up and facilitate research on speech and language technologies:
-- Various pretrained models nicely integrated with <img src="https://huggingface.co/front/assets/huggingface_logo.svg" alt="drawing" width="40"/> <sub>(HuggingFace)</sub> in our official [organization account](https://huggingface.co/speechbrain). These models are coupled with easy-inference interfaces that facilitate their use.  To help everyone replicate our results, we also provide all the experimental results and folders (including logs, training curves, etc.) in a shared Google Drive folder.
-- The `Brain` class is a fully-customizable tool for managing training and evaluation loops over data. The annoying details of training loops are handled for you while retaining complete flexibility to override any part of the process when needed.
-- A YAML-based hyperparameter file that specifies all the hyperparameters, from individual numbers (e.g., learning rate) to complete objects (e.g., custom models). This elegant solution dramatically simplifies the training script.
-- Multi-GPU training and inference with PyTorch Data-Parallel or Distributed Data-Parallel.
-- Mixed-precision for faster training.
-- A transparent and entirely customizable data input and output pipeline. SpeechBrain follows the PyTorch data loading style and enables users to customize the I/O pipelines (e.g., adding on-the-fly downsampling, BPE tokenization, sorting, threshold ...).
-- On-the-fly dynamic batching
-- Efficient reading of large datasets from a shared  Network File System (NFS) via [WebDataset](https://github.com/webdataset/webdataset).
-- Interface with [HuggingFace](https://huggingface.co/speechbrain) for popular models such as wav2vec2  and Hubert.
-- Interface with [Orion](https://github.com/Epistimio/orion) for hyperparameter tuning.
-
-
-### Speech recognition
-
-SpeechBrain supports state-of-the-art methods for end-to-end speech recognition:
-- Support of wav2vec 2.0 pretrained model with finetuning.
-- State-of-the-art performance or comparable with other existing toolkits in several ASR benchmarks.
-- Easily customizable neural language models, including RNNLM and TransformerLM. We also share several pre-trained models that you can easily use (more to come!). We support the Hugging Face `dataset` to facilitate the training over a large text dataset.
-- Hybrid CTC/Attention end-to-end ASR:
-    - Many available encoders: CRDNN (VGG + {LSTM,GRU,Li-GRU} + DNN), ResNet, SincNet, vanilla transformers, whisper, context net-based transformers or conformers. Thanks to the flexibility of SpeechBrain, any fully customized encoder could be connected to the CTC/attention decoder and trained in a few hours of work. The decoder is fully customizable: LSTM, GRU, LiGRU, transformer, or your neural network!
-    - Optimised and fast beam search on both CPUs and GPUs.
-- Transducer end-to-end ASR with both a custom Numba loss and the torchaudio one. Any encoder or decoder can be plugged into the transducer ranging from VGG+RNN+DNN to conformers.
-- Pre-trained ASR models for transcribing an audio file or extracting features for a downstream task.
-- Fully customizable with the possibility to add external Beam Search decoders, if the ones offered natively by SpeechBrain are not sufficient, such as [PyCTCDecode](https://github.com/kensho-technologies/pyctcdecode) like in our LibriSpeech CTC wav2vec recipe.
-
-### Feature extraction and augmentation
-
-SpeechBrain provides efficient (GPU-friendly) speech augmentation and feature extraction pipelines:
-- On-the-fly and fully-differentiable acoustic feature extraction: filter banks can be learned. This strategy simplifies the training pipeline (you don't have to dump features on disk).
-- On-the-fly feature normalization (global, sentence, batch, or speaker level).
-- On-the-fly environmental corruptions based on noise, reverberation, and babble for robust model training.
-- On-the-fly frequency and time domain SpecAugment with speed augmentation.
-- We support both SinConv and LEAF convolutional frontends.
-
-### Speech enhancement and separation
-- Recipes for spectral masking, spectral mapping, and time-domain speech enhancement.
-- Multiple sophisticated enhancement losses, including differentiable STOI loss, MetricGAN, and mimic loss.
-- State-of-the-art performance on speech separation with Conv-TasNet, DualPath RNN, SepFormer, and RE-SepFormer.
-
-### Speaker recognition, identification and diarization
-SpeechBrain provides different models for speaker recognition, identification, and diarization on different datasets:
-- State-of-the-art performance on speaker recognition and diarization based on ECAPA-TDNN models.
-- Original Xvectors implementation (inspired by Kaldi) with PLDA.
-- Spectral clustering for speaker diarization (combined with speakers embeddings).
-- Libraries to extract speaker embeddings with a pre-trained model on your data.
-
-### Text-to-Speech (TTS) and Vocoders
-- Recipes for training TTS systems such as [Tacotron2](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LJSpeech/) and [FastSpeech2](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LJSpeech/) with LJSpeech.
-- Recipes for training Vocoders such as [HiFIGAN](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LJSpeech).
-
-### Grapheme-to-Phoneme (G2P)
-We have models for converting characters into a sequence of phonemes. In particular, we have Transformer- and RNN-based models operating at the sentence level (i.e, converting a full sentence into a corresponding sequence of phonemes). The models are trained with both data from Wikipedia and LibriSpeech.
-
-### Language Identification
-SpeechBrain provides different models for language identification.
-In particular, our best model is based on an ECAPA-TDNN trained with the [voxlingua107 dataset](http://bark.phon.ioc.ee/voxlingua107/).
-
-### Speech Translation
-- Recipes for transformer and conformer-based end-to-end speech translation.
-- Possibility to choose between normal training (Attention), multi-objectives (CTC+Attention), and multitasks (ST + ASR).
-
-### Self-Supervised Learning of Speech Representations
-- Recipes for wav2vec 2.0 pre-training with multiple GPUs compatible with HuggingFace models.
-
-### Multi-microphone processing
-Combining multiple microphones is a powerful approach to achieving robustness in adverse acoustic environments:
-- Delay-and-sum, MVDR, and GeV beamforming.
-- Speaker localization.
-
-### Emotion Recognition
-- Recipes for emotion recognition using SSL and ECAPA-TDNN models on the [IEMOCAP](https://sail.usc.edu/iemocap/iemocap_release.htm) dataset.
-- Recipe for emotion diarization using SSL models on the [ZaionEmotionDataset](https://zaion.ai/en/resources/zaion-lab-blog/zaion-emotion-dataset/).
-
-### Interpretability
-- Recipes for various intepretability techniques on the ESC50 dataset.
-
-### Spoken Language Understanding
-- Recipes for training wav2vec 2.0 models on, [SLURP](https://zenodo.org/record/4274930#.YEFCYHVKg5k), [MEDIA](https://catalogue.elra.info/en-us/repository/browse/ELRA-E0024/) and [timers-and-such](https://zenodo.org/record/4623772#.YGeMMHVKg5k) datasets.
-
-### Performance
-The recipes released with speechbrain implement speech processing systems with competitive or state-of-the-art performance. In the following, we report the best performance achieved on some popular benchmarks:
-
-| Dataset        | Task           | System  | Performance  |
-| ------------- |:-------------:| -----:|-----:|
-| LibriSpeech      | Speech Recognition | wav2vec2 | WER=1.90% (test-clean) |
-| LibriSpeech      | Speech Recognition | CNN + Conformer | WER=2.0% (test-clean) |
-| TIMIT      | Speech Recognition | CRDNN + distillation | PER=13.1% (test) |
-| TIMIT      | Speech Recognition | wav2vec2 + CTC/Att. | PER=8.04% (test) |
-| CommonVoice (English) | Speech Recognition | wav2vec2 + CTC | WER=15.69% (test) |
-| CommonVoice (French) | Speech Recognition | wav2vec2 + CTC | WER=9.96% (test) |
-| CommonVoice (Italian) | Speech Recognition | wav2vec2 + seq2seq | WER=9.86% (test) |
-| CommonVoice (Kinyarwanda) | Speech Recognition | wav2vec2 + seq2seq | WER=18.91% (test) |
-| AISHELL (Mandarin) | Speech Recognition | wav2vec2 + CTC | CER=5.06% (test) |
-| Fisher-callhome (spanish) | Speech translation | conformer (ST + ASR) | BLEU=48.04 (test) |
-| VoxCeleb2      | Speaker Verification | ECAPA-TDNN | EER=0.80% (vox1-test) |
-| AMI      | Speaker Diarization | ECAPA-TDNN | DER=3.01% (eval)|
-| VoiceBank      | Speech Enhancement | MetricGAN+| PESQ=3.08 (test)|
-| WSJ2MIX      | Speech Separation | SepFormer| SDRi=22.6 dB (test)|
-| WSJ3MIX      | Speech Separation | SepFormer| SDRi=20.0 dB (test)|
-| WHAM!     | Speech Separation | SepFormer| SDRi= 16.4 dB (test)|
-| WHAMR!     | Speech Separation | SepFormer| SDRi= 14.0 dB (test)|
-| Libri2Mix     | Speech Separation | SepFormer| SDRi= 20.6 dB (test-clean)|
-| Libri3Mix     | Speech Separation | SepFormer| SDRi= 18.7 dB (test-clean)|
-| LibryParty | Voice Activity Detection | CRDNN | F-score=0.9477 (test) |
-| IEMOCAP | Emotion Recognition | wav2vec2 | Accuracy=79.8% (test) |
-| CommonLanguage | Language Recognition | ECAPA-TDNN | Accuracy=84.9% (test) |
-| Timers and Such | Spoken Language Understanding | CRDNN | Intent Accuracy=89.2% (test) |
-| SLURP | Spoken Language Understanding | HuBERT | Intent Accuracy=87.54% (test) |
-| VoxLingua 107 | Identification | ECAPA-TDNN | Sentence Accuracy=93.3% (test) |
-
-For more details, take a look at the corresponding implementation in recipes/dataset/.
-
-### Pretrained Models
-
-Beyond providing recipes for training the models from scratch, SpeechBrain shares several pre-trained models (coupled with easy-inference functions) on [HuggingFace](https://huggingface.co/speechbrain). In the following, we report some of them:
-
-| Task        | Dataset | Model |
-| ------------- |:-------------:| -----:|
-| Speech Recognition | LibriSpeech | [CNN + Transformer](https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech) |
-| Speech Recognition | LibriSpeech | [CRDNN](https://huggingface.co/speechbrain/asr-crdnn-transformerlm-librispeech) |
-| Speech Recognition | CommonVoice(English) | [wav2vec + CTC](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-en) |
-| Speech Recognition | CommonVoice(French) | [wav2vec + CTC](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-fr) |
-| Speech Recognition | CommonVoice(Italian) | [wav2vec + CTC](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-it) |
-| Speech Recognition | CommonVoice(Kinyarwanda) | [wav2vec + CTC](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-rw) |
-| Speech Recognition | AISHELL(Mandarin) | [wav2vec + seq2seq](https://huggingface.co/speechbrain/asr-wav2vec2-transformer-aishell) |
-| Text-to-Speech | LJSpeech | [Tacotron2](https://huggingface.co/speechbrain/tts-tacotron2-ljspeech) |
-| Speaker Recognition | Voxceleb | [ECAPA-TDNN](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb) |
-| Speech Separation | WHAMR! | [SepFormer](https://huggingface.co/speechbrain/sepformer-whamr) |
-| Speech Enhancement | Voicebank | [MetricGAN+](https://huggingface.co/speechbrain/metricgan-plus-voicebank) |
-| Speech Enhancement | WHAMR! | [SepFormer](https://huggingface.co/speechbrain/sepformer-whamr-enhancement) |
-| Spoken Language Understanding | Timers and Such | [CRDNN](https://huggingface.co/speechbrain/slu-timers-and-such-direct-librispeech-asr) |
-| Language Identification | CommonLanguage | [ECAPA-TDNN](https://huggingface.co/speechbrain/lang-id-commonlanguage_ecapa) |
-
-The full list of pre-trained models can be found on [HuggingFace](https://huggingface.co/speechbrain)
-
-### Documentation & Tutorials
-SpeechBrain is designed to speed up the research and development of speech technologies. Hence, our code is backed-up with different levels of documentation:
-- **Educational-level:** we provide various Google Colab (i.e., interactive) tutorials describing all the building blocks of SpeechBrain ranging from the core of the toolkit to a specific model designed for a particular task. The tutorials are designed not only to help people familiarize themselves with SpeechBrain but, more in general, to help them familiarize themselves with speech and language technologies.
-- **Functional-level:** all classes in SpeechBrain contains a detailed docstring. It describes the input and output formats, the different arguments, the usage of the function, the potentially associated bibliography, and a function example used for test integration during pull requests.
-- **Low-level:** The code also uses a lot of in-line comments to describe nontrivial parts of the code.
-
-### Under development
-We are currently implementing speech synthesis pipelines and real-time speech processing pipelines. An interface with the Finite State Transducers (FST) implemented by the [Kaldi 2 team](https://github.com/k2-fsa/k2) is under development.
-
-# Where is what, a link list.
-```
-                  (documentation)           (tutorials)
-                  .—————————————.            .———————.
-                  | readthedocs |       ‚––> | Colab |
-                  \—————————————/      ∕     \———————/
-                         ^       ‚––––‘          |
-    (release)            |      ∕                v
-    .——————.       .———————————. (landing) .———————————.
-    | PyPI | –––>  | github.io |  (page)   | templates |   (reference)
-    \——————/       \———————————/       ‚–> \———————————/ (implementation)
-        |                |        ‚–––‘          |
-        v                v       ∕               v
-.———————————–—.   .———————————–—.           .—————————.           .~~~~~~~~~~~~~.
-| HyperPyYAML |~~~| speechbrain | ––––––––> | recipes | ––––––––> | HuggingFace |
-\————————————–/   \————————————–/           \—————————/     ∕     \~~~~~~~~~~~~~/
-  (usability)     (source/modules)          (use cases)    ∕    (pretrained models)
-                                                          ∕
-                        |                        |       ∕               |
-                        v                        v      ∕                v
-                  .~~~~~~~~~~~~~.            .~~~~~~~~.            .———————————.
-                  |   PyTorch   | ––––––––-> | GDrive |            | Inference |
-                  \~~~~~~~~~~~~~/            \~~~~~~~~/            \———————————/
-                   (checkpoints)             (results)            (code snippets)
+- We think it is now time for a **holistic toolkit** that, mimicking the human brain, jointly supports diverse technologies for complex Conversational AI systems.
+
+- This spans *speech recognition*, *speaker recognition*, *speech enhancement*, *speech separation*, *language modeling*, *dialogue*, and beyond.
+
+
+
+## 📚 Training Recipes
+- We share over 200 competitive training [recipes](https://github.com/speechbrain/speechbrain/tree/develop/recipes) on more than 40 datasets supporting 20 speech and text processing tasks (see below).
+
+- We support both training from scratch and fine-tuning pretrained models such as [Whisper](https://huggingface.co/openai/whisper-large), [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2), [WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm), [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert), [GPT2](https://huggingface.co/gpt2), [Llama2](https://huggingface.co/docs/transformers/model_doc/llama2), and beyond. The models on [HuggingFace](https://huggingface.co/) can be easily plugged in and fine-tuned.
+
+- For any task, you train the model using these commands:
+```python
+python train.py hparams/train.yaml
 ```
 
-* https://speechbrain.github.io/
-  * via: https://github.com/speechbrain/speechbrain.github.io
-  * pointing to several tutorials on Google Colab
-* https://github.com/speechbrain/speechbrain
-  * [docs](https://github.com/speechbrain/speechbrain/tree/develop/docs) for https://speechbrain.readthedocs.io/
-  * [recipes](https://github.com/speechbrain/speechbrain/tree/develop/recipes)
-  * [speechbrain](https://github.com/speechbrain/speechbrain/tree/develop/speechbrain), heavily tied with [HyperPyYAML](https://github.com/speechbrain/HyperPyYAML); released on [PyPI](https://pypi.org/project/speechbrain/)
-  * [templates](https://github.com/speechbrain/speechbrain/tree/develop/templates)
-  * [tools](https://github.com/speechbrain/speechbrain/tree/develop/tools) for non-core functionality
-* https://huggingface.co/speechbrain/
-  * hosting several model cards (pretrained models with code snippets)
-* Gdrive
-  * hosting training results; checkpoints; ...
+- The hyperparameters are encapsulated in a YAML file, while the training process is orchestrated through a Python script.
 
-# Conference Tutorials
-SpeechBrain has been presented at Interspeech 2021 and 2022 as well as ASRU 2021. When possible, we will provide some ressources here:
-- [Interspeech 2022 slides.](https://drive.google.com/drive/folders/1d6GAquxw6rZBI-7JvfUQ_-upeiKstJEo)
-- [Interspeech 2021 YouTube recordings.](https://www.youtube.com/results?search_query=Interspeech+speechbrain+)
+- We maintained a consistent code structure across different tasks.
 
-# Quick installation
-SpeechBrain is constantly evolving. New features, tutorials, and documentation will appear over time.
-SpeechBrain can be installed via PyPI. Moreover,  a local installation can be used by those users who want to run experiments and modify/customize the toolkit. SpeechBrain supports both CPU and GPU computations. For most all the recipes, however, a GPU is necessary during training. Please note that CUDA must be properly installed to use GPUs.
+- For better replicability, training logs and checkpoints are hosted on Dropbox.
 
+## <a href="https://huggingface.co/speechbrain" target="_blank"> <img src="https://huggingface.co/front/assets/huggingface_logo.svg" alt="drawing" width="40"/> </a> Pretrained Models and Inference
 
-## Install via PyPI
+- Access over 100 pretrained models hosted on [HuggingFace](https://huggingface.co/speechbrain).
+- Each model comes with a user-friendly interface for seamless inference. For example, transcribing speech using a pretrained model requires just three lines of code:
 
-Once you have created your Python environment (Python 3.7+) you can simply type:
+```python
+from speechbrain.pretrained import EncoderDecoderASR
 
-```
-pip install speechbrain
+asr_model = EncoderDecoderASR.from_hparams(source="speechbrain/asr-conformer-transformerlm-librispeech", savedir="pretrained_models/asr-transformer-transformerlm-librispeech")
+asr_model.transcribe_file("speechbrain/asr-conformer-transformerlm-librispeech/example.wav")
 ```
 
-Then you can access SpeechBrain with:
+##  <a href="https://speechbrain.github.io/" target="_blank"> <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png" alt="drawing" width="50"/> </a>  Documentation
+- We are deeply dedicated to promoting inclusivity and education.
+- We have authored over 30 [tutorials](https://speechbrain.github.io/) on Google Colab that not only describe how SpeechBrain works but also help users familiarize themselves with Conversational AI.
+- Every class or function has clear explanations and examples that you can run. Check out the [documentation](https://speechbrain.readthedocs.io/en/latest/index.html) for more details 📚.
 
-```
-import speechbrain as sb
-```
 
-## Install with GitHub
 
-Once you have created your Python environment (Python 3.7+) you can simply type:
+## 🎯 Use Cases
+- 🚀 **Research Acceleration**: Speeding up academic and industrial research. You can develop and integrate new models effortlessly, comparing their performance against our baselines.
 
-```
-git clone https://github.com/speechbrain/speechbrain.git
-cd speechbrain
-pip install -r requirements.txt
-pip install --editable .
-```
+- ⚡️ **Rapid Prototyping**: Ideal for quick prototyping in time-sensitive projects.
 
-Then you can access SpeechBrain with:
+- 🎓 **Educational Tool**: SpeechBrain's simplicity makes it a valuable educational resource. It is used by institutions like [Mila](https://mila.quebec/en/), [Concordia University](https://www.concordia.ca/), [Avignon University](https://univ-avignon.fr/en/), and many others for student training.
 
-```
-import speechbrain as sb
-```
+#
+# 🚀 Quick Start
 
-Any modification made to the `speechbrain` package will be automatically interpreted as we installed it with the `--editable` flag.
+To get started with SpeechBrain, follow these simple steps:
 
-## Test Installation
-Please, run the following script to make sure your installation is working:
-```
+## 🛠️ Installation
+
+### Install via PyPI
+
+1. Install SpeechBrain using PyPI:
+
+    ```bash
+    pip install speechbrain
+    ```
+
+2. Access SpeechBrain in your Python code:
+
+    ```python
+    import speechbrain as sb
+    ```
+
+### Install from GitHub
+This installation is recommended for users who wish to conduct experiments and customize the toolkit according to their needs.
+
+1. Clone the GitHub repository and install the requirements:
+
+    ```bash
+    git clone https://github.com/speechbrain/speechbrain.git
+    cd speechbrain
+    pip install -r requirements.txt
+    pip install --editable .
+    ```
+
+2. Access SpeechBrain in your Python code:
+
+    ```python
+    import speechbrain as sb
+    ```
+
+Any modifications made to the `speechbrain` package will be automatically reflected, thanks to the `--editable` flag.
+
+## ✔️ Test Installation
+
+Ensure your installation is correct by running the following commands:
+
+```bash
 pytest tests
 pytest --doctest-modules speechbrain
 ```
 
-# Running an experiment
-In SpeechBrain, you can run experiments in this way:
+## 🏃‍♂️ Running an Experiment
 
+In SpeechBrain, you can train a model for any task using the following steps:
+
+```python
+cd recipes/<dataset>/<task>/
+python experiment.py params.yaml
 ```
-> cd recipes/<dataset>/<task>/
-> python experiment.py params.yaml
-```
 
-The results will be saved in the `output_folder` specified in the yaml file. The folder is created by calling `sb.core.create_experiment_directory()` in `experiment.py`. Both detailed logs and experiment outputs are saved there. Furthermore, less verbose logs are output to stdout.
+The results will be saved in the `output_folder` specified in the YAML file.
+
+## 📘 Learning SpeechBrain
+
+- **Website:** Explore general information on the [official website](https://speechbrain.github.io).
+
+- **Tutorials:** Start with [basic tutorials](https://speechbrain.github.io/tutorial_basics.html) covering fundamental functionalities. Find advanced tutorials and topics in the Tutorials menu on the [SpeechBrain website](https://speechbrain.github.io).
+
+- **Documentation:** Detailed information on the SpeechBrain API, contribution guidelines, and code is available in the [documentation](https://speechbrain.readthedocs.io/en/latest/index.html).
+
+#
+# 🔧 Supported Technologies
+- SpeechBrain is a versatile framework designed for implementing a wide range of technologies within the field of Conversational AI.
+- It excels not only in individual task implementations but also in combining various technologies into complex pipelines.
+
+## 🎙️ Speech/Audio Processing
+| Tasks        | Datasets           | Technologies/Models  |
+| ------------- |-------------| -----|
+| Speech Recognition      | [AISHELL-1](https://github.com/speechbrain/speechbrain/tree/develop/recipes/AISHELL-1), [CommonVoice](https://github.com/speechbrain/speechbrain/tree/develop/recipes/CommonVoice), [DVoice](https://github.com/speechbrain/speechbrain/tree/develop/recipes/DVoice), [KsponSpeech](https://github.com/speechbrain/speechbrain/tree/develop/recipes/KsponSpeech), [LibriSpeech](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriSpeech), [MEDIA](https://github.com/speechbrain/speechbrain/tree/develop/recipes/MEDIA), [RescueSpeech](https://github.com/speechbrain/speechbrain/tree/develop/recipes/RescueSpeech), [Switchboard](https://github.com/speechbrain/speechbrain/tree/develop/recipes/Switchboard), [TIMIT](https://github.com/speechbrain/speechbrain/tree/develop/recipes/TIMIT), [Tedlium2](https://github.com/speechbrain/speechbrain/tree/develop/recipes/Tedlium2), [Voicebank](https://github.com/speechbrain/speechbrain/tree/develop/recipes/Voicebank) | [CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf), [Tranducers](https://arxiv.org/pdf/1211.3711.pdf?origin=publication_detail), [Transformers](https://arxiv.org/abs/1706.03762), [Seq2Seq](http://zhaoshuaijiang.com/file/Hybrid_CTC_Attention_Architecture_for_End-to-End_Speech_Recognition.pdf), [Beamsearch techniques for CTC](https://arxiv.org/pdf/1911.01629.pdf),[seq2seq](https://arxiv.org/abs/1904.02619.pdf),[transducers](https://www.merl.com/publications/docs/TR2017-190.pdf)), [Rescoring](https://arxiv.org/pdf/1612.02695.pdf), [Conformer](https://arxiv.org/abs/2005.08100), [Branchformer](https://arxiv.org/abs/2207.02971), [Hyperconformer](https://arxiv.org/abs/2305.18281), [Kaldi2-FST](https://github.com/k2-fsa/k2) |
+| Speaker Recognition      | [VoxCeleb](https://github.com/speechbrain/speechbrain/tree/develop/recipes/VoxCeleb) | [ECAPA-TDNN](https://arxiv.org/abs/2005.07143), [ResNET](https://arxiv.org/pdf/1910.12592.pdf), [Xvectors](https://www.danielpovey.com/files/2018_icassp_xvectors.pdf), [PLDA](https://ieeexplore.ieee.org/document/6639151), [Score Normalization](https://www.sciencedirect.com/science/article/abs/pii/S1051200499903603) |
+| Speech Separation      | [WSJ0Mix](https://github.com/speechbrain/speechbrain/tree/develop/recipes/WSJ0Mix), [LibriMix](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriMix), [WHAM!](https://github.com/speechbrain/speechbrain/tree/develop/recipes/WHAMandWHAMR), [WHAMR!](https://github.com/speechbrain/speechbrain/tree/develop/recipes/WHAMandWHAMR), [Aishell1Mix](https://github.com/speechbrain/speechbrain/tree/develop/recipes/Aishell1Mix), [BinauralWSJ0Mix](https://github.com/speechbrain/speechbrain/tree/develop/recipes/BinauralWSJ0Mix) | [SepFormer](https://arxiv.org/abs/2010.13154), [RESepFormer](https://arxiv.org/abs/2206.09507), [SkiM](https://arxiv.org/abs/2201.10800), [DualPath RNN](https://arxiv.org/abs/1910.06379), [ConvTasNET](https://arxiv.org/abs/1809.07454) |
+| Speech Enhancement      | [DNS](https://github.com/speechbrain/speechbrain/tree/develop/recipes/DNS), [Voicebank](https://github.com/speechbrain/speechbrain/tree/develop/recipes/Voicebank) | [SepFormer](https://arxiv.org/abs/2010.13154), [MetricGAN](https://arxiv.org/abs/1905.04874), [MetricGAN-U](https://arxiv.org/abs/2110.05866), [SEGAN](https://arxiv.org/abs/1703.09452), [spectral masking](http://staff.ustc.edu.cn/~jundu/Publications/publications/Trans2015_Xu.pdf), [time masking](http://staff.ustc.edu.cn/~jundu/Publications/publications/Trans2015_Xu.pdf) |
+| Text-to-Speech      | [LJSpeech](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LJSpeech), [LibriTTS](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriTTS) | [Tacotron2](https://arxiv.org/abs/1712.05884), [Zero-Shot Multi-Speaker Tacotron2](https://arxiv.org/abs/2112.02418), [FastSpeech2](https://arxiv.org/abs/2006.04558) |
+| Vocoding      | [LJSpeech](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LJSpeech), [LibriTTS](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriTTS) | [HiFiGAN](https://arxiv.org/abs/2010.05646), [DiffWave](https://arxiv.org/abs/2009.09761)
+| Spoken Language Understanding | [MEDIA](https://github.com/speechbrain/speechbrain/tree/develop/recipes/MEDIA), [SLURP](https://github.com/speechbrain/speechbrain/tree/develop/recipes/SLURP), [Fluent Speech Commands](https://github.com/speechbrain/speechbrain/tree/develop/recipes/fluent-speech-commands), [Timers-and-Such](https://github.com/speechbrain/speechbrain/tree/develop/recipes/timers-and-such)  | [Direct SLU](https://arxiv.org/abs/2104.01604), [Decoupled SLU](https://arxiv.org/abs/2104.01604), [Multistage SLU](https://arxiv.org/abs/2104.01604) |
+| Speech-to-Speech Translation  | [CVSS](https://github.com/speechbrain/speechbrain/tree/develop/recipes/CVSS) | [Discrete Hubert](https://arxiv.org/pdf/2106.07447.pdf), [HiFiGAN](https://arxiv.org/abs/2010.05646), [wav2vec2](https://arxiv.org/abs/2006.11477) |
+| Speech Translation  | [Fisher CallHome (Spanish)](https://github.com/speechbrain/speechbrain/tree/develop/recipes/Fisher-Callhome-Spanish), [IWSLT22(lowresource)](https://github.com/speechbrain/speechbrain/tree/develop/recipes/IWSLT22_lowresource) | [wav2vec2](https://arxiv.org/abs/2006.11477) |
+| Emotion Classification      | [IEMOCAP](https://github.com/speechbrain/speechbrain/tree/develop/recipes/IEMOCAP), [ZaionEmotionDataset](https://github.com/speechbrain/speechbrain/tree/develop/recipes/ZaionEmotionDataset) | [ECAPA-TDNN](https://arxiv.org/abs/2005.07143), [wav2vec2](https://arxiv.org/abs/2006.11477), [Emotion Diarization](https://arxiv.org/abs/2306.12991) |
+| Language Identification | [VoxLingua107](https://github.com/speechbrain/speechbrain/tree/develop/recipes/VoxLingua107), [CommonLanguage](https://github.com/speechbrain/speechbrain/tree/develop/recipes/CommonLanguage)| [ECAPA-TDNN](https://arxiv.org/abs/2005.07143) |
+| Voice Activity Detection  | [LibriParty](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriParty) | [CRDNN](https://arxiv.org/abs/2106.04624) |
+| Sound Classification  | [ESC50](https://github.com/speechbrain/speechbrain/tree/develop/recipes/ESC50), [UrbanSound](https://github.com/speechbrain/speechbrain/tree/develop/recipes/UrbanSound8k) | [CNN14](https://github.com/ranchlai/sound_classification), [ECAPA-TDNN](https://arxiv.org/abs/2005.07143) |
+| Self-Supervised Learning | [CommonVoice](https://github.com/speechbrain/speechbrain/tree/develop/recipes/CommonVoice), [LibriSpeech](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriSpeech) | [wav2vec2](https://arxiv.org/abs/2006.11477) |
+| Interpretabiliy | [ESC50](https://github.com/speechbrain/speechbrain/tree/develop/recipes/ESC50) | [Learning-to-Interpret (L2I)](https://proceedings.neurips.cc/paper_files/paper/2022/file/e53280d73dd5389e820f4a6250365b0e-Paper-Conference.pdf), [Non-Negative Matrix Factorization (NMF)](https://proceedings.neurips.cc/paper_files/paper/2022/file/e53280d73dd5389e820f4a6250365b0e-Paper-Conference.pdf), [PIQ](https://arxiv.org/abs/2303.12659) |
+| Speech Generation | [AudioMNIST](https://github.com/speechbrain/speechbrain/tree/develop/recipes/AudioMNIST) | [Diffusion](https://arxiv.org/abs/2006.11239), [Latent Diffusion](https://arxiv.org/abs/2112.10752) |
+| Metric Learning | [REAL-M](https://github.com/speechbrain/speechbrain/tree/develop/recipes/REAL-M/sisnr-estimation), [Voicebank](https://github.com/speechbrain/speechbrain/tree/develop/recipes/Voicebank) | [Blind SNR-Estimation](https://arxiv.org/abs/2002.08909), [PESQ Learning](https://arxiv.org/abs/2110.05866) |
+| Allignment | [TIMIT](https://github.com/speechbrain/speechbrain/tree/develop/recipes/TIMIT) | [CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf), [Viterbi](https://www.cs.cmu.edu/~cga/behavior/rabiner1.pdf), [Forward Forward](https://www.cs.cmu.edu/~cga/behavior/rabiner1.pdf) |
+| Diarization | [AMI](https://github.com/speechbrain/speechbrain/tree/develop/recipes/AMI) | [ECAPA-TDNN](https://arxiv.org/abs/2005.07143), [X-vectors](https://www.danielpovey.com/files/2018_icassp_xvectors.pdf), [Spectral Clustering](http://www.ifp.illinois.edu/~hning2/papers/Ning_spectral.pdf) |
+
+## 📝 Text Processing
+| Tasks        | Datasets           | Technologies/Models  |
+| ------------- |-------------| -----|
+| Language Modeling | [CommonVoice](https://github.com/speechbrain/speechbrain/tree/develop/recipes/CommonVoice), [LibriSpeech](https://github.com/speechbrain/speechbrain/tree/unstable-v0.6/recipes/LibriSpeech)| [n-grams](https://web.stanford.edu/~jurafsky/slp3/3.pdf), [RNNLM](https://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf), [TransformerLM](https://arxiv.org/abs/1706.03762) |
+| Response Generation | [MultiWOZ](https://github.com/speechbrain/speechbrain/tree/unstable-v0.6/recipes/MultiWOZ/response_generation)| [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf), [Llama2](https://arxiv.org/abs/2307.09288) |
+| Grapheme-to-Phoneme | [LibriSpeech](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriSpeech) | [RNN](https://arxiv.org/abs/2207.13703), [Transformer](https://arxiv.org/abs/2207.13703), [Curriculum Learning](https://arxiv.org/abs/2207.13703), [Homograph loss](https://arxiv.org/abs/2207.13703) |
+
+## 🔍 Additional Features
+
+SpeechBrain includes a range of native functionalities that enhance the development of Conversational AI technologies. Here are some examples:
+
+- **Training Orchestration:** The `Brain` class serves as a fully customizable tool for managing training and evaluation loops over data. It simplifies training loops while providing the flexibility to override any part of the process.
+
+- **Hyperparameter Management:** A YAML-based hyperparameter file specifies all hyperparameters, from individual numbers (e.g., learning rate) to complete objects (e.g., custom models). This elegant solution drastically simplifies the training script.
+
+- **Dynamic Dataloader:** Enables flexible and efficient data reading.
+
+- **GPU Training:** Supports single and multi-GPU training, including distributed training.
+
+- **Dynamic Batching:** On-the-fly dynamic batching enhances the efficient processing of variable-length signals.
+
+- **Mixed-Precision Training:** Accelerates training through mixed-precision techniques.
+
+- **Efficient Data Reading:** Reads large datasets efficiently from a shared Network File System (NFS) via [WebDataset](https://github.com/webdataset/webdataset).
+
+- **Hugging Face Integration:** Interfaces seamlessly with [HuggingFace](https://huggingface.co/speechbrain) for popular models such as wav2vec2 and Hubert.
+
+- **Orion Integration:** Interfaces with [Orion](https://github.com/Epistimio/orion) for hyperparameter tuning.
+
+- **Speech Augmentation Techniques:** Includes SpecAugment, Noise, Reverberation, and more.
+
+- **Data Preparation Scripts:** Includes scripts for preparing data for supported datasets.
+
+SpeechBrain is rapidly evolving, with ongoing efforts to support a growing array of technologies in the future.
+
+
+## 📊 Performance
+
+- SpeechBrain integrates a variety of technologies, including those that achieves competitive or state-of-the-art performance.
+
+- For a comprehensive overview of the achieved performance across different tasks, datasets, and technologies, please visit [here](https://github.com/speechbrain/speechbrain/blob/develop/PERFORMANCE.md).
+
+#
+# 📜 License
+
+- SpeechBrain is released under the [Apache License, version 2.0](https://www.apache.org/licenses/LICENSE-2.0), a popular BSD-like license.
+- You are free to redistribute SpeechBrain for both free and commercial purposes, with the condition of retaining license headers. Unlike the GPL, the Apache License is not viral, meaning you are not obligated to release modifications to the source code.
+
+#
+# 🔮Future Plans
+
+We have ambitious plans for the future, with a focus on the following priorities:
+
+- **Scale Up:** Our aim is to provide comprehensive recipes and technologies for training massive models on extensive datasets.
+
+- **Scale Down:** While scaling up delivers unprecedented performance, we recognize the challenges of deploying large models in production scenarios. We are focusing on real-time, streamable, and small-footprint Conversational AI.
+
+#
+# 🤝 Contributing
+
+- SpeechBrain is a community-driven project, led by a core team with the support of numerous international collaborators.
+- We welcome contributions and ideas from the community. For more information, check [here](https://speechbrain.github.io/contributing.html).
+
+#
+# 🙏 Sponsors
+
+- SpeechBrain is an academically driven project and relies on the passion and enthusiasm of its contributors.
+- As we cannot rely on the resources of a large company, we deeply appreciate any form of support, including donations or collaboration with the core team.
+- If you're interested in sponsoring SpeechBrain, please reach out to us at speechbrainproject@gmail.com.
+- A heartfelt thank you to all our sponsors, including the current ones:
+
 
-# SpeechBrain Roadmap
 
-As a community-based and open-source project, SpeechBrain needs the help of its community to grow in the right direction. Opening the roadmap to our users enables the toolkit to benefit from new ideas, new research axes, or even new technologies. The roadmap will be available in our [GitHub Discussions](https://github.com/speechbrain/speechbrain/discussions/categories/announcements) and will list all the changes and updates that need to be done in the current version of SpeechBrain. Users are more than welcome to propose new items via new Discussions topics!
+[<img src="https://huggingface.co/front/assets/huggingface_logo.svg" alt="Image 1" width="250"/>](https://speechbrain.github.io/img/hf.ico) &nbsp; &nbsp;
+[<img src="https://speechbrain.github.io/img/sponsors/logo_vd.png" alt="Image 3" width="250"/>](https://viadialog.com/en/) &nbsp; &nbsp;
+[<img src="https://speechbrain.github.io/img/sponsors/logo_nle.png" alt="Image 4" width="250"/>](https://europe.naverlabs.com/)
 
-# Learning SpeechBrain
+<br><br>
 
-We provide users with different resources to learn how to use SpeechBrain:
-- General information can be found on the [website](https://speechbrain.github.io).
-- We offer many tutorials, you can start from the [basic ones](https://speechbrain.github.io/tutorial_basics.html) about SpeechBrain's basic functionalities and building blocks. We provide also more advanced tutorials (e.g SpeechBrain advanced, signal processing ...). You can browse them via the Tutorials drop-down menu on [SpeechBrain website](https://speechbrain.github.io) in the upper right.
-- Details on the SpeechBrain API, how to contribute, and the code are given in the [documentation](https://speechbrain.readthedocs.io/en/latest/index.html).
+[<img src="https://speechbrain.github.io/img/sponsors/logo_ovh.png" alt="Image 5" width="250"/>](https://www.ovhcloud.com/en-ca/) &nbsp; &nbsp;
+[<img src="https://speechbrain.github.io/img/sponsors/logo_badu.png" alt="Image 2" width="250"/>](https://usa.baidu.com/) &nbsp; &nbsp;
+[<img src="https://speechbrain.github.io/img/sponsors/samsung_official.png" alt="Image 6" width="250"/>](https://research.samsung.com/aicenter_cambridge)
 
-# License
-SpeechBrain is released under the Apache License, version 2.0. The Apache license is a popular BSD-like license. SpeechBrain can be redistributed for free, even for commercial purposes, although you can not take off the license headers (and under some circumstances, you may have to distribute a license document). Apache is not a viral license like the GPL, which forces you to release your modifications to the source code. Note that this project has no connection to the Apache Foundation, other than that we use the same license terms.
+<br><br>
 
-# Social Media
-We constantly update the community using Twitter. [Feel free to follow us](https://twitter.com/speechbrain1)
+[<img src="https://speechbrain.github.io/img/sponsors/logo_mila_small.png" alt="Image 7" width="250"/>](https://mila.quebec/en/) &nbsp; &nbsp;
+[<img src="https://www.concordia.ca/content/dam/common/logos/Concordia-logo.jpeg" alt="Image 9" width="250"/>](https://www.concordia.ca/) &nbsp; &nbsp;
+[<img src="https://speechbrain.github.io/img/partners/logo_lia.png" alt="Image 8" width="250"/>](https://lia.univ-avignon.fr/) &nbsp; &nbsp;
+#
+# 📖 Citing SpeechBrain
 
-# Citing SpeechBrain
-Please, cite SpeechBrain if you use it for your research or business.
+If you use SpeechBrain in your research or business, please cite it using the following BibTeX entry:
 
 ```bibtex
 @misc{speechbrain,
diff --git a/conftest.py b/conftest.py
index 2f9c5106116ae7c90aa3340f6dd716cdfb26b9a4..b38b98cb63506bb6ff55db2d1a0aace14e2b6f8a 100644
--- a/conftest.py
+++ b/conftest.py
@@ -15,6 +15,10 @@ try:
     import numba  # noqa: F401
 except ModuleNotFoundError:
     collect_ignore.append("speechbrain/nnet/loss/transducer_loss.py")
+try:
+    import kenlm  # noqa: F401
+except ModuleNotFoundError:
+    collect_ignore.append("speechbrain/decoders/language_model.py")
 try:
     import fairseq  # noqa: F401
 except ModuleNotFoundError:
@@ -22,12 +26,41 @@ except ModuleNotFoundError:
 try:
     from transformers import Wav2Vec2Model  # noqa: F401
 except ModuleNotFoundError:
-    collect_ignore.append("speechbrain/lobes/models/huggingface_wav2vec.py")
+    collect_ignore.append(
+        "speechbrain/lobes/models/huggingface_transformers/wav2vec2.py"
+    )
 try:
     from transformers import WhisperModel  # noqa: F401
 except ModuleNotFoundError:
-    collect_ignore.append("speechbrain/lobes/models/huggingface_whisper.py")
+    collect_ignore.append(
+        "speechbrain/lobes/models/huggingface_transformers/whisper.py"
+    )
+try:
+    import sklearn  # noqa: F401
+except ModuleNotFoundError:
+    collect_ignore.append("speechbrain/utils/kmeans.py")
+    collect_ignore.append(
+        "speechbrain/lobes/models/huggingface_transformers/discrete_hubert.py"
+    )
+    collect_ignore.append(
+        "speechbrain/lobes/models/huggingface_transformers/discrete_wav2vec2.py"
+    )
+    collect_ignore.append(
+        "speechbrain/lobes/models/huggingface_transformers/discrete_wavlm.py"
+    )
+try:
+    import peft  # noqa: F401
+except ModuleNotFoundError:
+    collect_ignore.append(
+        "speechbrain/lobes/models/huggingface_transformers/llama2.py"
+    )
 try:
     import sacrebleu  # noqa: F401
 except ModuleNotFoundError:
     collect_ignore.append("speechbrain/utils/bleu.py")
+try:
+    import vocos  # noqa: F401
+except ModuleNotFoundError:
+    collect_ignore.append(
+        "speechbrain/lobes/models/huggingface_transformers/vocos.py"
+    )
diff --git a/docs/audioloading.rst b/docs/audioloading.rst
new file mode 100644
index 0000000000000000000000000000000000000000..bef9673720776753c55778a80bd7b3fe8ba0486e
--- /dev/null
+++ b/docs/audioloading.rst
@@ -0,0 +1,107 @@
+=============================
+Audio loading troubleshooting
+=============================
+
+This page is intended to document how to install torchaudio backends and
+provides troubleshooting steps for your audio loading troubles.
+
+Introduction
+============
+
+SpeechBrain relies on
+`torchaudio <https://pytorch.org/audio/stable/index.html>`_
+for loading audio files in most cases. Please first try to **update torchaudio**
+if you are encountering issues. Please also ensure that you are using the
+correct PyTorch version for your installed torchaudio version.
+
+As of torchaudio `2.2.0`, three backends are supported: ``ffmpeg``, ``sox`` and
+``soundfile``. torchaudio documents how their backends are found in their
+`optional dependency docs <https://pytorch.org/audio/stable/installation.html#optional-dependencies>`_.
+
+You can determine which backends are available in your environment by running
+:func:`torchaudio.list_audio_backends`.
+
+.. warning::
+    **A backend can *silently* fail to load** if initialization failed and will be
+    omitted from this list.
+
+.. warning::
+    **Not every backend can support any codec.** For instance, at the time of
+    writing, the torchaudio SoX backend cannot handle MP3 and the SoundFile
+    backend cannot handle AAC (usually ``.m4a``), both of which are found in
+    certain popular speech datasets.
+    However, most common formats are typically well supported by all backends
+    (``.wav``/``.ogg`` vorbis/opus/``.flac``).
+
+Recommended install steps
+=========================
+
+Often, torchaudio will work out of the box. On certain systems, there might not
+be a working backend installed. We recommend you try if any of those steps fixes
+your issue:
+
+- On Linux, if you have superuser rights, install ffmpeg and/or libsndfile
+  and/or SoX through your distribution's package manager.
+
+- On Windows/Linux/macOS, you can try installing ffmpeg through Conda
+  (see `ffmpeg`_), which does not require superuser rights (provided Conda is
+  available).
+
+- On macOS, alternatively, it appears to be possible to install ffmpeg through
+  Homebrew. Make sure that you are installing a version compatible with
+  torchaudio (see `ffmpeg`_).
+
+- On Windows/Linux/macOS, `SoundFile <https://pypi.org/project/soundfile/>`_
+  has started shipping with a prebuilt ``libsndfile``, which does not require
+  admin rights. Try installing or updating it. See the linked page for more
+  details.
+
+Note for developers & breaking torchaudio `2.x` changes
+=======================================================
+
+With torchaudio `<2.x`, backends were selected through
+``torchaudio.set_audio_backend``. This function was deprecated and then
+removed in the `2.x` branch of torchaudio and is no longer used in SpeechBrain.
+Since then, the backend is (optionally) selected through the ``backend``
+argument of :func:`torchaudio.load` and :func:`torchaudio.info`.
+
+Installing/troubleshooting backends
+===================================
+
+ffmpeg
+------
+
+torchaudio compiles their ffmpeg backend for a **specific range** of ffmpeg
+versions.
+
+ffmpeg is commonly already installed on common Linux distributions.
+On Ubuntu, it can be installed through ``sudo apt install ffmpeg``.
+
+Depending on your OS version, it is possible that your installed ffmpeg version
+is not supported by torchaudio (if too recent or too old).
+If you believe this to be the case, you can try installing a specific version
+of the ``ffmpeg`` package as supplied by
+`conda-forge <https://anaconda.org/conda-forge/ffmpeg>`_.
+
+See `torchaudio documentation on optional dependencies <https://pytorch.org/audio/stable/installation.html#optional-dependencies>`_ for more details.
+
+SoundFile
+---------
+
+torchaudio can use `soundfile <https://pypi.org/project/soundfile/>`_ as an
+audio backend, which depends on ``libsndfile``.
+
+Starting with SoundFile 0.12.0, this package bundles a prebuilt ``libsndfile``
+for a number of platforms. Refer to the project page for more details.
+
+SoX
+---
+
+Starting with torchaudio 0.12.0, the SoX backend no longer supports mp3 files.
+
+Starting with torchaudio 2.1.0, torchaudio no longer compiles and bundles SoX
+by itself, and expects it to be provided by the system.
+
+If you have upgraded from an earlier version and can no longer load audio files,
+it may be due to this. In this case, you may need to install SoX or use a
+different backend.
diff --git a/docs/conf.py b/docs/conf.py
index 4b7e2d06996a4c3a01bb3d62e0d7257b3f59213d..494d2930fc91820164001730daf489daaadcd089 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -61,6 +61,7 @@ intersphinx_mapping = {
     "python": ("https://docs.python.org/", None),
     "numpy": ("http://docs.scipy.org/doc/numpy/", None),
     "torch": ("https://pytorch.org/docs/master/", None),
+    "torchaudio": ("https://pytorch.org/audio/stable/", None),
 }
 
 # AUTODOC:
@@ -85,6 +86,9 @@ templates_path = ["_templates"]
 # This pattern also affects html_static_path and html_extra_path.
 exclude_patterns = ["_apidoc_templates"]
 
+# Make backticks behave as inline code blocks rather than italics
+default_role = "code"
+
 # -- Better apidoc -----------------------------------------------------------
 
 
@@ -93,20 +97,6 @@ def run_apidoc(app):
     import better_apidoc
 
     better_apidoc.APP = app
-
-    better_apidoc.main(
-        [
-            "better-apidoc",
-            "-t",
-            "_apidoc_templates",
-            "--force",
-            "--no-toc",
-            "--separate",
-            "-o",
-            "API",
-            os.path.dirname(hyperpyyaml.__file__),
-        ]
-    )
     better_apidoc.main(
         [
             "better-apidoc",
@@ -118,6 +108,7 @@ def run_apidoc(app):
             "-o",
             "API",
             os.path.join("../", "speechbrain"),
+            os.path.dirname(hyperpyyaml.__file__),
         ]
     )
 
@@ -131,6 +122,7 @@ html_theme = "sphinx_rtd_theme"
 # See https://sphinx-rtd-theme.readthedocs.io/en/stable/configuring.html
 # for rtd theme options
 html_theme_options = {
+    "logo_only": True,
     # Toc options
     "collapse_navigation": False,
     "sticky_navigation": True,
@@ -138,6 +130,8 @@ html_theme_options = {
     "includehidden": True,
 }
 
+html_logo = "images/speechbrain-logo.svg"
+
 
 # Add any paths that contain custom static files (such as style sheets) here,
 # relative to this directory. They are copied after the builtin static files,
diff --git a/docs/experiment.md b/docs/experiment.md
index 34eccf7a7cd1f105612a156b298e8c7605c1b709..135094ecb304723f8ec2424bd53afb6e6d1f30cd 100644
--- a/docs/experiment.md
+++ b/docs/experiment.md
@@ -54,7 +54,6 @@ SpeechBrain defines a set of running arguments that can be set from the command
 - `debug`: a flag that enables debug mode, only running a few iterations to verify that program won't crash.
 - `data_parallel_backend`: a flag that enables `data_parallel` for multigpu training on a single machine.
 - `data_parallel_count`: default "-1" (use all gpus), if > 0, use a subset of gpus available `[0, 1, ..., data_parallel_count]`.
-- `distributed_launch`: A flag that enables training with `ddp` for multiGPU training. Assumes `torch.distributed.launch` was used to start script. the `local_rank` and `rank` UNIX arguments are parsed.
 - `distributed_backend`: default "nccl", options: `["nccl", "gloo", "mpi"]`, this backend will be used as a DDP communication protocol. See PyTorch documentation for more details.
 - Additional runtime arguments are documented in the Brain class.
 
diff --git a/docs/index.rst b/docs/index.rst
index 071665a519e35843495e63c214efe461f2445546..b0bbf41bf19c8889aad36c27469007211b097167 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -3,21 +3,21 @@
    You can adapt this file completely to your liking, but it should at least
    contain the root `toctree` directive.
 
-.. image:: images/speechbrain-logo.svg
-  :width: 400
-  :align: center
+==========
+User guide
+==========
 
 SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch.
 This documentation is intended to give SpeechBrain users all the API
 information necessary to develop their projects. For tutorials,
 please refer to the official `Github <https://github.com/speechbrain/speechbrain>`_
-or the official `Website <https://speechbrain.github.io>`
+or the official `Website <https://speechbrain.github.io>`_.
 
 
 License
 -------
 
-SpeechBrain is released under the Apache license, version 2.0. The Apache license is a popular BSD-like license.
+SpeechBrain is released under the `Apache License, version 2.0 <https://github.com/speechbrain/speechbrain/blob/develop/LICENSE>`_. The Apache license is a popular BSD-like license.
 SpeechBrain can be redistributed for free, even for commercial purposes, although you can not take off the license headers (and under some circumstances you may have to distribute a license document).
 Apache is not a viral license like the GPL, which forces you to release your modifications to the source code. Also note that this project has no connection to the Apache Foundation, other than that we use the same license terms.
 
@@ -40,12 +40,25 @@ Referencing SpeechBrain
 
 .. toctree::
    :maxdepth: 1
-   :caption: Getting started:
+   :caption: Getting started
 
    installation.md
    experiment.md
    multigpu.md
    tutorials.md
+
+
+.. toctree::
+   :maxdepth: 1
+   :caption: Tips & tricks
+
+   audioloading.rst
+
+
+.. toctree::
+   :maxdepth: 1
+   :caption: Contributing
+
    contributing.md
    guidance.md
    coverage.md
@@ -65,11 +78,17 @@ API Documentation
 
    speechbrain
    speechbrain.alignment
+   speechbrain.augment
    speechbrain.dataio
    speechbrain.decoders
+   speechbrain.inference
+   speechbrain.k2_integration
    speechbrain.lm
    speechbrain.lobes
    speechbrain.nnet
    speechbrain.processing
    speechbrain.tokenizers
    speechbrain.utils
+   speechbrain.wordemb
+
+   hyperpyyaml.core
\ No newline at end of file
diff --git a/docs/installation.md b/docs/installation.md
index 13295dfa153ed056319a124aac36cfacb4315fb0..3dad4a728408fc8bb2e1e36a498af995a8c06367 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -1,11 +1,11 @@
 
 # Quick installation
 
-SpeechBrain is constantly evolving. New features, tutorials, and documentation will appear over time. SpeechBrain can be installed via PyPI to rapidly use the standard library. Moreover, a local installation can be used to run experiments and modify/customize the toolkit.
+SpeechBrain is constantly evolving. New features, tutorials, and documentation will appear over time. SpeechBrain can be installed via PyPI to rapidly use the standard library. Moreover, a local installation can be used to run experiments and modify/customize the toolkit and its recipes.
 
-SpeechBrain supports both CPU and GPU computations. For most recipes, however, a GPU is necessary during training. Please note that CUDA must be properly installed to use GPUs.
+SpeechBrain supports both CPU and GPU computation. For most recipes, however, a GPU is necessary during training. Please note that CUDA must be properly installed to use GPUs.
 
-We support pytorch >= 1.7 (https://pytorch.org/) and Python >= 3.7.
+We support [PyTorch](https://pytorch.org/get-started/locally/) 1.9+ and Python 3.9-3.11 (newer Python versions may work if supported by PyTorch).
 
 ## Install via PyPI
 
@@ -15,6 +15,8 @@ Once you have created your Python environment (see instructions below) you can s
 pip install speechbrain
 ```
 
+Depending on your OS, audio loading may require the install of optional torchaudio dependencies to work. If it does not work out-of-the box for you, please visit [audio troubleshooting](audioloading.html).
+
 Then you can then access SpeechBrain with:
 
 ```
@@ -59,20 +61,20 @@ tests/.run-doctests.sh
 SpeechBrain supports Linux-based distributions and macOS. A solution for windows users can be found
 in this [GitHub issue](https://github.com/speechbrain/speechbrain/issues/512).
 
-## Anaconda and venv
+## Setting up a Conda environment/virtualenv
 
 A good practice is to have different python environments for your different tools
 and toolkits, so they do not interfere with each other. This can be done either with
 [Anaconda](https://www.anaconda.com/products/distribution) or [venv](https://docs.python.org/3.8/library/venv.html).
 
-Anaconda can be installed by simply following [this tutorial](https://docs.anaconda.com/anaconda/install/linux/). In practice, it is a matter of downloading the installation script and executing it.
+Anaconda can be installed by simply following [this tutorial](https://docs.anaconda.com/free/anaconda/install/linux/). In practice, it is a matter of downloading the installation script and executing it.
 
-## Anaconda setup
+### Conda
 
-Once Anaconda is installed, you can create a new environment with:
+Once Conda is installed, you can create a new environment with:
 
 ```
-conda create --name speechbrain python=3.9
+conda create --name speechbrain python=3.11
 ```
 
 Then, activate it with:
diff --git a/docs/multigpu.md b/docs/multigpu.md
index 18585731d76c0209510336e9c31f43cc6196e425..57c996da2b64a33ffe93cb1521b30e67c2bacc3a 100644
--- a/docs/multigpu.md
+++ b/docs/multigpu.md
@@ -27,7 +27,7 @@ Using SpeechBrain, this would look like:
 
 ```bash
 cd recipes/<dataset>/<task>/
-python -m torch.distributed.launch --nproc_per_node=4 experiment.py hyperparams.yaml --distributed_launch
+torchrun --nproc_per_node=4 experiment.py hyperparams.yaml
 ```
 
 ... where `nproc_per_node` is the the number of processes to spawn/GPUs to use.
@@ -45,27 +45,25 @@ While DDP is more efficient than `DataParallel`, it is somewhat prone to exhibit
 
 Let's start with a simple example where a user is able to connect to each node directly. Consider that we have 2 nodes with 2 GPUs each (for a total of 4 GPUs).
 
-We use `torch.distributed.launch` once on each machine, with the following parameters:
+We use `torchrun` once on each machine, with the following parameters:
 
 - `--nproc_per_node=2` means we will spawn 2 processes per node, which equates to 2 GPUs per nodes.
 - `--nnodes=2` means we will be using two nodes in total.
 - `--node_rank=0` and `--node_rank=1` refer to the rank/"index" we are attributing to the node/machine.
 - `--master_addr`/`--master_port` define the IP address and the port of the "master" machine. In this case, we're arbitrarily choosing the first machine to be the "master" of everyone else (the 2nd machine in our case). Note that `5555` might be taken by a different process if you are unlucky or if you would run multiple different training scripts on that node, so you may need to choose a different free port.
 
-We also need to pass `--distributed_launch` as a parameter **to our script** (`experiment.py`) as opposed to `torch.distributed.launch`. This is so we tell SpeechBrain to enable DDP.
-
 Hence, we get:
 
 ```bash
 # Machine 1
 cd recipes/<dataset>/<task>/
-python -m torch.distributed.launch --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr machine_1_address --master_port 5555 experiment.py hyperparams.yaml --distributed_launch
+torchrun --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr machine_1_address --master_port 5555 experiment.py hyperparams.yaml
 ```
 
 ```bash
 # Machine 2
 cd recipes/<dataset>/<task>/
-python -m torch.distributed.launch --nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr machine_1_address --master_port 5555 experiment.py hyperparams.yaml --distributed_launch
+torchrun --nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr machine_1_address --master_port 5555 experiment.py hyperparams.yaml
 ```
 
 In this setup:
@@ -77,7 +75,7 @@ In this setup:
     - Subprocess #1: `local_rank`=0, `rank`=2
     - Subprocess #2: `local_rank`=1, `rank`=3
 
-In practice, using `torch.distributed.launch` ensures that the right environment variables are set (`local_rank` and `rank`), so you don't have to bother with it.
+In practice, using `torchrun` ensures that the right environment variables are set (`LOCAL_RANK` and `RANK`), so you don't have to bother with it.
 
 #### Multi-node setup with Slurm
 
@@ -118,8 +116,8 @@ conda activate super_cool_sb_env
 LISTNODES=`scontrol show hostname $SLURM_JOB_NODELIST`
 MASTER=`echo $LISTNODES | cut -d" " -f1`
 
-# here --nproc_per_node=4 because we want torch.distributed to spawn 4 processes (4 GPUs). Then we give the total amount of nodes requested (--nnodes) and then --node_rank that is necessary to dissociate the node that we are calling this from.
-python -m torch.distributed.launch --nproc_per_node=4 --nnodes=${SLURM_JOB_NUM_NODES} --node_rank=${SLURM_NODEID} --master_addr=${MASTER} --master_port=5555 train.py hparams/myrecipe.yaml
+# here --nproc_per_node=4 because we want torchrun to spawn 4 processes (4 GPUs). Then we give the total amount of nodes requested (--nnodes) and then --node_rank that is necessary to dissociate the node that we are calling this from.
+torchrun --nproc_per_node=4 --nnodes=${SLURM_JOB_NUM_NODES} --node_rank=${SLURM_NODEID} --master_addr=${MASTER} --master_port=5555 train.py hparams/myrecipe.yaml
 ```
 
 ## (DEPRECATED) Single-node multi-GPU training using Data Parallel
diff --git a/docs/readthedocs-requirements.txt b/docs/readthedocs-requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..28c9f6e93642a8e9b3a2c9c43b40ac90e323d6ee
--- /dev/null
+++ b/docs/readthedocs-requirements.txt
@@ -0,0 +1,9 @@
+# readthedocs only lets us define a single requirements file in the yaml
+# this file merges both the usual and the docs requirements so that everything
+# gets installed correctly.
+
+--find-links https://k2-fsa.github.io/k2/cpu.html
+-r ../requirements.txt
+-r docs-requirements.txt
+k2==1.24.4.dev20240223+cpu.torch2.2.1
+torch==2.2.1
diff --git a/recipes/AISHELL-1/ASR/CTC/README.md b/recipes/AISHELL-1/ASR/CTC/README.md
index 7743365445f8cce9e6697bd7621e264128c4b171..eef7d82b8f5248d9f9644cb99961a058b233b8e7 100644
--- a/recipes/AISHELL-1/ASR/CTC/README.md
+++ b/recipes/AISHELL-1/ASR/CTC/README.md
@@ -6,7 +6,6 @@ This folder contains a CTC-wav2vec2 recipe for speech recognition with [AISHELL-
 A pretrained tokenizer from [huggingface](https://huggingface.co/bert-base-chinese) is used and can be downloaded
 automatically.
 
-If not present in the specified data_folder, the dataset will be automatically downloaded there.
 This step is not mandatory. We will use the official tokenizer downloaded from the web if you do not
 specify a different tokenizer in the speech recognition recipe.
 
diff --git a/recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml b/recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml
index 491480a4856998daca43c75a890b532e8240d12c..486685f258aba538b0262cdd5339b23c9dc266a2 100644
--- a/recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml
+++ b/recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml
@@ -10,7 +10,7 @@
 
 seed: 2
 __set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/ctc_wav2vec/<seed>
+output_folder: !ref results/ctc_wav2vec2/<seed>
 cer_file: !ref <output_folder>/cer.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -19,6 +19,7 @@ train_log: !ref <output_folder>/train_log.txt
 data_folder: !PLACEHOLDER # e,g./path/to/aishell
 
 skip_prep: False
+remove_compressed_wavs: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 train_data: !ref <output_folder>/train.csv
 valid_data: !ref <output_folder>/dev.csv
@@ -27,29 +28,31 @@ test_data: !ref <output_folder>/test.csv
 wav2vec2_hub: TencentGameMate/chinese-wav2vec2-large
 wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 80
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
 # Must be 8 per GPU to fit 32GB of VRAM
 batch_size: 10
-test_batch_size: 4
+test_batch_size: 1
 
 dynamic_batching: False
+max_batch_length: 15 # in terms of "duration" in annotations by default, second here
+shuffle: False # if true re-creates batches at each epoch shuffling examples.
+num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
+batch_ordering: ascending
 dynamic_batch_sampler:
-   feats_hop_size: 0.01
-   max_batch_len: 15 # in terms of "duration" in annotations by default, second here
-   left_bucket_len: 200 # old implementation attributs
-   multiplier: 1.1 # old implementation attributs
-   shuffle_ex: False # if true re-creates batches at each epoch shuffling examples.
-   num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
-   batch_ordering: ascending
+   max_batch_length: !ref <max_batch_length>
+   shuffle: !ref <shuffle>
+   num_buckets: !ref <num_buckets>
+   batch_ordering: !ref <batch_ordering>
 
 num_workers: 6
 
@@ -74,36 +77,69 @@ tokenizer: !apply:transformers.BertTokenizer.from_pretrained
 # bert-base-chinese tokens length
 output_neurons: 21128
 
-# Decoding parameters
+############################## Decoding ########################################
+
 # Be sure that the bos and eos index match with the BPEs ones
+# Decoding parameters
+test_searcher: !name:speechbrain.decoders.CTCBeamSearcher
 blank_index: 0
+beam_size: 100
+beam_prune_logp: -12.0
+token_prune_min_logp: -1.2
+prune_history: True
+topk: 1
+alpha: 1.0
+beta: 0.5
+# can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
+# It can either be a .bin or .arpa ; note: .arpa is much slower at loading
+# If you don't want to use an LM, comment it out or set it to null
+# kenlm_model_path: none
+
 
 # AISHELL-1 has spaces between words in the transcripts,
 # which Chinese writing normally does not do.
 # If remove_spaces, spaces are removed
 # from the transcript before computing CER.
-# (e.g., 祝 可爱 的 你 —> 祝可爱的你)
 remove_spaces: True
 split_tokens: !apply:operator.not_ [!ref <remove_spaces>]
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
 
-SpeedPerturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [90, 100, 110]
-
-SpecAugment: !new:speechbrain.lobes.augment.SpecAugment
-   time_warp: True
-   time_warp_window: 5
-   time_warp_mode: bicubic
-   freq_mask: True
-   n_freq_mask: 2
-   time_mask: True
-   n_time_mask: 2
-   replace_with_zero: False
-   freq_mask_width: 30
-   time_mask_width: 40
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 35
+   drop_length_high: 45
+   drop_count_low: 2
+   drop_count_high: 2
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 25
+   drop_length_high: 35
+   drop_count_low: 2
+   drop_count_high: 2
+   dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+   min_augmentations: 3
+   max_augmentations: 3
+   augment_prob: 1.0
+   augmentations: [
+      !ref <time_drop>,
+      !ref <freq_drop>,
+      !ref <time_warp>]
+
+############################## Models ##########################################
 
 enc: !new:speechbrain.nnet.containers.Sequential
    input_shape: [null, null, !ref <wav2vec_output_dim>]
@@ -127,7 +163,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
    bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
    activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
    source: !ref <wav2vec2_hub>
    output_norm: True
    freeze: !ref <freeze_wav2vec>
@@ -171,6 +207,8 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.9
    patient: 0
 
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py b/recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py
index 853e8aa161714b73d97530278fd88f5ba6b827de..43783eed7f49f54a3f41e8e6adffdb2381eee50e 100644
--- a/recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py
+++ b/recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py
@@ -35,19 +35,14 @@ class ASR(sb.Brain):
         # Add augmentation if specified
         if stage == sb.Stage.TRAIN:
             if hasattr(self.hparams, "SpeedPerturb"):
-                wavs = self.hparams.SpeedPerturb(wavs, wav_lens)
-
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
+                wavs = self.hparams.speed_perturb(wavs, wav_lens)
 
         # Forward pass
         feats = self.modules.wav2vec2(wavs, wav_lens)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "SpecAugment"):
-                feats = self.hparams.SpecAugment(feats)
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
 
         x = self.modules.enc(feats)
         logits = self.modules.ctc_lin(x)
@@ -61,19 +56,23 @@ class ASR(sb.Brain):
         ids = batch.id
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
+        # Labels must be extended if parallel augmentation or concatenated
+        # augmentation was performed on the input (increasing the time dimension)
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "fea_augment"):
+                tokens = self.hparams.fea_augment.replicate_labels(tokens)
+                tokens_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_lens
+                )
 
         loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
 
-        if stage != sb.Stage.TRAIN:
+        if stage == sb.Stage.VALID:
             # Decode token terms to words
             sequences = sb.decoders.ctc_greedy_decode(
                 p_ctc, wav_lens, blank_id=self.hparams.blank_index
             )
             predicted_words_list = []
-            target_words_list = [list(wrd) for wrd in batch.wrd]
 
             for sequence in sequences:
                 # Decode token terms to words
@@ -92,36 +91,29 @@ class ASR(sb.Brain):
 
                 predicted_words_list.append(predicted_words)
 
+        elif stage == sb.Stage.TEST:
+            p_tokens = test_searcher(p_ctc, wav_lens)
+            # select one-best
+            text_hyps = [hyp[0].text for hyp in p_tokens]
+
+            predicted_words_list = []
+            preds = []
+            for seq in text_hyps:
+                seq = seq.replace("[CLS]", "")
+                seq = seq.replace("[SEP]", "")
+                seq = seq.replace("[PAD]", "")
+                for c in seq:
+                    preds.append(c)
+                predicted_words_list.append(preds)
+
+        if stage != sb.Stage.TRAIN:
+            target_words_list = [list(wrd) for wrd in batch.wrd]
             self.cer_metric.append(
                 ids=ids, predict=predicted_words_list, target=target_words_list,
             )
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-
-        if self.check_gradients(loss):
-            if not self.hparams.wav2vec2.freeze:
-                self.wav2vec_optimizer.step()
-            self.model_optimizer.step()
-
-        if not self.hparams.wav2vec2.freeze:
-            self.wav2vec_optimizer.zero_grad()
-        self.model_optimizer.zero_grad()
-
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         self.batch_idx = 0
@@ -192,10 +184,23 @@ class ASR(sb.Brain):
         if self.checkpointer is not None:
             self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
 
-    def zero_grad(self, set_to_none=False):
         if not self.hparams.wav2vec2.freeze:
-            self.wav2vec_optimizer.zero_grad(set_to_none)
-        self.model_optimizer.zero_grad(set_to_none)
+            self.optimizers_dict = {
+                "wav2vec_optimizer": self.wav2vec_optimizer,
+                "model_optimizer": self.model_optimizer,
+            }
+        else:
+            self.optimizers_dict = {"model_optimizer": self.model_optimizer}
+
+    def freeze_optimizers(self, optimizers):
+        """Freezes the wav2vec2 optimizer according to the warmup steps"""
+        valid_optimizers = {}
+        if not self.hparams.wav2vec2.freeze:
+            valid_optimizers["wav2vec_optimizer"] = optimizers[
+                "wav2vec_optimizer"
+            ]
+        valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+        return valid_optimizers
 
 
 def dataio_prepare(hparams):
@@ -277,24 +282,13 @@ def dataio_prepare(hparams):
         from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
 
         dynamic_hparams = hparams["dynamic_batch_sampler"]
-        num_buckets = dynamic_hparams["num_buckets"]
 
         train_batch_sampler = DynamicBatchSampler(
-            train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            train_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
         valid_batch_sampler = DynamicBatchSampler(
-            valid_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            valid_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
     return (
@@ -314,7 +308,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -335,6 +328,7 @@ if __name__ == "__main__":
             "data_folder": hparams["data_folder"],
             "save_folder": hparams["output_folder"],
             "skip_prep": hparams["skip_prep"],
+            "remove_compressed_wavs": hparams["remove_compressed_wavs"],
         },
     )
 
@@ -356,8 +350,22 @@ if __name__ == "__main__":
         checkpointer=hparams["checkpointer"],
     )
 
-    # adding objects to trainer:
     asr_brain.tokenizer = tokenizer
+    vocab_list = [
+        tokenizer.convert_ids_to_tokens(i) for i in range(tokenizer.vocab_size)
+    ]
+    test_searcher = hparams["test_searcher"](
+        blank_index=hparams["blank_index"],
+        vocab_list=vocab_list,
+        alpha=hparams["alpha"],
+        beta=hparams["beta"],
+        beam_size=hparams["beam_size"],
+        beam_prune_logp=hparams["beam_prune_logp"],
+        token_prune_min_logp=hparams["token_prune_min_logp"],
+        prune_history=hparams["prune_history"],
+        topk=hparams["topk"],
+        kenlm_model_path=hparams.get("kenlm_model_path"),
+    )
 
     # Changing the samplers if dynamic batching is activated
     train_dataloader_opts = hparams["train_dataloader_opts"]
diff --git a/recipes/AISHELL-1/ASR/seq2seq/README.md b/recipes/AISHELL-1/ASR/seq2seq/README.md
index 6d77920fa1f2392ac732d17d9a822d9d1a5a38fe..4db3746e084ce794d5a256fb894370cc68061109 100644
--- a/recipes/AISHELL-1/ASR/seq2seq/README.md
+++ b/recipes/AISHELL-1/ASR/seq2seq/README.md
@@ -9,9 +9,8 @@ To train a full recipe:
 
 ```
 cd ../../Tokenizer
-python train.py hparams/tokenizer_bpe5000.yaml --data_folder=/localscratch/aishell/
+python train.py hparams/tokenizer_bpe5000.yaml --data_folder=/path/to/aishell/
 ```
-If not present in the specified data_folder, the dataset will be automatically downloaded there.
 This step is not mandatory. We will use the official tokenizer downloaded from the web if you do not
 specify a different tokenizer in the speech recognition recipe.
 
diff --git a/recipes/AISHELL-1/ASR/seq2seq/hparams/train.yaml b/recipes/AISHELL-1/ASR/seq2seq/hparams/train.yaml
index 7acbfd00ea99971fa335cf6a11db8299e5addc92..e6fda7de26ed417a2cae1ddf6389aacb717d420d 100644
--- a/recipes/AISHELL-1/ASR/seq2seq/hparams/train.yaml
+++ b/recipes/AISHELL-1/ASR/seq2seq/hparams/train.yaml
@@ -17,35 +17,38 @@ save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
 # Data files
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
 data_folder: !PLACEHOLDER # e,g./path/to/aishell
-# noise/ris dataset will automatically be downloaded
-data_folder_rirs: !ref <data_folder> # Change this is needed
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
 skip_prep: False
+remove_compressed_wavs: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 train_data: !ref <output_folder>/train.csv
 valid_data: !ref <output_folder>/dev.csv
 test_data: !ref <output_folder>/test.csv
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
 tokenizer_file: speechbrain/asr-transformer-aishell/tokenizer.ckpt
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 40
 number_of_ctc_epochs: 10
 batch_size: 16
 lr: 0.0003
 ctc_weight: 0.5
 sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
 
 dynamic_batching: True
+max_batch_length: 15 # in terms of "duration" in annotations by default, second here
+shuffle: False # if true re-creates batches at each epoch shuffling examples.
+num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
+batch_ordering: ascending
 dynamic_batch_sampler:
-   feats_hop_size: 0.01
-   max_batch_len: 15 # in terms of "duration" in annotations by default, second here
-   left_bucket_len: 200 # old implementation attributs
-   multiplier: 1.1 # old implementation attributs
-   shuffle_ex: False # if true re-creates batches at each epoch shuffling examples.
-   num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
-   batch_ordering: ascending
-
-num_workers: 6
+   max_batch_length: !ref <max_batch_length>
+   shuffle: !ref <shuffle>
+   num_buckets: !ref <num_buckets>
+   batch_ordering: !ref <batch_ordering>
 
 # Feature parameters
 sample_rate: 16000
@@ -56,6 +59,7 @@ opt_class: !name:torch.optim.Adam
    lr: !ref <lr>
 
 # Dataloader options
+num_workers: 4
 train_dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>
@@ -68,7 +72,7 @@ test_dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -85,9 +89,11 @@ dnn_neurons: 512
 emb_size: 128
 dec_neurons: 1024
 output_neurons: 5000  # Number of tokens
+# we need to have blank_index != bos_index != eos_index when using CTCScorer
 blank_index: 0
-bos_index: 0
-eos_index: 0
+bos_index: 1
+eos_index: 2
+label_smoothing: 0.1
 
 # Decoding parameters
 min_decode_ratio: 0.0
@@ -98,12 +104,11 @@ using_max_attn_shift: True
 max_attn_shift: 240
 coverage_penalty: 1.5
 temperature: 1.25
-
+scorer_beam_scale: 0.5
 # AISHELL-1 has spaces between words in the transcripts,
 # which Chinese writing normally does not do.
 # If remove_spaces, spaces are removed
 # from the transcript before computing CER.
-# (e.g., 祝 可爱 的 你 —> 祝可爱的你)
 remove_spaces: True
 split_tokens: !apply:operator.not_ [!ref <remove_spaces>]
 
@@ -113,23 +118,64 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
 normalize: !new:speechbrain.processing.features.InputNormalization
    norm_type: global
 
+############################## Augmentations ###################################
+
 compute_features: !new:speechbrain.lobes.features.Fbank
    sample_rate: !ref <sample_rate>
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mels>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-   openrir_folder: !ref <data_folder_rirs>
-   babble_prob: 0.0
-   reverb_prob: 0.0
-   noise_prob: 1.0
-   noise_snr_low: 0
-   noise_snr_high: 15
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+   URL: !ref <NOISE_DATASET_URL>
+   dest_folder: !ref <data_folder_noise>
+   ext: wav
+   csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+   csv_file: !ref <noise_annotation>
+   snr_low: 0
+   snr_high: 15
+   noise_sample_rate: !ref <sample_rate>
+   clean_sample_rate: !ref <sample_rate>
+   num_workers: !ref <num_workers>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
    speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <add_noise>,
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
    input_shape: [null, null, !ref <n_mels>]
    activation: !ref <activation>
@@ -183,7 +229,7 @@ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
    blank_index: !ref <blank_index>
 
 seq_cost: !name:speechbrain.nnet.losses.nll_loss
-   label_smoothing: 0.1
+   label_smoothing: !ref <label_smoothing>
 
 # Models
 modules:
@@ -193,8 +239,6 @@ modules:
    ctc_lin: !ref <ctc_lin>
    seq_lin: !ref <seq_lin>
    normalize: !ref <normalize>
-   env_corrupt: !ref <env_corrupt>
-   #lm_model: !ref <lm_model>
 
 model: !new:torch.nn.ModuleList
    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
@@ -208,22 +252,37 @@ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    paths:
       tokenizer: !ref <tokenizer_file>
 
+############################## Decoding ########################################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+   eos_index: !ref <eos_index>
+   blank_index: !ref <blank_index>
+   ctc_fc: !ref <ctc_lin>
+
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+   vocab_size: !ref <output_neurons>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+   full_scorers: [!ref <coverage_scorer>, !ref <ctc_scorer>]
+   weights:
+      coverage: !ref <coverage_penalty>
+      ctc: !ref <ctc_weight>
+   scorer_beam_scale: !ref <scorer_beam_scale>
+
 beam_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <beam_size>
    eos_threshold: !ref <eos_threshold>
+   temperature: !ref <temperature>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
-   temperature: !ref <temperature>
+   scorer: !ref <scorer>
 
 lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
@@ -231,6 +290,8 @@ lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.8
    patient: 0
 
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/AISHELL-1/ASR/seq2seq/train.py b/recipes/AISHELL-1/ASR/seq2seq/train.py
index f8a22afccfa8a9afa64fc9cbdb15458ae0533830..bc2c49b888ce0cb997d03a87da151a0cc524d35d 100644
--- a/recipes/AISHELL-1/ASR/seq2seq/train.py
+++ b/recipes/AISHELL-1/ASR/seq2seq/train.py
@@ -1,8 +1,6 @@
 #!/usr/bin/env/python3
 """
-
 AISHELL-1 seq2seq model recipe. (Adapted from the LibriSpeech recipe.)
-
 """
 
 import sys
@@ -24,16 +22,10 @@ class ASR(sb.Brain):
         tokens_bos, _ = batch.tokens_bos
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
 
         # Forward pass
         feats = self.hparams.compute_features(wavs)
@@ -47,42 +39,39 @@ class ASR(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
+        p_ctc, p_tokens = None, None
         if stage == sb.Stage.TRAIN:
             current_epoch = self.hparams.epoch_counter.current
             if current_epoch <= self.hparams.number_of_ctc_epochs:
                 # Output layer for ctc log-probabilities
                 logits = self.modules.ctc_lin(x)
                 p_ctc = self.hparams.log_softmax(logits)
-                return p_ctc, p_seq, wav_lens
-            else:
-                return p_seq, wav_lens
         else:
-            p_tokens, scores = self.hparams.beam_search(x, wav_lens)
-            return p_seq, wav_lens, p_tokens
+            p_tokens, _, _, _ = self.hparams.beam_search(x, wav_lens)
+
+        return p_ctc, p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (CTC+NLL) given predictions and targets."""
 
         current_epoch = self.hparams.epoch_counter.current
-        if stage == sb.Stage.TRAIN:
-            if current_epoch <= self.hparams.number_of_ctc_epochs:
-                p_ctc, p_seq, wav_lens = predictions
-            else:
-                p_seq, wav_lens = predictions
-        else:
-            p_seq, wav_lens, predicted_tokens = predictions
+        p_ctc, p_seq, wav_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
+        # Labels must be extended if parallel augmentation or concatenated
+        # augmentation was performed on the input (increasing the time dimension)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            (
+                tokens,
+                tokens_lens,
+                tokens_eos,
+                tokens_eos_lens,
+            ) = self.hparams.wav_augment.replicate_multiple_labels(
+                tokens, tokens_lens, tokens_eos, tokens_eos_lens
             )
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
 
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
@@ -117,24 +106,6 @@ class ASR(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_idx += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         self.batch_idx = 0
@@ -255,24 +226,13 @@ def dataio_prepare(hparams):
         from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
 
         dynamic_hparams = hparams["dynamic_batch_sampler"]
-        num_buckets = dynamic_hparams["num_buckets"]
 
         train_batch_sampler = DynamicBatchSampler(
-            train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            train_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
         valid_batch_sampler = DynamicBatchSampler(
-            valid_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            valid_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
     return (
@@ -292,7 +252,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -313,8 +272,10 @@ if __name__ == "__main__":
             "data_folder": hparams["data_folder"],
             "save_folder": hparams["output_folder"],
             "skip_prep": hparams["skip_prep"],
+            "remove_compressed_wavs": hparams["remove_compressed_wavs"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
 
     # here we create the datasets objects as well as tokenization and encoding
     (
@@ -328,7 +289,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Trainer initialization
     asr_brain = ASR(
diff --git a/recipes/AISHELL-1/ASR/transformer/README.md b/recipes/AISHELL-1/ASR/transformer/README.md
index d7e6f1973f96d30ea942553d0a8535aea9957ee8..722dab20b87d303111ee83a5b16927167e943e47 100644
--- a/recipes/AISHELL-1/ASR/transformer/README.md
+++ b/recipes/AISHELL-1/ASR/transformer/README.md
@@ -6,9 +6,8 @@ This folder contains recipes for tokenization and speech recognition with [AISHE
 
 ```
 cd ../../Tokenizer
-python train.py hparams/train_transformer_tokenizer_bpe5000.yaml --data_folder=/localscratch/aishell/
+python train.py hparams/train_transformer_tokenizer_bpe5000.yaml --data_folder=/path/to/aishell
 ```
-If not present in the specified data_folder, the dataset will be automatically downloaded there.
 This step is not mandatory. We will use the official tokenizer downloaded from the web if you do not
 specify a different tokenizer in the speech recognition recipe.
 
@@ -17,7 +16,7 @@ specify a different tokenizer in the speech recognition recipe.
 python train.py hparams/train_ASR_transformer.yaml --data_folder=/localscratch/aishell/
 ```
 
-Make sure to have "transformers" installed if you use the wav2vec2 recipe (see extra-requirements.txt)
+Make sure to have `transformers` installed if you use the wav2vec2 recipe (see extra-requirements.txt)
 
 # Performance summary
 Results are reported in terms of Character Error Rate (CER).
diff --git a/recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer.yaml b/recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer.yaml
index 7916146b48c8300130054467a7013e934190582e..408c9e68008f74950c1916b53b6b9c29da7efdec 100644
--- a/recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer.yaml
+++ b/recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer.yaml
@@ -15,36 +15,44 @@ cer_file: !ref <output_folder>/cer.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+
 # Data files
 data_folder: !PLACEHOLDER # e,g./path/to/aishell
-# noise/ris dataset will automatically be downloaded
-data_folder_rirs: !ref <data_folder> # Change this is needed
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
 skip_prep: False
+remove_compressed_wavs: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
-train_data: !ref <output_folder>/train.csv
-valid_data: !ref <output_folder>/dev.csv
-test_data: !ref <output_folder>/test.csv
+train_data: !ref <save_folder>/train.csv
+valid_data: !ref <save_folder>/dev.csv
+test_data: !ref <save_folder>/test.csv
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
 tokenizer_file: speechbrain/asr-transformer-aishell/tokenizer.ckpt
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 50
 batch_size: 8
 ctc_weight: 0.3
-gradient_accumulation: 4
+grad_accumulation_factor: 4
 loss_reduction: 'batchmean'
 sorting: random
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
+precision: fp32 # bf16, fp16 or fp32
 
 dynamic_batching: False
+max_batch_length: 15 # in terms of "duration" in annotations by default, second here
+shuffle: False # if true re-creates batches at each epoch shuffling examples.
+num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
+batch_ordering: ascending
 dynamic_batch_sampler:
-    feats_hop_size: 0.01
-    max_batch_len: 15 # in terms of "duration" in annotations by default, second here
-    left_bucket_len: 200 # old implementation attributs
-    multiplier: 1.1 # old implementation attributs
-    shuffle_ex: False # if true re-creates batches at each epoch shuffling examples.
-    num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
-    batch_ordering: ascending
+    max_batch_length: !ref <max_batch_length>
+    shuffle: !ref <shuffle>
+    num_buckets: !ref <num_buckets>
+    batch_ordering: !ref <batch_ordering>
 
-num_workers: 6
+num_workers: 4
 
 # stages related parameters
 stage_one_epochs: 40
@@ -59,15 +67,18 @@ n_mels: 80
 # Dataloader options
 train_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
     shuffle: True
 
 valid_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 test_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
-####################### Model parameters ###########################
+####################### Model Parameters #######################################
 # Transformer
 d_model: 256
 nhead: 4
@@ -93,7 +104,7 @@ valid_beam_size: 10
 test_beam_size: 10
 ctc_weight_decode: 0.40
 
-############################## models ################################
+############################## Models ##########################################
 
 CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
     input_shape: (8, 10, 80)
@@ -126,20 +137,12 @@ seq_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <d_model>
     n_neurons: !ref <output_neurons>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
 
 modules:
     CNN: !ref <CNN>
     Transformer: !ref <Transformer>
     seq_lin: !ref <seq_lin>
     ctc_lin: !ref <ctc_lin>
-    env_corrupt: !ref <env_corrupt>
 
 model: !new:torch.nn.ModuleList
     - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
@@ -155,30 +158,39 @@ SGD: !name:torch.optim.SGD
     momentum: 0.99
     nesterov: True
 
+############################## Decoding & optimiser ############################
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-    bos_index: !ref <bos_index>
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
     eos_index: !ref <eos_index>
     blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer>
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -211,23 +223,72 @@ normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
     update_until_epoch: 4
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
 
 compute_features: !new:speechbrain.lobes.features.Fbank
     sample_rate: !ref <sample_rate>
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
+############################## Augmentation ####################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 1
+    max_augmentations: 1
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>]
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 0
+    drop_length_high: 100
+    drop_count_low: 2
+    drop_count_high: 2
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 30
+    drop_length_high: 40
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 1
+    max_augmentations: 1
+    augment_start_index: !ref <batch_size> # This leaves unchanges original inputs
+    concat_end_index: !ref <batch_size> # This leaves unchanges original inputs
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+############################## Logging and Pretrainer ##########################
+
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
 
@@ -235,7 +296,6 @@ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
 # which Chinese writing normally does not do.
 # If remove_spaces, spaces are removed
 # from the transcript before computing CER.
-# (e.g., 祝 可爱 的 你 —> 祝可爱的你)
 remove_spaces: True
 split_tokens: !apply:operator.not_ [!ref <remove_spaces>]
 
diff --git a/recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer_with_wav2vect.yaml b/recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer_with_wav2vect.yaml
index 95eeb50c8a89d7cdda91c724b9adee9683f66a9b..a196afc5842ef5235d0f6e9c5179984319c327bf 100644
--- a/recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer_with_wav2vect.yaml
+++ b/recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer_with_wav2vect.yaml
@@ -17,63 +17,63 @@ train_log: !ref <output_folder>/train_log.txt
 
 # Data files
 data_folder: !PLACEHOLDER # e,g./path/to/aishell
-# noise/ris dataset will automatically be downloaded
-data_folder_rirs: !ref <data_folder> # Change this is needed
 skip_prep: False
+remove_compressed_wavs: False
 ckpt_interval_minutes: 30 # save checkpoint every N min
-train_data: !ref <output_folder>/train.csv
-valid_data: !ref <output_folder>/dev.csv
-test_data: !ref <output_folder>/test.csv
+train_data: !ref <save_folder>/train.csv
+valid_data: !ref <save_folder>/dev.csv
+test_data: !ref <save_folder>/test.csv
 tokenizer_file: speechbrain/asr-transformer-aishell/tokenizer.ckpt
-
+sample_rate: 16000
 # Self-supervised pre-training
 wav2vec2_hub: facebook/wav2vec2-large-100k-voxpopuli
 wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
 freeze_wav2vec: False
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 80
 batch_size: 2
-ctc_weight: 0.3
-gradient_accumulation: 16
+grad_accumulation_factor: 16
 loss_reduction: 'batchmean'
 sorting: random
+ctc_weight: 0.3
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
+precision: fp32 # bf16, fp16 or fp32
 
 dynamic_batching: False
+max_batch_length: 15 # in terms of "duration" in annotations by default, second here
+shuffle: False # if true re-creates batches at each epoch shuffling examples.
+num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
+batch_ordering: ascending
 dynamic_batch_sampler:
-    feats_hop_size: 0.01
-    max_batch_len: 15 # in terms of "duration" in annotations by default, second here
-    left_bucket_len: 200 # old implementation attributs
-    multiplier: 1.1 # old implementation attributs
-    shuffle_ex: False # if true re-creates batches at each epoch shuffling examples.
-    num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
-    batch_ordering: ascending
+    max_batch_length: !ref <max_batch_length>
+    shuffle: !ref <shuffle>
+    num_buckets: !ref <num_buckets>
+    batch_ordering: !ref <batch_ordering>
 
-num_workers: 6
+num_workers: 4
 
 # stages related parameters
 stage_one_epochs: 40
 lr_adam: 1.0
 lr_sgd: 0.000025
-# lr_wav2vec: 0.0001
-
-# Feature parameters
-# sample_rate: 16000
-# n_fft: 400
-# n_mels: 80
 
 # Dataloader options
 train_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
     shuffle: True
 
 valid_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 test_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
-####################### Model parameters ###########################
+####################### Model Parameters #######################################
 # Transformer
 d_model: 256
 nhead: 4
@@ -99,9 +99,9 @@ valid_beam_size: 10
 test_beam_size: 10
 ctc_weight_decode: 0.40
 
-############################## models ################################
+############################## Models ##########################################
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
@@ -130,24 +130,48 @@ seq_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <d_model>
     n_neurons: !ref <output_neurons>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
 
 modules:
     wav2vec2: !ref <wav2vec2>
     Transformer: !ref <Transformer>
     seq_lin: !ref <seq_lin>
     ctc_lin: !ref <ctc_lin>
-    env_corrupt: !ref <env_corrupt>
 
 model: !new:torch.nn.ModuleList
     - [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
 
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Decoding & optimiser ############################
+
 # define two optimizers here for two-stage training
 Adam: !name:torch.optim.Adam
     lr: 0
@@ -164,30 +188,38 @@ wav2vec_opt_class: !name:torch.optim.Adam
     betas: (0.9, 0.98)
     eps: 0.000000001
 
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer>
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -210,6 +242,7 @@ noam_annealing_wav2vect: !new:speechbrain.nnet.schedulers.NoamScheduler
     n_warmup_steps: 25000
     model_size: !ref <d_model>
 
+############################## Logging and Pretrainer ##########################
 
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
@@ -224,19 +257,6 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
-
-
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
 
@@ -244,7 +264,6 @@ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
 # which Chinese writing normally does not do.
 # If remove_spaces, spaces are removed
 # from the transcript before computing CER.
-# (e.g., 祝 可爱 的 你 —> 祝可爱的你)
 remove_spaces: True
 split_tokens: !apply:operator.not_ [!ref <remove_spaces>]
 
diff --git a/recipes/AISHELL-1/ASR/transformer/train.py b/recipes/AISHELL-1/ASR/transformer/train.py
index f27779488da44fa0febc8e2d9d678b4fbea6aafc..63361bf0dd9d597664ce175238529ad10f6442be 100644
--- a/recipes/AISHELL-1/ASR/transformer/train.py
+++ b/recipes/AISHELL-1/ASR/transformer/train.py
@@ -3,6 +3,9 @@
 
 AISHELL-1 transformer model recipe. (Adapted from the LibriSpeech recipe.)
 
+Authors
+    * Jianyuan Zhong 2021
+    * Titouan Parcollet 2021
 """
 
 import sys
@@ -23,23 +26,20 @@ class ASR(sb.core.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, _ = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
 
         # compute features
         feats = self.hparams.compute_features(wavs)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
+
         current_epoch = self.hparams.epoch_counter.current
         feats = self.hparams.normalize(feats, wav_lens, epoch=current_epoch)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                feats = self.hparams.augmentation(feats)
-
         # forward modules
         src = self.modules.CNN(feats)
         enc_out, pred = self.modules.Transformer(
@@ -56,17 +56,19 @@ class ASR(sb.core.Brain):
 
         # Compute outputs
         hyps = None
-        if stage == sb.Stage.TRAIN:
-            hyps = None
-        elif stage == sb.Stage.VALID:
-            hyps = None
-            current_epoch = self.hparams.epoch_counter.current
-            if current_epoch % self.hparams.valid_search_interval == 0:
-                # for the sake of efficiency, we only perform beamsearch with limited capacity
-                # and no LM to give user some idea of how the AM is doing
-                hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens)
-        elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
+        current_epoch = self.hparams.epoch_counter.current
+        is_valid_search = (
+            stage == sb.Stage.VALID
+            and current_epoch % self.hparams.valid_search_interval == 0
+        )
+        is_test_search = stage == sb.Stage.TEST
+
+        if is_valid_search:
+            hyps, _, _, _ = self.hparams.valid_search(
+                enc_out.detach(), wav_lens
+            )
+        elif is_test_search:
+            hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
         return p_ctc, p_seq, wav_lens, hyps
 
@@ -79,13 +81,28 @@ class ASR(sb.core.Brain):
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
-            )
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
+        if stage == sb.Stage.TRAIN:
+            # Labels must be extended if parallel augmentation or concatenated
+            # augmentation was performed on the input (increasing the time dimension)
+            if hasattr(self.hparams, "wav_augment"):
+                (
+                    tokens,
+                    tokens_lens,
+                    tokens_eos,
+                    tokens_eos_lens,
+                ) = self.hparams.wav_augment.replicate_multiple_labels(
+                    tokens, tokens_lens, tokens_eos, tokens_eos_lens
+                )
+
+            if hasattr(self.hparams, "fea_augment"):
+                (
+                    tokens,
+                    tokens_lens,
+                    tokens_eos,
+                    tokens_eos_lens,
+                ) = self.hparams.fea_augment.replicate_multiple_labels(
+                    tokens, tokens_lens, tokens_eos, tokens_eos_lens
+                )
 
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
@@ -117,37 +134,17 @@ class ASR(sb.core.Brain):
             self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
+    def on_fit_batch_start(self, batch, should_step):
+        """Gets called at the beginning of each fit_batch."""
         # check if we need to switch optimizer
         # if so change the optimizer from Adam to SGD
         self.check_and_reset_optimizer()
 
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        # normalize the loss by gradient_accumulation step
-        (loss / self.hparams.gradient_accumulation).backward()
-
-        if self.step % self.hparams.gradient_accumulation == 0:
-            # gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
-
-            self.optimizer.step()
-            self.optimizer.zero_grad()
-
-            # anneal lr every update
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
             self.hparams.noam_annealing(self.optimizer)
 
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        with torch.no_grad():
-            predictions = self.compute_forward(batch, stage=stage)
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -171,7 +168,7 @@ class ASR(sb.core.Brain):
                 stage_stats["CER"] = self.cer_metric.summarize("error_rate")
 
         # log stats and save checkpoint at end-of-epoch
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
 
             # report different epoch stages according current stage
             current_epoch = self.hparams.epoch_counter.current
@@ -198,7 +195,7 @@ class ASR(sb.core.Brain):
             self.checkpointer.save_and_keep_only(
                 meta={"ACC": stage_stats["ACC"], "epoch": epoch},
                 max_keys=["ACC"],
-                num_to_keep=10,
+                num_to_keep=self.hparams.avg_checkpoints,
             )
 
         elif stage == sb.Stage.TEST:
@@ -256,9 +253,7 @@ class ASR(sb.core.Brain):
                 if "momentum" not in group:
                     return
 
-                self.checkpointer.recover_if_possible(
-                    device=torch.device(self.device)
-                )
+                self.checkpointer.recover_if_possible()
 
     def on_evaluate_start(self, max_key=None, min_key=None):
         """perform checkpoint averge if needed"""
@@ -268,7 +263,7 @@ class ASR(sb.core.Brain):
             max_key=max_key, min_key=min_key
         )
         ckpt = sb.utils.checkpoints.average_checkpoints(
-            ckpts, recoverable_name="model", device=self.device
+            ckpts, recoverable_name="model",
         )
 
         self.hparams.model.load_state_dict(ckpt, strict=True)
@@ -313,7 +308,7 @@ def dataio_prepare(hparams):
     test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
         csv_path=hparams["test_data"], replacements={"data_root": data_folder},
     )
-    test_data = test_data.filtered_sorted(sort_key="duration")
+    test_data = test_data.filtered_sorted(sort_key="duration", reverse=True)
 
     datasets = [train_data, valid_data, test_data]
 
@@ -359,24 +354,13 @@ def dataio_prepare(hparams):
         from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
 
         dynamic_hparams = hparams["dynamic_batch_sampler"]
-        num_buckets = dynamic_hparams["num_buckets"]
 
         train_batch_sampler = DynamicBatchSampler(
-            train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            train_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
         valid_batch_sampler = DynamicBatchSampler(
-            valid_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            valid_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
     return (
@@ -396,7 +380,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -417,8 +400,10 @@ if __name__ == "__main__":
             "data_folder": hparams["data_folder"],
             "save_folder": hparams["output_folder"],
             "skip_prep": hparams["skip_prep"],
+            "remove_compressed_wavs": hparams["remove_compressed_wavs"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
 
     # here we create the datasets objects as well as tokenization and encoding
     (
@@ -432,7 +417,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Trainer initialization
     asr_brain = ASR(
diff --git a/recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py b/recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py
index b43f2b42ac9f8b542de2f1b575a0109ec3917d89..53aa47375146d3707224211740ad31ccbdc00c66 100644
--- a/recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py
+++ b/recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py
@@ -24,22 +24,15 @@ class ASR(sb.core.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, _ = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
 
         # compute features
         feats = self.modules.wav2vec2(wavs, wav_lens)
         current_epoch = self.hparams.epoch_counter.current
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                feats = self.hparams.augmentation(feats)
-
         # forward modules
         enc_out, pred = self.hparams.Transformer(
             feats, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index
@@ -55,17 +48,19 @@ class ASR(sb.core.Brain):
 
         # Compute outputs
         hyps = None
-        if stage == sb.Stage.TRAIN:
-            hyps = None
-        elif stage == sb.Stage.VALID:
-            hyps = None
-            current_epoch = self.hparams.epoch_counter.current
-            if current_epoch % self.hparams.valid_search_interval == 0:
-                # for the sake of efficiency, we only perform beamsearch with limited capacity
-                # and no LM to give user some idea of how the AM is doing
-                hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens)
-        elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
+        current_epoch = self.hparams.epoch_counter.current
+        is_valid_search = (
+            stage == sb.Stage.VALID
+            and current_epoch % self.hparams.valid_search_interval == 0
+        )
+        is_test_search = stage == sb.Stage.TEST
+
+        if is_valid_search:
+            hyps, _, _, _ = self.hparams.valid_search(
+                enc_out.detach(), wav_lens
+            )
+        elif is_test_search:
+            hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
         return p_ctc, p_seq, wav_lens, hyps
 
@@ -78,13 +73,18 @@ class ASR(sb.core.Brain):
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
-            )
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
+        if stage == sb.Stage.TRAIN:
+            # Labels must be extended if parallel augmentation or concatenated
+            # augmentation was performed on the input (increasing the time dimension)
+            if hasattr(self.hparams, "wav_augment"):
+                (
+                    tokens,
+                    tokens_lens,
+                    tokens_eos,
+                    tokens_eos_lens,
+                ) = self.hparams.wav_augment.replicate_multiple_labels(
+                    tokens, tokens_lens, tokens_eos, tokens_eos_lens
+                )
 
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
@@ -116,40 +116,18 @@ class ASR(sb.core.Brain):
             self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
+    def on_fit_batch_start(self, batch, should_step):
+        """Gets called at the beginning of each fit_batch."""
         # check if we need to switch optimizer
         # if so change the optimizer from Adam to SGD
         self.check_and_reset_optimizer()
 
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        # normalize the loss by gradient_accumulation step
-        (loss / self.hparams.gradient_accumulation).backward()
-
-        if self.step % self.hparams.gradient_accumulation == 0:
-            # gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
-
-            self.optimizer.step()
-            self.optimizer_wav2vect.step()
-            self.optimizer.zero_grad()
-            self.optimizer_wav2vect.zero_grad()
-
-            # anneal lr every update
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
             self.hparams.noam_annealing(self.optimizer)
             self.hparams.noam_annealing_wav2vect(self.optimizer_wav2vect)
 
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        with torch.no_grad():
-            predictions = self.compute_forward(batch, stage=stage)
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -173,7 +151,7 @@ class ASR(sb.core.Brain):
                 stage_stats["CER"] = self.cer_metric.summarize("error_rate")
 
         # log stats and save checkpoint at end-of-epoch
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
 
             # report different epoch stages according current stage
             current_epoch = self.hparams.epoch_counter.current
@@ -200,7 +178,7 @@ class ASR(sb.core.Brain):
             self.checkpointer.save_and_keep_only(
                 meta={"ACC": stage_stats["ACC"], "epoch": epoch},
                 max_keys=["ACC"],
-                num_to_keep=10,
+                num_to_keep=self.hparams.avg_checkpoints,
             )
 
         elif stage == sb.Stage.TEST:
@@ -270,7 +248,7 @@ class ASR(sb.core.Brain):
             max_key=max_key, min_key=min_key
         )
         ckpt = sb.utils.checkpoints.average_checkpoints(
-            ckpts, recoverable_name="model", device=self.device
+            ckpts, recoverable_name="model",
         )
 
         self.hparams.model.load_state_dict(ckpt, strict=True)
@@ -289,9 +267,10 @@ class ASR(sb.core.Brain):
             )
             self.checkpointer.add_recoverable("modelopt", self.optimizer)
 
-    def zero_grad(self, set_to_none=False):
-        self.optimizer_wav2vect.zero_grad(set_to_none)
-        self.optimizer.zero_grad(set_to_none)
+        self.optimizers_dict = {
+            "wav2vect_optimizer": self.optimizer_wav2vect,
+            "model_optimizer": self.optimizer,
+        }
 
 
 def dataio_prepare(hparams):
@@ -378,24 +357,13 @@ def dataio_prepare(hparams):
         from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
 
         dynamic_hparams = hparams["dynamic_batch_sampler"]
-        num_buckets = dynamic_hparams["num_buckets"]
 
         train_batch_sampler = DynamicBatchSampler(
-            train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            train_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
         valid_batch_sampler = DynamicBatchSampler(
-            valid_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            valid_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
     return (
@@ -415,7 +383,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -436,6 +403,7 @@ if __name__ == "__main__":
             "data_folder": hparams["data_folder"],
             "save_folder": hparams["output_folder"],
             "skip_prep": hparams["skip_prep"],
+            "remove_compressed_wavs": hparams["remove_compressed_wavs"],
         },
     )
 
@@ -451,7 +419,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Trainer initialization
     asr_brain = ASR(
diff --git a/recipes/AISHELL-1/Tokenizer/hparams/tokenizer_bpe5000.yaml b/recipes/AISHELL-1/Tokenizer/hparams/tokenizer_bpe5000.yaml
index 1c550971f99d33473888c9acd75e7631d1d11867..d2cb230189911a83d54c5b6caedf7fffa334c57b 100644
--- a/recipes/AISHELL-1/Tokenizer/hparams/tokenizer_bpe5000.yaml
+++ b/recipes/AISHELL-1/Tokenizer/hparams/tokenizer_bpe5000.yaml
@@ -10,11 +10,12 @@ output_folder: !ref results/tokenizer_bpe5000/
 # Data files
 data_folder: !PLACEHOLDER # e.g, /localscratch/aishell
 skip_prep: False
+remove_compressed_wavs: False
 train_csv: !ref <output_folder>/train.csv
 valid_csv: !ref <output_folder>/dev.csv
 
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 5000  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
diff --git a/recipes/AISHELL-1/Tokenizer/hparams/train_transformer_tokenizer_bpe5000.yaml b/recipes/AISHELL-1/Tokenizer/hparams/train_transformer_tokenizer_bpe5000.yaml
index 4163a47ab35cd657d9eea3ae695d97299f414ab2..973df9a1194b9828e6002062fdefbb871da0fc77 100644
--- a/recipes/AISHELL-1/Tokenizer/hparams/train_transformer_tokenizer_bpe5000.yaml
+++ b/recipes/AISHELL-1/Tokenizer/hparams/train_transformer_tokenizer_bpe5000.yaml
@@ -10,11 +10,12 @@ output_folder: !ref results/transformer_tokenizer_bpe5000/
 # Data files
 data_folder: !PLACEHOLDER # e.g, /localscratch/aishell
 skip_prep: False
+remove_compressed_wavs: False
 train_csv: !ref <output_folder>/train.csv
 valid_csv: !ref <output_folder>/dev.csv
 
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 5000  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
diff --git a/recipes/AISHELL-1/Tokenizer/train.py b/recipes/AISHELL-1/Tokenizer/train.py
index 4e31ab61c32f5680a47279a7f8050da210c9a34f..6aac5e11ba1240dd1a42901276890b6435d77e96 100644
--- a/recipes/AISHELL-1/Tokenizer/train.py
+++ b/recipes/AISHELL-1/Tokenizer/train.py
@@ -25,7 +25,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -46,6 +45,7 @@ if __name__ == "__main__":
             "data_folder": hparams["data_folder"],
             "save_folder": hparams["output_folder"],
             "skip_prep": hparams["skip_prep"],
+            "remove_compressed_wavs": hparams["remove_compressed_wavs"],
         },
     )
 
diff --git a/recipes/AISHELL-1/aishell_prepare.py b/recipes/AISHELL-1/aishell_prepare.py
index 290ba03b1bb1c781a44704f95d42ce6417a2a9b6..1fb9f2bd9c56d1b3ddff360225e86286ec0e2b49 100644
--- a/recipes/AISHELL-1/aishell_prepare.py
+++ b/recipes/AISHELL-1/aishell_prepare.py
@@ -1,102 +1,220 @@
+"""
+Data preparation.
+
+Download: https://www.openslr.org/33/
+
+Authors
+-------
+ * Adel Moumen 2023
+"""
+
 import os
 import shutil
 import logging
-from speechbrain.dataio.dataio import read_audio
-from speechbrain.utils.data_utils import download_file
 import glob
 import csv
+from speechbrain.dataio.dataio import read_audio_info
+from speechbrain.utils.parallel import parallel_map
+import functools
 
 logger = logging.getLogger(__name__)
 
 
-def prepare_aishell(data_folder, save_folder, skip_prep=False):
+def extract_and_cleanup_wav_files(
+    tgz_list, wav_dir, splits, remove_compressed_wavs
+):
+    """This function extracts the wav files in the AISHELL-1 dataset.
+
+    Arguments
+    ---------
+    tgz_list: list
+        list of paths to the tar.gz files.
+    wav_dir: str
+        path to the wav directory.
+    splits: list
+        list of splits.
+    remove_compressed_wavs: bool
+        If True, remove compressed wav files after extraction.
     """
-    This function prepares the AISHELL-1 dataset.
-    If the folder does not exist, the zip file will be extracted. If the zip file does not exist, it will be downloaded.
+    if len(tgz_list) > 0:
+        logger.info(f"Extracting wav files in {wav_dir}...")
+
+        decompress_processor = functools.partial(
+            shutil.unpack_archive, extract_dir=wav_dir,
+        )
+
+        for split in splits:
+            os.makedirs(os.path.join(wav_dir, split), exist_ok=True)
+
+        for _ in parallel_map(decompress_processor, tgz_list, chunk_size=64):
+            pass
 
-    data_folder : path to AISHELL-1 dataset.
-    save_folder: path where to store the manifest csv files.
-    skip_prep: If True, skip data preparation.
+        if remove_compressed_wavs:
+            for tgz in tgz_list:
+                os.remove(tgz)
 
+
+def process_line(wav, filename2transcript):
+    """This function processes a line of the csv file.
+
+    This function is being used in the context of multi-processing.
+
+    Arguments
+    ---------
+    wav: str
+        path to the wav file.
+    filename2transcript: dict
+        dictionary mapping filenames to transcripts.
+
+    Returns
+    -------
+    list
+        list containing the duration, the path to the wav file and the transcript.
     """
+    filename = wav.split("/")[-1].split(".wav")[0]
+
+    info = read_audio_info(wav)
+    duration = info.num_frames / info.sample_rate
+
+    transcript_ = filename2transcript[filename]
+
+    return [str(duration), wav, transcript_]
+
+
+def skip(splits, save_folder):
+    """ Detect when the AiSHELL-1 data preparation can be skipped.
+
+    Arguments
+    ---------
+    splits : list
+        A list of the splits expected in the preparation.
+    save_folder : str
+        The location of the save directory
+
+    Returns
+    -------
+    bool
+        if True, the preparation phase can be skipped.
+        if False, it must be done.
+    """
+    # Checking csv files
+    skip = True
+
+    for split in splits:
+        if not os.path.isfile(os.path.join(save_folder, split + ".csv")):
+            skip = False
+
+    return skip
+
+
+def prepare_aishell(
+    data_folder, save_folder, skip_prep=False, remove_compressed_wavs=True
+):
+    """This function prepares the AISHELL-1 dataset.
+
+    Arguments
+    ---------
+    data_folder: str
+        path to AISHELL-1 dataset.
+    save_folder: str
+        path where to store the manifest csv files.
+    skip_prep: bool
+        If True, skip data preparation.
+    remove_compressed_wavs: bool
+        If True, remove compressed wav files after extraction.
+    """
+
     if skip_prep:
         return
 
-    # If the data folders do not exist, we need to extract the data
-    if not os.path.isdir(os.path.join(data_folder, "data_aishell/wav")):
-        # Check for zip file and download if it doesn't exist
-        zip_location = os.path.join(data_folder, "data_aishell.tgz")
-        if not os.path.exists(zip_location):
-            url = "https://www.openslr.org/resources/33/data_aishell.tgz"
-            download_file(url, zip_location, unpack=True)
-        logger.info("Extracting data_aishell.tgz...")
-        shutil.unpack_archive(zip_location, data_folder)
-        wav_dir = os.path.join(data_folder, "data_aishell/wav")
-        tgz_list = glob.glob(wav_dir + "/*.tar.gz")
-        for tgz in tgz_list:
-            shutil.unpack_archive(tgz, wav_dir)
-            os.remove(tgz)
+    wav_dir = os.path.join(data_folder, "wav")
+    tgz_list = glob.glob(wav_dir + "/*.tar.gz")
+
+    splits = [
+        "train",
+        "dev",
+        "test",
+    ]
+
+    if skip(splits, save_folder):
+        return
+
+    extract_and_cleanup_wav_files(
+        tgz_list, wav_dir, splits, remove_compressed_wavs=remove_compressed_wavs
+    )
 
     # Create filename-to-transcript dictionary
     filename2transcript = {}
-    with open(
-        os.path.join(
-            data_folder, "data_aishell/transcript/aishell_transcript_v0.8.txt"
-        ),
-        "r",
-    ) as f:
+    path_to_transcript = os.path.join(
+        data_folder, "transcript/aishell_transcript_v0.8.txt"
+    )
+
+    with open(path_to_transcript, "r",) as f:
         lines = f.readlines()
         for line in lines:
             key = line.split()[0]
             value = " ".join(line.split()[1:])
             filename2transcript[key] = value
 
-    splits = [
-        "train",
-        "dev",
-        "test",
-    ]
-    ID_start = 0  # needed to have a unique ID for each audio
+    line_processor = functools.partial(
+        process_line, filename2transcript=filename2transcript,
+    )
+
     for split in splits:
-        new_filename = os.path.join(save_folder, split) + ".csv"
-        if os.path.exists(new_filename):
-            continue
-        logger.info("Preparing %s..." % new_filename)
 
-        csv_output = [["ID", "duration", "wav", "transcript"]]
-        entry = []
+        final_csv = os.path.join(save_folder, split) + ".csv"
+        tmp_csv = os.path.join(save_folder, split) + ".tmp"
+
+        logger.info("Preparing %s..." % final_csv)
 
         all_wavs = glob.glob(
-            os.path.join(data_folder, "data_aishell/wav")
-            + "/"
-            + split
-            + "/*/*.wav"
+            os.path.join(data_folder, "wav") + "/" + split + "/*/*.wav"
         )
-        for i in range(len(all_wavs)):
-            filename = all_wavs[i].split("/")[-1].split(".wav")[0]
-            if filename not in filename2transcript:
-                continue
-            signal = read_audio(all_wavs[i])
-            duration = signal.shape[0] / 16000
-            transcript_ = filename2transcript[filename]
-            csv_line = [
-                ID_start + i,
-                str(duration),
-                all_wavs[i],
-                transcript_,
-            ]
-            entry.append(csv_line)
-
-        csv_output = csv_output + entry
-
-        with open(new_filename, mode="w") as csv_f:
+        # only keep the files that are in the transcript
+        transcript_wavs = [
+            wav
+            for wav in all_wavs
+            if wav.split("/")[-1].split(".wav")[0] in filename2transcript
+        ]
+
+        total_line = 0
+        total_duration = 0
+        id = 0
+        with open(tmp_csv, mode="w", encoding="utf-8") as csv_f:
             csv_writer = csv.writer(
                 csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
             )
-            for line in csv_output:
-                csv_writer.writerow(line)
+            csv_writer.writerow(["ID", "duration", "wav", "transcript"])
+            for row in parallel_map(
+                line_processor, transcript_wavs, chunk_size=4092
+            ):
+
+                if row is None:
+                    continue
+
+                row = [str(id)] + row
+                csv_writer.writerow(row)
+
+                total_line += 1
+                total_duration += float(row[1])
+                id += 1
+
+        msg = f"Number of samples: {total_line} "
+        logger.info(msg)
+        msg = "Total duration: %s Hours" % (
+            str(round(total_duration / 3600, 2))
+        )
+
+        logger.info(msg)
+
+        os.replace(tmp_csv, final_csv)
 
-        msg = "\t%s successfully created!" % (new_filename)
+        msg = "\t%s successfully created!" % (final_csv)
         logger.info(msg)
 
-        ID_start += len(all_wavs)
+        msg = f"Number of samples: {total_line} "
+        logger.info(msg)
+        msg = "Total duration: %s Hours" % (
+            str(round(total_duration / 3600, 2))
+        )
+        logger.info(msg)
diff --git a/recipes/AMI/Diarization/experiment.py b/recipes/AMI/Diarization/experiment.py
index 4005d35bda3bf0167734f0624c24b00d861af350..0e9ff0395f4d87733273d2e702065adbe48a633a 100755
--- a/recipes/AMI/Diarization/experiment.py
+++ b/recipes/AMI/Diarization/experiment.py
@@ -552,7 +552,7 @@ if __name__ == "__main__":  # noqa: C901
     # We download the pretrained Model from HuggingFace (or elsewhere depending on
     # the path given in the YAML file).
     run_on_main(params["pretrainer"].collect_files)
-    params["pretrainer"].load_collected(device=run_opts["device"])
+    params["pretrainer"].load_collected()
     params["embedding_model"].eval()
     params["embedding_model"].to(run_opts["device"])
 
diff --git a/recipes/Aishell1Mix/separation/README.md b/recipes/Aishell1Mix/separation/README.md
index fb74e7d87eeea3d93aff1727240fa627a82be915..13c4b589efa877d1022acbf4422118efef550bf6 100644
--- a/recipes/Aishell1Mix/separation/README.md
+++ b/recipes/Aishell1Mix/separation/README.md
@@ -77,8 +77,8 @@ The output folders with model checkpoints and logs is available [here](https://w
 
 You can run the following command to train the model using Distributed Data Parallel (DDP) with 2 GPUs:
 
-```
- python -m torch.distributed.launch --nproc_per_node=2 train.py hparams/sepformer.yaml --data_folder /yourdatapath --distributed_launch --distributed_backend='nccl'
+```bash
+torchrun --nproc_per_node=2 train.py hparams/sepformer.yaml --data_folder /yourdatapath
 ```
 You can add the other runtime options as appropriate. For more complete information on multi-GPU usage, take a look at this [tutorial](https://colab.research.google.com/drive/13pBUacPiotw1IvyffvGZ-HrtBr9T6l15).
 
diff --git a/recipes/Aishell1Mix/separation/dynamic_mixing.py b/recipes/Aishell1Mix/separation/dynamic_mixing.py
index 990aca19bb49d0eed394cd619f4e4cd3316c2bcc..3b95a81be2a10e0a3178ff238b8ec1bb43314e1c 100644
--- a/recipes/Aishell1Mix/separation/dynamic_mixing.py
+++ b/recipes/Aishell1Mix/separation/dynamic_mixing.py
@@ -1,3 +1,10 @@
+"""
+The file implement Dynamic Mixing For SpeechSeparation
+
+Authors
+    * Samuele Cornell 2021
+    * Cem Subakan 2021
+"""
 import speechbrain as sb
 import numpy as np
 import torch
@@ -10,13 +17,6 @@ import warnings
 import pyloudnorm
 import random
 
-"""
-The functions to implement Dynamic Mixing For SpeechSeparation
-Authors
-    * Samuele Cornell 2021
-    * Cem Subakan 2021
-"""
-
 
 def build_spk_hashtable_aishell1mix(hparams):
     """
diff --git a/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2-wham.yaml b/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2-wham.yaml
index 0b877728798c2a2b707565ee58677fe0c104a1fd..d3cb9493e4ce2a4a9a83ef00f7379fe6c58c2df4 100644
--- a/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2-wham.yaml
+++ b/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2-wham.yaml
@@ -35,12 +35,12 @@ skip_prep: False
 ckpt_interval_minutes: 60
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -63,18 +63,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
@@ -162,7 +182,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         decoder: !ref <Decoder>
         masknet: !ref <MaskNet>
         counter: !ref <epoch_counter>
-        # lr_scheduler: !ref <lr_scheduler>
 
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
diff --git a/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2.yaml b/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2.yaml
index e80fd9ef3bfa6d159a106e28a34d09c871fb7802..168471dbb1d2f12207ec51e1f7a2eb0e89afe096 100644
--- a/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2.yaml
+++ b/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2.yaml
@@ -35,12 +35,12 @@ skip_prep: False
 ckpt_interval_minutes: 60
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -63,18 +63,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
@@ -162,7 +182,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         decoder: !ref <Decoder>
         masknet: !ref <MaskNet>
         counter: !ref <epoch_counter>
-        # lr_scheduler: !ref <lr_scheduler>
 
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
diff --git a/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3-wham.yaml b/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3-wham.yaml
index 91235f63e9ef1457b2f41e655fed44d6b06cf0b1..834857ed77f61a2a2da292de742fae950c2bd05e 100644
--- a/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3-wham.yaml
+++ b/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3-wham.yaml
@@ -35,12 +35,12 @@ skip_prep: False
 ckpt_interval_minutes: 60
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -63,18 +63,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3.yaml b/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3.yaml
index 27d78b7eabe546dbbacb7cd819c61eba072cef8a..d48fdecb214c6acef700f1186817142a90af2f0d 100644
--- a/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3.yaml
+++ b/recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3.yaml
@@ -35,12 +35,12 @@ skip_prep: False
 ckpt_interval_minutes: 60
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -63,18 +63,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/Aishell1Mix/separation/train.py b/recipes/Aishell1Mix/separation/train.py
index 2ea258501c7370470d9718a8d4b8cf004aeb737d..4d34f6f9ba0200736e45db2fd20675fe56ed981e 100644
--- a/recipes/Aishell1Mix/separation/train.py
+++ b/recipes/Aishell1Mix/separation/train.py
@@ -30,11 +30,11 @@ import numpy as np
 from tqdm import tqdm
 import speechbrain as sb
 import torch.nn.functional as F
-from torch.cuda.amp import autocast
 import speechbrain.nnet.schedulers as schedulers
 from speechbrain.utils.distributed import run_on_main
 from hyperpyyaml import load_hyperpyyaml
 import logging
+from speechbrain.core import AMPConfig
 
 
 # from: recipes/LibriMix/separation/train.py
@@ -73,7 +73,8 @@ class Separation(sb.Brain):
                         targets = targets[:, :min_len, :]
 
                 if self.hparams.use_wavedrop:
-                    mix = self.hparams.wavedrop(mix, mix_lens)
+                    mix = self.hparams.drop_chunk(mix, mix_lens)
+                    mix = self.hparams.drop_freq(mix)
 
                 if self.hparams.limit_training_signal_len:
                     mix, targets = self.cut_signals(mix, targets)
@@ -109,6 +110,9 @@ class Separation(sb.Brain):
 
     def fit_batch(self, batch):
         """Trains one batch"""
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
+
         # Unpacking batch list
         mixture = batch.mix_sig
         targets = [batch.s1_sig, batch.s2_sig]
@@ -120,14 +124,51 @@ class Separation(sb.Brain):
         if self.hparams.num_spks == 3:
             targets.append(batch.s3_sig)
 
-        if self.auto_mix_prec:
-            with autocast():
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
+                    predictions, targets = self.compute_forward(
+                        mixture, targets, sb.Stage.TRAIN, noise
+                    )
+                    loss = self.compute_objectives(predictions, targets)
+
+                    # hard threshold the easy dataitems
+                    if self.hparams.threshold_byloss:
+                        th = self.hparams.threshold
+                        loss_to_keep = loss[loss > th]
+                        if loss_to_keep.nelement() > 0:
+                            loss = loss_to_keep.mean()
+                    else:
+                        loss = loss.mean()
+
+                if (
+                    loss < self.hparams.loss_upper_lim and loss.nelement() > 0
+                ):  # the fix for computational problems
+                    self.scaler.scale(loss).backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        self.scaler.unscale_(self.optimizer)
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.scaler.step(self.optimizer)
+                    self.scaler.update()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
+                    )
+                    loss.data = torch.tensor(0).to(self.device)
+            else:
                 predictions, targets = self.compute_forward(
                     mixture, targets, sb.Stage.TRAIN, noise
                 )
                 loss = self.compute_objectives(predictions, targets)
 
-                # hard threshold the easy dataitems
                 if self.hparams.threshold_byloss:
                     th = self.hparams.threshold
                     loss_to_keep = loss[loss > th]
@@ -136,56 +177,24 @@ class Separation(sb.Brain):
                 else:
                     loss = loss.mean()
 
-            if (
-                loss < self.hparams.loss_upper_lim and loss.nelement() > 0
-            ):  # the fix for computational problems
-                self.scaler.scale(loss).backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    self.scaler.unscale_(self.optimizer)
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm,
-                    )
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
-                    )
-                )
-                loss.data = torch.tensor(0).to(self.device)
-        else:
-            predictions, targets = self.compute_forward(
-                mixture, targets, sb.Stage.TRAIN, noise
-            )
-            loss = self.compute_objectives(predictions, targets)
-
-            if self.hparams.threshold_byloss:
-                th = self.hparams.threshold
-                loss_to_keep = loss[loss > th]
-                if loss_to_keep.nelement() > 0:
-                    loss = loss_to_keep.mean()
-            else:
-                loss = loss.mean()
-
-            if (
-                loss < self.hparams.loss_upper_lim and loss.nelement() > 0
-            ):  # the fix for computational problems
-                loss.backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm
-                    )
-                self.optimizer.step()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
+                if (
+                    loss < self.hparams.loss_upper_lim and loss.nelement() > 0
+                ):  # the fix for computational problems
+                    loss.backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.optimizer.step()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
                     )
-                )
-                loss.data = torch.tensor(0).to(self.device)
+                    loss.data = torch.tensor(0).to(self.device)
         self.optimizer.zero_grad()
 
         return loss.detach().cpu()
@@ -261,9 +270,7 @@ class Separation(sb.Brain):
             recombine = True
 
             for i in range(targets.shape[-1]):
-                new_target = self.hparams.speedperturb(
-                    targets[:, :, i], targ_lens
-                )
+                new_target = self.hparams.speed_perturb(targets[:, :, i],)
                 new_targets.append(new_target)
                 if i == 0:
                     min_len = new_target.shape[-1]
@@ -554,7 +561,6 @@ if __name__ == "__main__":
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
-    run_opts["auto_mix_prec"] = hparams["auto_mix_prec"]
 
     # Initialize ddp (useful only for multi-GPU DDP training)
     sb.utils.distributed.ddp_init_group(run_opts)
@@ -562,6 +568,8 @@ if __name__ == "__main__":
     # Logger info
     logger = logging.getLogger(__name__)
 
+    # If device is cpu use precision='bf16'
+
     # Create experiment directory
     sb.create_experiment_directory(
         experiment_directory=hparams["output_folder"],
@@ -569,6 +577,10 @@ if __name__ == "__main__":
         overrides=overrides,
     )
 
+    # Update precision to bf16 if the device is CPU and precision is fp16
+    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
+        hparams["precision"] = "bf16"
+
     # Data preparation
     from prepare_data import prepare_aishell1mix
 
diff --git a/recipes/AudioMNIST/diffusion/hparams/train.yaml b/recipes/AudioMNIST/diffusion/hparams/train.yaml
index eb59ccc12859648ed5deb6c64b73c3bfeac2609b..90f4d347c5b652e482dc6d09ead166fd0d67aa55 100644
--- a/recipes/AudioMNIST/diffusion/hparams/train.yaml
+++ b/recipes/AudioMNIST/diffusion/hparams/train.yaml
@@ -93,7 +93,6 @@ pad_level_db: -50.
 
 # Model Parameters
 model_channels: 128
-model_norm_num_groups: 32
 model_num_res_blocks: 4
 diffusion_channels: 1
 
@@ -208,7 +207,7 @@ diffusion_sample_channels: !ref <diffusion_channels>
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-vocoder: !name:speechbrain.pretrained.interfaces.HIFIGAN.from_hparams
+vocoder: !name:speechbrain.inference.vocoders.HIFIGAN.from_hparams
     source: !ref <vocoder_model>
 
 
diff --git a/recipes/AudioMNIST/diffusion/hparams/train_latent.yaml b/recipes/AudioMNIST/diffusion/hparams/train_latent.yaml
index 254a930c5f95c186ec34e32fe427310ea01658bc..d346cebb6e23b565ca4461b6e783c258ff430088 100644
--- a/recipes/AudioMNIST/diffusion/hparams/train_latent.yaml
+++ b/recipes/AudioMNIST/diffusion/hparams/train_latent.yaml
@@ -115,7 +115,6 @@ autoencoder_channels: 32
 autoencoder_norm_num_groups: 32
 autoencoder_num_res_blocks: 1
 autoencoder_encoder_out_channels: 32
-autoencoder_nom_num_groups: 32
 autoencoder_latent_channels: 2
 autoencoder_dropout: 0.1
 latent_mask_value: -3.
@@ -262,7 +261,7 @@ done_detector: !new:speechbrain.nnet.utils.DoneDetector
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-vocoder: !name:speechbrain.pretrained.interfaces.HIFIGAN.from_hparams
+vocoder: !name:speechbrain.inference.vocoders.HIFIGAN.from_hparams
     source: !ref <vocoder_model>
 
 
diff --git a/recipes/AudioMNIST/diffusion/train.py b/recipes/AudioMNIST/diffusion/train.py
index 82063c89d6e87ad1c0dc4273923cb5c5bf3c45ab..70637eaed4f19957fc020aa0310feb7387ee8647 100644
--- a/recipes/AudioMNIST/diffusion/train.py
+++ b/recipes/AudioMNIST/diffusion/train.py
@@ -92,21 +92,20 @@ class DiffusionBrain(sb.Brain):
         hparams=None,
         run_opts=None,
         checkpointer=None,
-        profiler=None,
     ):
-        super().__init__(
-            modules, opt_class, hparams, run_opts, checkpointer, profiler
-        )
+        super().__init__(modules, opt_class, hparams, run_opts, checkpointer)
         self.diffusion_mode = DiffusionMode(self.hparams.diffusion_mode)
         self.use_done_detector = "done_detector" in self.modules
 
     def init_optimizers(self):
         """Initializes the diffusion model optimizer - and the
         autoencoder optimizer, if applicable"""
+        self.optimizers_dict = {}
         if self.opt_class is not None:
             self.optimizer = self.opt_class(self.modules.unet.parameters())
             if self.checkpointer is not None:
                 self.checkpointer.add_recoverable("optimizer", self.optimizer)
+            self.optimizers_dict["opt_class"] = self.optimizer
 
         if self.use_done_detector:
             self.optimizer_done = self.hparams.opt_class_done(
@@ -116,6 +115,7 @@ class DiffusionBrain(sb.Brain):
                 self.checkpointer.add_recoverable(
                     "optimizer_done", self.optimizer
                 )
+            self.optimizers_dict["opt_class_done"] = self.optimizer_done
 
         if self.diffusion_mode == DiffusionMode.LATENT:
             self.autoencoder_optimizer = self.hparams.opt_class_autoencoder(
@@ -125,6 +125,9 @@ class DiffusionBrain(sb.Brain):
                 self.checkpointer.add_recoverable(
                     "autoencoder_optimizer", self.autoencoder_optimizer
                 )
+            self.optimizers_dict[
+                "opt_class_autoencoder"
+            ] = self.autoencoder_optimizer
 
     def compute_forward(self, batch, stage):
         """Runs all the computation of that transforms the input into the
@@ -315,7 +318,7 @@ class DiffusionBrain(sb.Brain):
                 )
 
         if should_step:
-            if self.train_diffusion and self.check_gradients(loss):
+            if self.train_diffusion:
                 self.optimizer.step()
             self.optimizer.zero_grad()
 
@@ -330,7 +333,10 @@ class DiffusionBrain(sb.Brain):
         ):
             with self.no_sync(not should_step):
                 (loss_autoencoder / self.grad_accumulation_factor).backward()
-            if should_step and self.check_gradients(loss_autoencoder):
+            if should_step:
+                torch.nn.utils.clip_grad_norm_(
+                    self.modules.parameters(), self.max_grad_norm
+                )
                 self.autoencoder_optimizer.step()
             self.autoencoder_optimizer.zero_grad()
 
@@ -1463,7 +1469,12 @@ def dataio_prep(hparams):
 
     train_split = dataset_splits["train"]
     data_count = None
-    train_split = apply_overfit_test(hparams, train_split)
+    train_split = apply_overfit_test(
+        hparams["overfit_test"],
+        hparams["overfit_test_sample_count"],
+        hparams["overfit_test_epoch_data_count"],
+        train_split,
+    )
 
     if hparams["train_data_count"] is not None:
         data_count = hparams["train_data_count"]
diff --git a/recipes/BinauralWSJ0Mix/extra_requirements.txt b/recipes/BinauralWSJ0Mix/extra_requirements.txt
index 7ccc17fff47d099410dfe83d2b65443deb65b5b4..3cfaa1241dff68b6fc36a6207662c6030ea41308 100644
--- a/recipes/BinauralWSJ0Mix/extra_requirements.txt
+++ b/recipes/BinauralWSJ0Mix/extra_requirements.txt
@@ -1,4 +1,4 @@
-gitpython==3.1.32
+gitpython==3.1.37
 mir-eval==0.6
 pyroomacoustics>=0.7.3
 
diff --git a/recipes/BinauralWSJ0Mix/separation/README.md b/recipes/BinauralWSJ0Mix/separation/README.md
index 3cc01273366828505d4d2d55d533154c76e4f78b..ba0abfd2a6bf49ac6633cd31721e8db1498c0d22 100644
--- a/recipes/BinauralWSJ0Mix/separation/README.md
+++ b/recipes/BinauralWSJ0Mix/separation/README.md
@@ -69,8 +69,8 @@ The output folders with the checkpoints, logs, etc are available [here](https://
 
 You can run the following command to train the model using Distributed Data Parallel (DDP) with 2 GPUs:
 
-```
- python -m torch.distributed.launch --nproc_per_node=2 train.py hparams/convtasnet-parallel.yaml --data_folder /yourdatapath --distributed_launch --distributed_backend='nccl'
+```bash
+torchrun --nproc_per_node=2 train.py hparams/convtasnet-parallel.yaml --data_folder /yourdatapath
 ```
 You can add the other runtime options as appropriate. For more complete information on multi-GPU usage, take a look at this [tutorial](https://colab.research.google.com/drive/13pBUacPiotw1IvyffvGZ-HrtBr9T6l15).
 
diff --git a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-cross.yaml b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-cross.yaml
index ea4e4dbe3ae55c856d35b8c79fb2721275ed326f..043845aebd245c6ef40c8a41b9635cac018f02cd 100644
--- a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-cross.yaml
+++ b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-cross.yaml
@@ -37,13 +37,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: True # Save estimated sources on disk
 n_audio_to_save: 10
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -65,18 +65,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-independent.yaml b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-independent.yaml
index eb6caa8e9ac9eb72bf0396bd1e288ef2b79e9460..164ccc45b17fd617f2c6a6342bbf23d26001f68f 100644
--- a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-independent.yaml
+++ b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-independent.yaml
@@ -37,13 +37,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: True # Save estimated sources on disk
 n_audio_to_save: 10
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -65,18 +65,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-noise.yaml b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-noise.yaml
index f5413fa1674c7d5a0dc8399b3228d4390205046a..fef85267f644aad3f01047cc64100e7a55ba8beb 100644
--- a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-noise.yaml
+++ b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-noise.yaml
@@ -37,13 +37,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: True # Save estimated sources on disk
 n_audio_to_save: 10
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -65,18 +65,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-reverb.yaml b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-reverb.yaml
index 2d455b8de819ebc0ff39004825131cdaa03f77f4..4ec5054f982631250b29b6c03147ced3fbddb586 100644
--- a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-reverb.yaml
+++ b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-reverb.yaml
@@ -37,13 +37,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: True # Save estimated sources on disk
 n_audio_to_save: 10
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -65,18 +65,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel.yaml b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel.yaml
index 3347da4bebe7c13693721f15c44305bf25a357b4..adb31ddc68fd3284774c58e4e60b3a28b65457c5 100644
--- a/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel.yaml
+++ b/recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel.yaml
@@ -37,13 +37,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: True # Save estimated sources on disk
 n_audio_to_save: 10
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -65,18 +65,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/BinauralWSJ0Mix/separation/train.py b/recipes/BinauralWSJ0Mix/separation/train.py
index 30711f81e72b93aec296c4c44018783b20479f43..99ec4253189e100c608d8e567c9c32680f747cda 100644
--- a/recipes/BinauralWSJ0Mix/separation/train.py
+++ b/recipes/BinauralWSJ0Mix/separation/train.py
@@ -29,7 +29,6 @@ import torchaudio
 import speechbrain as sb
 import speechbrain.nnet.schedulers as schedulers
 from speechbrain.utils.distributed import run_on_main
-from torch.cuda.amp import autocast
 from hyperpyyaml import load_hyperpyyaml
 import numpy as np
 from tqdm import tqdm
@@ -38,6 +37,7 @@ import logging
 from pyroomacoustics.experimental.localization import tdoa
 from speechbrain.processing.features import STFT, spectral_magnitude
 from torch.nn import Conv1d
+from speechbrain.core import AMPConfig
 
 logger = logging.getLogger(__name__)
 
@@ -80,7 +80,8 @@ class Separation(sb.Brain):
                         targets = targets[:, :min_len, :]
 
                 if self.hparams.use_wavedrop:
-                    mix = self.hparams.wavedrop(mix, mix_lens)
+                    mix = self.hparams.drop_chunk(mix, mix_lens)
+                    mix = self.hparams.drop_freq(mix)
 
                 if self.hparams.limit_training_signal_len:
                     mix, targets = self.cut_signals(mix, targets)
@@ -197,6 +198,9 @@ class Separation(sb.Brain):
 
     def fit_batch(self, batch):
         """Trains one batch"""
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
+
         # Unpacking batch list
         mixture = batch.mix_sig
         targets = [batch.s1_sig, batch.s2_sig]
@@ -208,14 +212,51 @@ class Separation(sb.Brain):
         if "noise" in self.hparams.experiment_name:
             noise = batch.noise_sig[0]
 
-        if self.auto_mix_prec:
-            with autocast():
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
+                    predictions, targets = self.compute_forward(
+                        mixture, targets, sb.Stage.TRAIN, noise
+                    )
+                    loss = self.compute_objectives(predictions, targets)
+
+                    # hard threshold the easy dataitems
+                    if self.hparams.threshold_byloss:
+                        th = self.hparams.threshold
+                        loss = loss[loss > th]
+                        if loss.nelement() > 0:
+                            loss = loss.mean()
+                    else:
+                        loss = loss.mean()
+
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    self.scaler.scale(loss).backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        self.scaler.unscale_(self.optimizer)
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.scaler.step(self.optimizer)
+                    self.scaler.update()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
+                    )
+                    loss.data = torch.tensor(0).to(self.device)
+            else:
                 predictions, targets = self.compute_forward(
                     mixture, targets, sb.Stage.TRAIN, noise
                 )
                 loss = self.compute_objectives(predictions, targets)
 
-                # hard threshold the easy dataitems
                 if self.hparams.threshold_byloss:
                     th = self.hparams.threshold
                     loss = loss[loss > th]
@@ -224,56 +265,24 @@ class Separation(sb.Brain):
                 else:
                     loss = loss.mean()
 
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                self.scaler.scale(loss).backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    self.scaler.unscale_(self.optimizer)
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm,
-                    )
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
-                    )
-                )
-                loss.data = torch.tensor(0).to(self.device)
-        else:
-            predictions, targets = self.compute_forward(
-                mixture, targets, sb.Stage.TRAIN, noise
-            )
-            loss = self.compute_objectives(predictions, targets)
-
-            if self.hparams.threshold_byloss:
-                th = self.hparams.threshold
-                loss = loss[loss > th]
-                if loss.nelement() > 0:
-                    loss = loss.mean()
-            else:
-                loss = loss.mean()
-
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                loss.backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm
-                    )
-                self.optimizer.step()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    loss.backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.optimizer.step()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
                     )
-                )
-                loss.data = torch.tensor(0).to(self.device)
+                    loss.data = torch.tensor(0).to(self.device)
         self.optimizer.zero_grad()
 
         return loss.detach().cpu()
@@ -349,9 +358,7 @@ class Separation(sb.Brain):
             recombine = True
 
             for i in range(targets.shape[-1]):
-                new_target = self.hparams.speedperturb(
-                    targets[:, :, :, i], targ_lens
-                )
+                new_target = self.hparams.speed_perturb(targets[:, :, :, i])
                 new_targets.append(new_target)
                 if i == 0:
                     min_len = new_target.shape[1]
diff --git a/recipes/CVSS/S2ST/README.md b/recipes/CVSS/S2ST/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d2c9172a13f9d1f9d6fe77209f7c9170eb1910d1
--- /dev/null
+++ b/recipes/CVSS/S2ST/README.md
@@ -0,0 +1,65 @@
+# Speech-to-Speech Translation (with CVSS)
+This folder contains the recipe for training a speech-to-unit translation (S2UT) model using a pre-trained Wav2Vec 2.0 encoder and a transformer decoder on the CVSS dataset.
+The implementation is based on [Textless Speech-to-Speech Translation](https://arxiv.org/abs/2112.08352) and [Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentatio](https://arxiv.org/abs/2204.02967) papers.
+
+## Dataset
+[CVSS](https://github.com/google-research-datasets/cvss) is a massively multilingual-to-English speech-to-speech translation corpus. It covers pairs from 21 languages into English. CVSS is derived from the Common Voice speech corpus and the CoVoST 2 speech-to-text translation corpus.
+The CVSS dataset includes two versions of spoken translation: CVSS-C and CVSS-T. While both versions can be utilized to train the S2UT model, we recommend using CVSS-C because of its superior speech quality.
+
+The first step is to select a source language and download [Common Voice (version 4)](https://commonvoice.mozilla.org/en/datasets) for the chosen language code. In this recipe, we've chosen French as the source language.
+The next step is to pair translation audio clips with the source speech by downloading the corresponding subset of the [CVSS dataset](https://github.com/google-research-datasets/cvss). In our case, we have to download the French CVSS-C subset, which corresponds to the English translation of the French portion of the Common Voice dataset.
+
+At this point, you should have two distinct folders: the first one containing the Common Voice data and the second one containing the CVSS data.
+
+> Note: In the recipe, we frequently employ the terms `src_data` and `tgt_data`.
+> `src_data` refers to the source language data (Common Voice).
+> `tgt_data` refers to the target language data (CVSS).
+
+## Installing Extra Dependencies
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+```
+pip install -r extra_requirements.txt
+```
+
+## How to Run
+Before training the speech-to-unit translation (S2UT) model, we have to quantize the target speech into discrete speech units. This is achieved by training a k-means model on raw speech features, which will then serve as the target for training the S2UT model. By default, we use a pre-trained model with `k=100` trained on the 6th layer of HuBERT. For instructions on training a new quantization model, please refer to `recipes/LJSpeech/TTS/quantization`.
+
+To train the S2UT model on French-English, simply run the following command:
+```
+python train.py hparams/train_fr-en.yaml --src_data_folder=/corpus/CommonVoice/fr --tgt_data_folder=/corpus/CVSS/fr --bfloat16_mix_prec
+```
+
+>  Dynamic batch settings are optimized for a 40GB VRAM GPU. Don't hesitate to adjust max_batch_len and max_batch_len_val to fit your GPU's capabilities.
+
+
+# Performance summary
+Results are reported in terms of sacrebleu.
+
+| hyperparams file | valid | test   | Model      | Training logs | GPUs       |
+|:----------------:|:-----:| :-----:|:-------:   | :-----------: |:---------: |
+| train_fr-en.yaml | 24.25   | 24.47    | [dropbox](https://www.dropbox.com/sh/woz4i1p8pkfkqhf/AACmOvr3sS7p95iXl3twCj_xa?dl=0) | [wandb](https://wandb.ai/jar0d/s2ut_cvss_sb/runs/uh4tvc8c?workspace=user-jar0d)    |1xA100 80GB |
+
+Training requires about 1 hour and 5 minutes for each epoch on an NVIDIA A100 GPU. A total of 30 epochs are needed.
+
+To synthesize speech from the predicted speech units, you need to train a unit-based HiFi-GAN vocoder. If you haven't done this already, please refer to the `LJSpeech/TTS/vocoder/unit_hifi_gan` recipe.
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/CVSS/S2ST/cvss_prepare.py b/recipes/CVSS/S2ST/cvss_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..7f1db79de657a1b2ae61b48a43ae85f14c2eec5d
--- /dev/null
+++ b/recipes/CVSS/S2ST/cvss_prepare.py
@@ -0,0 +1 @@
+../cvss_prepare.py
\ No newline at end of file
diff --git a/recipes/CVSS/S2ST/extra_requirements.txt b/recipes/CVSS/S2ST/extra_requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..454a21c03558b3210da0efe42581688e9cfc4b1d
--- /dev/null
+++ b/recipes/CVSS/S2ST/extra_requirements.txt
@@ -0,0 +1 @@
+sacrebleu
diff --git a/recipes/CVSS/S2ST/extract_code.py b/recipes/CVSS/S2ST/extract_code.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4d46cb7a5b653c21cf741a34bed5bfbfca927c8
--- /dev/null
+++ b/recipes/CVSS/S2ST/extract_code.py
@@ -0,0 +1,223 @@
+"""
+Apply K-means clustering over acoustic features to extract speech units for training the speech-to-unit translation model.
+
+Authors
+ * Jarod Duret 2023
+"""
+
+import logging
+import json
+import pathlib as pl
+
+import joblib
+import torch
+import torchaudio
+import numpy as np
+from tqdm import tqdm
+import speechbrain as sb
+from speechbrain.dataio.dataio import (
+    load_pkl,
+    save_pkl,
+)
+from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import Wav2Vec2
+from huggingface_hub import hf_hub_download
+
+OPT_FILE = "opt_cvss_extract.pkl"
+TRAIN_JSON = "train.json"
+VALID_JSON = "valid.json"
+VALID_SMALL = "valid_small.json"
+TEST_JSON = "test.json"
+
+
+def setup_logger():
+    """Set up a logger with a log format and logging level."""
+    log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+    logging.basicConfig(format=log_format, level=logging.INFO)
+    logger = logging.getLogger(__name__)
+    return logger
+
+
+def get_device(use_cuda):
+    """Determine and return the appropriate device for computation."""
+    use_cuda = use_cuda and torch.cuda.is_available()
+    print("\n" + "=" * 30)
+    print("USE_CUDA SET TO: {}".format(use_cuda))
+    print("CUDA AVAILABLE?: {}".format(torch.cuda.is_available()))
+    print("=" * 30 + "\n")
+    return torch.device("cuda" if use_cuda else "cpu")
+
+
+def np_array(tensor):
+    """Convert a Pytorch tensor to a Numpy array."""
+    tensor = tensor.squeeze(0)
+    tensor = tensor.detach().cpu()
+    return tensor.numpy()
+
+
+def skip(splits, save_folder, conf):
+    """
+    Detects if the ljspeech data_extraction has been already done.
+    If the extraction has been done, we can skip it.
+
+    Returns
+    -------
+    bool
+        if True, the preparation phase can be skipped.
+        if False, it must be done.
+    """
+    # Checking json files
+    skip = True
+
+    split_files = {
+        "train": TRAIN_JSON,
+        "valid": VALID_JSON,
+        "valid_small": VALID_SMALL,
+        "test": TEST_JSON,
+    }
+
+    for split in splits:
+        if not (save_folder / split_files[split]).exists():
+            skip = False
+
+    code_folder = save_folder / "codes"
+    if not code_folder.exists():
+        skip = False
+
+    #  Checking saved options
+    save_opt = save_folder / OPT_FILE
+    if skip is True:
+        if save_opt.is_file():
+            opts_old = load_pkl(save_opt.as_posix())
+            if opts_old == conf:
+                skip = True
+            else:
+                skip = False
+        else:
+            skip = False
+    return skip
+
+
+def extract_cvss(
+    data_folder,
+    splits,
+    kmeans_folder,
+    encoder,
+    layer,
+    save_folder,
+    sample_rate=16000,
+    skip_extract=False,
+):
+    """
+    Extract speech units for HiFi-GAN training on the CVSS datasets.
+
+    Arguments
+    ---------
+    data_folder : str
+        Path to the folder where the original CVSS dataset is stored.
+    splits : list
+        List of splits to prepare.
+    kmeans_folder: str
+        Path to the folder where the k-means model checkpoint is stored.
+    encoder: str
+        Url to the model used as feature extractor.
+    layer: int
+        Layer from which features are extracted.
+    save_folder: str
+        Path to the folder where the speech units are stored.
+    sample_rate: int
+        CVSS dataset sample rate
+    skip_extract: Bool
+        If True, skip extraction.
+
+    Example
+    -------
+    >>> from recipes.CVSS.S2ST.extract_code import extract_cvss
+    >>> data_folder = 'data/CVSS/'
+    >>> splits = ['train', 'valid']
+    >>> kmeans_folder = ./Quantization/results/kmeans/4321/save
+    >>> encoder = facebook/hubert-base-ls960
+    >>> layer = 6
+    >>> save_folder = 'save/'
+    >>> extract_cvss(data_folder, splits, kmeans_folder, encoder, layer, save_folder)
+    """
+    logger = setup_logger()
+
+    if skip_extract:
+        return
+    # Create configuration for easily skipping code extraction stage
+    conf = {
+        "data_folder": data_folder,
+        "splits": splits,
+        "save_folder": save_folder,
+        "kmeans_folder": kmeans_folder,
+        "encoder": encoder,
+        "layer": layer,
+    }
+
+    save_folder = pl.Path(save_folder)
+    # Check if this phase is already done (if so, skip it)
+    if skip(splits, save_folder, conf):
+        logger.info("Skipping code extraction, completed in previous run.")
+        return
+
+    # Fetch device
+    device = get_device(use_cuda=True)
+
+    save_opt = save_folder / OPT_FILE
+    data_folder = pl.Path(data_folder)
+
+    # Fetch K-means model
+    kmeans_folder = pl.Path(kmeans_folder)
+    kmeans_ckpt = kmeans_folder / "kmeans.ckpt"
+    if not kmeans_ckpt.exists():
+        logger.info("K-means checkpoint not found, downloading it from HF.")
+        kmeans_download_path = save_folder / "pretrained_models/quantization"
+        kmeans_download_path.mkdir(exist_ok=True, parents=True)
+        hf_hub_download(
+            repo_id=kmeans_folder.as_posix(),
+            filename="kmeans.ckpt",
+            local_dir=kmeans_download_path,
+        )
+        kmeans_ckpt = kmeans_download_path / "kmeans.ckpt"
+
+    encoder_save_path = save_folder / "pretrained_models"
+    code_folder = save_folder / "codes"
+    code_folder.mkdir(parents=True, exist_ok=True)
+
+    logger.info(f"Loading encoder: {encoder} ...")
+    encoder = Wav2Vec2(
+        encoder,
+        encoder_save_path.as_posix(),
+        output_all_hiddens=True,
+        output_norm=False,
+        freeze_feature_extractor=True,
+        freeze=True,
+    ).to(device)
+
+    # K-means model
+    logger.info(f"Loading K-means model from {kmeans_ckpt} ...")
+    kmeans_model = joblib.load(open(kmeans_ckpt, "rb"))
+    kmeans_model.verbose = False
+
+    for split in splits:
+        dataset_path = data_folder / f"{split}.json"
+        logger.info(f"Reading dataset from {dataset_path} ...")
+        meta_json = json.load(open(dataset_path))
+        for key in tqdm(meta_json.keys()):
+            item = meta_json[key]
+            wav = item["tgt_audio"]
+            with torch.no_grad():
+                info = torchaudio.info(wav)
+                audio = sb.dataio.dataio.read_audio(wav)
+                audio = torchaudio.transforms.Resample(
+                    info.sample_rate, sample_rate,
+                )(audio)
+                audio = audio.unsqueeze(0).to(device)
+                feats = encoder.extract_features(audio)
+                feats = feats[layer]
+                feats = np_array(feats)
+            pred = kmeans_model.predict(feats)
+            np.save(code_folder / f"{key}_tgt.npy", pred)
+
+    logger.info("Extraction completed.")
+    save_pkl(conf, save_opt)
diff --git a/recipes/CVSS/S2ST/hparams/train_fr-en.yaml b/recipes/CVSS/S2ST/hparams/train_fr-en.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..678dd1c178017fd8f71e7fbbbfccdf048c4cfdd9
--- /dev/null
+++ b/recipes/CVSS/S2ST/hparams/train_fr-en.yaml
@@ -0,0 +1,234 @@
+############################################################################
+# Model: Speech-to-Unit Translation (S2UT)
+# Language: French-English (Fr-En)
+# Training: CVSS
+# Authors: Jarod Duret
+# ############################################################################
+
+###################################
+# Experiment Parameters and setup #
+###################################
+seed: 888
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/s2ut/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+epochs: 30
+
+progress_samples: True
+progress_sample_path: !ref <output_folder>/samples
+progress_samples_interval: 1
+progress_batch_sample_size: 4
+
+evaluation_interval: 4
+
+#################################
+# Data files and pre-processing #
+#################################
+src_data_folder: !PLACEHOLDER # e.g, /corpus/CommonVoice/fr (French Data)
+tgt_data_folder: !PLACEHOLDER # e.g, /corpus/CV4/fr (English Data)
+sample_rate: 16000
+
+train_json: !ref <save_folder>/train.json
+valid_json: !ref <save_folder>/valid.json
+valid_small_json: !ref <save_folder>/valid_small.json
+test_json: !ref <save_folder>/test.json
+splits: ["train", "valid_small", "valid", "test"]
+skip_prep: False
+
+# SSL model used to encode target features
+encoder_source: facebook/hubert-base-ls960
+layer: 6
+kmeans_source: speechbrain/tts-hifigan-unit-hubert-l6-k100-ljspeech
+codes_folder: !ref <save_folder>/codes
+skip_extract: False
+
+# Vocoder model used for evaluation
+vocoder_source: speechbrain/tts-hifigan-unit-hubert-l6-k100-ljspeech
+vocoder_download_path: !ref <save_folder>/pretrained_models/vocoder
+
+# ASR model used for evaluation
+asr_source: speechbrain/asr-transformer-transformerlm-librispeech
+asr_download_path: !ref <save_folder>/pretrained_models/asr
+
+# Wav2vec2 encoder
+wav2vec2_source: LeBenchmark/wav2vec2-FR-7K-large
+wav2vec2_download_path: !ref <save_folder>/pretrained_models
+
+# wav2vec2 encoder specific parameters
+wav2vec2_frozen: False
+wav2vec2_freeze_steps: 10000
+
+####################### Training Parameters ####################################
+lr: 0.0005
+lr_wav2vec: 0.00001
+loss_reduction: batchmean
+
+# Outputs
+# blank_index: 102
+bos_index: 100
+eos_index: 101
+pad_index: 102
+label_smoothing: 0.2
+
+# Dynamic batching
+sorting: random
+num_workers: 4
+dynamic_batching: True
+max_batch_len: 180 # 40 GB GPU
+num_bucket: 200
+
+train_batch_size: 32 # if not using dynamic batching
+valid_batch_size: 16
+
+dynamic_batch_sampler:
+    max_batch_len: !ref <max_batch_len>
+    num_buckets: !ref <num_bucket>
+    shuffle_ex: True # if true re-creates batches at each epoch shuffling examples.
+    batch_ordering: random
+    max_batch_ex: 128
+
+train_dataloader_opts:
+    batch_size: !ref <train_batch_size>
+    drop_last: False
+    num_workers: !ref <num_workers>
+    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
+        padding_kwargs:
+            value: !ref <pad_index>
+
+valid_dataloader_opts:
+    batch_size: !ref <valid_batch_size>
+    num_workers: !ref <num_workers>
+    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
+        padding_kwargs:
+            value: !ref <pad_index>
+
+################################
+# Model Parameters and model   #
+################################
+
+# Feature parameters (W2V2 etc)
+features_dim: 1024 # large wav2vec output dimension, for base replace by 768
+
+# Length Regulator
+enc_kernel_size: 3
+enc_stride: 2
+
+# Transformer
+embedding_size: 512
+d_model: 512
+nhead: 8
+num_encoder_layers: 0
+num_decoder_layers: 6
+d_ffn: 2048
+transformer_dropout: 0.1
+activation: !name:torch.nn.GELU
+output_neurons: 103 # /!\ needs to be changed accordingly to the vocabulary
+attention_type: "regularMHA" # "RelPosMHAXL" or "regularMHA"
+
+# Decoding parameters
+test_bs: 10
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+
+############################## models ################################
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <wav2vec2_source>
+    output_norm: True ### Test in baseline_v2
+    freeze: !ref <wav2vec2_frozen>
+    freeze_feature_extractor: False
+    save_path: !ref <wav2vec2_download_path>
+    apply_spec_augment: False
+
+enc: !new:speechbrain.nnet.CNN.Conv1d
+    input_shape: [null, null, !ref <features_dim>]
+    out_channels: !ref <embedding_size>
+    kernel_size: !ref <enc_kernel_size>
+    stride: !ref <enc_stride>
+
+transformer: !new:speechbrain.lobes.models.transformer.TransformerST.TransformerST # yamllint disable-line rule:line-length
+    input_size: !ref <embedding_size>
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    d_ffn: !ref <d_ffn>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    attention_type: !ref <attention_type>
+    normalize_before: True
+    causal: False
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+    apply_log: True
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+modules:
+    wav2vec2: !ref <wav2vec2>
+    enc: !ref <enc>
+    transformer: !ref <transformer>
+    seq_lin: !ref <seq_lin>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <enc>, !ref <transformer>, !ref <seq_lin>]
+
+opt_class: !name:torch.optim.AdamW
+    lr: !ref <lr>
+    betas: (0.9, 0.98)
+
+wav2vec_opt_class: !name:torch.optim.AdamW
+    lr: !ref <lr_wav2vec>
+
+seq_cost: !name:speechbrain.nnet.losses.nll_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr>
+    n_warmup_steps: 5000
+
+wav2vec_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr_wav2vec>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.98
+
+#epoch object
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <epochs>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+valid_search: !new:speechbrain.decoders.seq2seq.S2STransformerGreedySearch
+    modules: [!ref <transformer>, !ref <seq_lin>, null]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    temperature: 1.0
+
+test_search: !new:speechbrain.decoders.seq2seq.S2STransformerBeamSearcher
+    modules: [!ref <transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_bs>
+
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+bleu_computer: !name:speechbrain.utils.bleu.BLEUStats
+    merge_words: False
+
+#checkpointer
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        wav2vec2: !ref <wav2vec2>
+        counter: !ref <epoch_counter>
+        noam_scheduler: !ref <noam_annealing>
+        wav2vec_scheduler: !ref <wav2vec_annealing>
diff --git a/recipes/CVSS/S2ST/train.py b/recipes/CVSS/S2ST/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..63eeb0d05153601f56a032f3059af3d210fa047a
--- /dev/null
+++ b/recipes/CVSS/S2ST/train.py
@@ -0,0 +1,621 @@
+"""
+ Recipe for training the speech-to-unit translation (S2UT) model, the implementation is based on the following papers:
+ - Direct speech-to-speech translation with discrete units: (https://arxiv.org/abs/2006.04558)
+ - Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation: (https://arxiv.org/abs/2204.02967)
+ To run this recipe, do the following:
+ # python train.py hparams/train_fr-en.yaml --src_data_folder=/corpus/CommonVoice/fr --tgt_data_folder=/corpus/CVSS/fr
+
+ Authors
+ * Jarod Duret 2023
+"""
+
+import sys
+import torch
+import logging
+import pathlib as pl
+from hyperpyyaml import load_hyperpyyaml
+import speechbrain as sb
+from speechbrain.inference.vocoders import UnitHIFIGAN
+from speechbrain.inference.ASR import EncoderDecoderASR
+import tqdm
+import torchaudio
+import numpy as np
+from torch.nn.parallel import DistributedDataParallel
+
+logger = logging.getLogger(__name__)
+
+
+class S2UT(sb.core.Brain):
+    def compute_forward(self, batch, stage):
+        """Computes the forward pass.
+
+        Arguments
+        ---------
+        batch : torch.Tensor or tensors
+            An element from the dataloader, including inputs for processing.
+        stage : Stage
+            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
+
+        Returns
+        -------
+        (torch.Tensor or Tensors, list of float or None, list of str or None)
+            The outputs after all processing is complete.
+        """
+        batch = batch.to(self.device)
+        wavs, wav_lens = batch.src_sig
+        tokens_bos, _ = batch.code_bos
+
+        # Use default padding value for wav2vec2
+        wavs[wavs == self.hparams.pad_index] = 0.0
+
+        # compute features
+        enc_out = self.modules.wav2vec2(wavs, wav_lens)
+
+        # dimensionality reduction
+        enc_out = self.modules.enc(enc_out)
+
+        if isinstance(self.modules.transformer, DistributedDataParallel):
+            dec_out = self.modules.transformer.module.forward_mt_decoder_only(
+                enc_out, tokens_bos, pad_idx=self.hparams.pad_index
+            )
+        else:
+            dec_out = self.modules.transformer.forward_mt_decoder_only(
+                enc_out, tokens_bos, pad_idx=self.hparams.pad_index
+            )
+
+        # logits and softmax
+        pred = self.modules.seq_lin(dec_out)
+        p_seq = self.hparams.log_softmax(pred)
+
+        hyps = None
+        wavs = None
+        transcripts = None
+        if stage != sb.Stage.TRAIN:
+            if (
+                stage == sb.Stage.TEST
+                or self.hparams.epoch_counter.current
+                % self.hparams.evaluation_interval
+                == 0
+            ):
+                ids = batch.id
+                tgt_text = batch.tgt_text
+
+                search = (
+                    self.hparams.valid_search
+                    if stage == sb.Stage.VALID
+                    else self.hparams.test_search
+                )
+                hyps, _, _, _ = search(enc_out.detach(), wav_lens)
+
+                # generate speech and transcriptions
+                wavs = []
+                for hyp in hyps:
+                    if len(hyp) > 3:
+                        code = torch.LongTensor(hyp)
+                        wav = self.test_vocoder.decode_unit(code)
+                        wavs.append(wav.squeeze(0))
+                if wavs:
+                    wavs, wav_lens = sb.utils.data_utils.batch_pad_right(wavs)
+                    transcripts, _ = self.test_asr.transcribe_batch(
+                        wavs, wav_lens
+                    )
+                    transcripts = [
+                        transcript.lower() for transcript in transcripts
+                    ]
+
+                    self.bleu_metric.append(ids, transcripts, [tgt_text])
+
+        return (
+            p_seq,
+            wavs,
+            transcripts,
+        )
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss given the predicted and targeted outputs.
+        Arguments
+        ---------
+        predictions : torch.Tensor
+            The model generated spectrograms and other metrics from `compute_forward`.
+        batch : PaddedBatch
+            This batch object contains all the relevant tensors for computation.
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
+        Returns
+        -------
+        loss : torch.Tensor
+            A one-element tensor used for backpropagating the gradient.
+        """
+        (p_seq, wavs, transcripts) = predictions
+        tokens_eos, tokens_eos_lens = batch.code_eos
+        ids = batch.id
+
+        # speech translation loss
+        loss = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens)
+
+        if stage != sb.Stage.TRAIN:
+            if (
+                stage == sb.Stage.TEST
+                or self.hparams.epoch_counter.current
+                % self.hparams.evaluation_interval
+                == 0
+            ):
+                # compute the accuracy of the one-step-forward prediction
+                self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
+
+                tgt_wavs, _ = batch.tgt_sig
+                tgt_transcripts = batch.tgt_text
+
+                # Save last batch
+                wavs = [wav.cpu() for wav in wavs]
+                tgt_wavs = [wav.cpu() for wav in tgt_wavs]
+                self.last_batch = [
+                    ids,
+                    (wavs, transcripts),
+                    (tgt_transcripts, tgt_wavs),
+                ]
+
+        return loss
+
+    def freeze_optimizers(self, optimizers):
+        """Freezes the wav2vec2 optimizer according to the warmup steps"""
+        valid_optimizers = {}
+        if (
+            not self.hparams.wav2vec2_frozen
+            and self.optimizer_step >= self.hparams.wav2vec2_freeze_steps
+        ):
+            valid_optimizers["wav2vec_optimizer"] = optimizers[
+                "wav2vec_optimizer"
+            ]
+        valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+        return valid_optimizers
+
+    def init_optimizers(self):
+        """Called during ``on_fit_start()``, initialize optimizers
+        after parameters are fully configured (e.g. DDP, jit).
+        """
+        self.optimizers_dict = {}
+
+        # Initializes the wav2vec2 optimizer if the model is not wav2vec2_frozen
+        if not self.hparams.wav2vec2_frozen:
+            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
+                self.modules.wav2vec2.parameters()
+            )
+            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
+
+        self.model_optimizer = self.hparams.opt_class(
+            self.hparams.model.parameters()
+        )
+        self.optimizers_dict["model_optimizer"] = self.model_optimizer
+
+        if self.checkpointer is not None:
+            self.checkpointer.add_recoverable(
+                "wav2vec_optimizer", self.wav2vec_optimizer
+            )
+            self.checkpointer.add_recoverable(
+                "model_optimizer", self.model_optimizer
+            )
+
+    def on_fit_batch_start(self, batch, should_step):
+        """Called at the beginning of ``fit_batch()``.
+
+        Arguments
+        ---------
+        batch : list of torch.Tensors
+            Batch of data to use for training. Default implementation assumes
+            this batch has two elements: inputs and targets.
+        should_step : boolean
+            Whether optimizer.step() was called or not.
+        """
+        if self.optimizer_step == self.hparams.wav2vec2_freeze_steps:
+            logger.warning(
+                "speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 is unfrozen."
+            )
+
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """Called after ``fit_batch()``, meant for calculating and logging metrics.
+
+        Arguments
+        ---------
+        batch : list of torch.Tensors
+            Batch of data to use for training. Default implementation assumes
+            this batch has two elements: inputs and targets.
+        outputs : list or dictionary of torch.Tensors
+            Returned value of compute_forward().
+        loss : torch.Tensor
+            Returned value of compute_objectives().
+        should_step : boolean
+            Whether optimizer.step() was called or not.
+        """
+        if should_step:
+            # anneal model lr every update
+            self.hparams.noam_annealing(self.model_optimizer)
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called when a stage starts.
+
+        Arguments
+        ---------
+        stage : Stage
+            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
+        epoch : int
+            The current epoch count.
+        """
+        if stage != sb.Stage.TRAIN:
+            if (
+                stage == sb.Stage.VALID
+                and epoch % self.hparams.evaluation_interval != 0
+            ):
+                return
+
+            self.acc_metric = self.hparams.acc_computer()
+            self.bleu_metric = self.hparams.bleu_computer()
+            self.last_batch = None
+
+            logger.info("Loading pretrained HiFi-GAN ...")
+            self.test_vocoder = UnitHIFIGAN.from_hparams(
+                source=self.hparams.vocoder_source,
+                savedir=self.hparams.vocoder_download_path,
+                run_opts={"device": "cpu"},
+            )
+
+            logger.info("Loading pretrained ASR ...")
+            self.test_asr = EncoderDecoderASR.from_hparams(
+                source=self.hparams.asr_source,
+                savedir=self.hparams.asr_download_path,
+                run_opts={"device": "cpu"},
+            )
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of an epoch.
+
+        Arguments
+        ---------
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
+        stage_loss : float
+            The average loss for all of the data processed in this stage.
+        epoch : int
+            The currently-starting epoch. This is passed
+            `None` during the test stage.
+        """
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_loss
+
+        # At the end of validation, we can write
+        elif (
+            stage == sb.Stage.VALID
+            and epoch % self.hparams.evaluation_interval == 0
+        ):
+            # delete vocoder and asr to free memory for next training epoch
+            del self.test_vocoder
+            del self.test_asr
+
+            stage_stats = {"loss": stage_loss}
+            stage_stats["ACC"] = self.acc_metric.summarize()
+            stage_stats["BLEU"] = self.bleu_metric.summarize("BLEU")
+
+            output_progress_sample = (
+                self.hparams.progress_samples
+                and epoch % self.hparams.progress_samples_interval == 0
+            )
+
+            if output_progress_sample:
+                self._save_progress_sample(epoch)
+
+            current_epoch = self.hparams.epoch_counter.current
+            lr_model = self.hparams.noam_annealing.current_lr
+            lr_wav2vec = 0.0
+
+            if not self.hparams.wav2vec2_frozen:
+                (lr_wav2vec, new_lr_wav2vec,) = self.hparams.wav2vec_annealing(
+                    stage_stats["ACC"]
+                )
+                sb.nnet.schedulers.update_learning_rate(
+                    self.wav2vec_optimizer, new_lr_wav2vec
+                )
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={
+                    "epoch": current_epoch,
+                    "lr_model": lr_model,
+                    "lr_wav2vec": lr_wav2vec,
+                },
+                train_stats={"loss": self.train_stats},
+                valid_stats=stage_stats,
+            )
+
+            # Save the current checkpoint and delete previous checkpoints.
+            self.checkpointer.save_and_keep_only(
+                meta={
+                    "ACC": stage_stats["ACC"],
+                    "BLEU": stage_stats["BLEU"],
+                    "epoch": epoch,
+                },
+                max_keys=["BLEU"],
+                num_to_keep=10,
+            )
+
+        elif stage == sb.Stage.TEST:
+            stage_stats = {"loss": stage_loss}
+            stage_stats["BLEU"] = self.bleu_metric.summarize("BLEU")
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+
+            logger.info(
+                f"BLEU score: {round(self.bleu_metric.summarize('BLEU'), 2)}"
+            )
+            bleu_file = pl.Path(self.hparams.output_folder) / "bleu.txt"
+            with open(bleu_file, "a+", encoding="utf-8") as w:
+                self.bleu_metric.write_stats(w)
+
+    def _save_progress_sample(self, epoch):
+        """Save samples and BLEU score from last batch for current epoch.
+        Arguments
+        ---------
+        epoch : int
+            The currently-starting epoch. This is passed
+            `None` during the test stage.
+        """
+        if self.last_batch is None:
+            return
+
+        (
+            ids,
+            (wavs, transcripts),
+            (tgt_transcripts, tgt_wavs),
+        ) = self.last_batch
+
+        save_folder = pl.Path(self.hparams.progress_sample_path) / f"{epoch}"
+        save_folder.mkdir(parents=True, exist_ok=True)
+
+        sample_size = self.hparams.progress_batch_sample_size
+        if len(ids) < sample_size:
+            sample_size = len(ids)
+
+        for i in tqdm.tqdm(range(sample_size)):
+            utt_id = ids[i]
+            wav = wavs[i]
+            transcript = transcripts[i]
+            tgt_transcript = tgt_transcripts[i]
+            tgt_wav = tgt_wavs[i]
+
+            sample_path = save_folder / f"{utt_id}_pred.wav"
+            sb.dataio.dataio.write_audio(
+                sample_path, wav, self.hparams.sample_rate
+            )
+
+            sample_path = save_folder / f"{utt_id}_ref.wav"
+            sb.dataio.dataio.write_audio(
+                sample_path, tgt_wav, self.hparams.sample_rate
+            )
+
+            sample_path = save_folder / f"{utt_id}.txt"
+            with open(sample_path, "w") as file:
+                file.write(f"pred: {transcript}\n")
+                file.write(f"ref: {tgt_transcript}\n")
+
+        self.bleu_metric.append(
+            ids[:sample_size],
+            transcripts[:sample_size],
+            [tgt_transcripts[:sample_size]],
+        )
+
+        bleu_path = save_folder / "bleu.txt"
+        with open(bleu_path, "w") as file:
+            file.write(
+                f"BLEU score: {round(self.bleu_metric.summarize('BLEU'), 2)}\n"
+            )
+
+
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions.
+    """
+    codes_folder = pl.Path(hparams["codes_folder"])
+
+    # Define audio pipeline. In this case, we simply read the audio contained
+    # in the variable src_audio with the custom reader.
+    @sb.utils.data_pipeline.takes("src_audio")
+    @sb.utils.data_pipeline.provides("src_sig")
+    def src_audio_pipeline(wav):
+        """Load the source language audio signal.
+        This is done on the CPU in the `collate_fn`
+        """
+        info = torchaudio.info(wav)
+        sig = sb.dataio.dataio.read_audio(wav)
+        sig = torchaudio.transforms.Resample(
+            info.sample_rate, hparams["sample_rate"],
+        )(sig)
+        return sig
+
+    @sb.utils.data_pipeline.takes("tgt_audio")
+    @sb.utils.data_pipeline.provides("tgt_sig")
+    def tgt_audio_pipeline(wav):
+        """Load the target language audio signal.
+        This is done on the CPU in the `collate_fn`.
+        """
+        info = torchaudio.info(wav)
+        sig = sb.dataio.dataio.read_audio(wav)
+        sig = torchaudio.transforms.Resample(
+            info.sample_rate, hparams["sample_rate"],
+        )(sig)
+        return sig
+
+    @sb.utils.data_pipeline.takes("id")
+    @sb.utils.data_pipeline.provides("code_bos", "code_eos")
+    def unit_pipeline(utt_id):
+        """Load target codes"""
+        code = np.load(codes_folder / f"{utt_id}_tgt.npy")
+        code = torch.LongTensor(code)
+        code = torch.unique_consecutive(code)
+        code_bos = torch.cat((torch.LongTensor([hparams["bos_index"]]), code))
+        yield code_bos
+        code_eos = torch.cat((code, torch.LongTensor([hparams["eos_index"]])))
+        yield code_eos
+
+    datasets = {}
+    for split in hparams["splits"]:
+        datasets[split] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=hparams[f"{split}_json"],
+            dynamic_items=[
+                src_audio_pipeline,
+                tgt_audio_pipeline,
+                unit_pipeline,
+            ],
+            output_keys=[
+                "id",
+                "src_sig",
+                "tgt_sig",
+                "duration",
+                "code_bos",
+                "code_eos",
+                "tgt_text",
+            ],
+        )
+
+    # Sorting training data with ascending order makes the code  much
+    # faster  because we minimize zero-padding. In most of the cases, this
+    # does not harm the performance.
+    if hparams["sorting"] == "ascending":
+        datasets["train"] = datasets["train"].filtered_sorted(
+            sort_key="duration"
+        )
+        datasets["valid"] = datasets["valid"].filtered_sorted(
+            sort_key="duration"
+        )
+
+        hparams["train_dataloader_opts"]["shuffle"] = False
+        hparams["valid_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "descending":
+        datasets["train"] = datasets["train"].filtered_sorted(
+            sort_key="duration", reverse=True
+        )
+        datasets["valid"] = datasets["valid"].filtered_sorted(
+            sort_key="duration", reverse=True
+        )
+
+        hparams["train_dataloader_opts"]["shuffle"] = False
+        hparams["valid_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "random":
+        hparams["train_dataloader_opts"]["shuffle"] = True
+        hparams["valid_dataloader_opts"]["shuffle"] = False
+
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+
+    # Dynamic Batching is used, we instantiate the needed samplers.
+    train_batch_sampler = None
+    if hparams["dynamic_batching"]:
+        from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
+
+        dynamic_hparams = hparams["dynamic_batch_sampler"]
+        num_buckets = dynamic_hparams["num_buckets"]
+
+        train_batch_sampler = DynamicBatchSampler(
+            datasets["train"],
+            dynamic_hparams["max_batch_len"],
+            num_buckets=num_buckets,
+            length_func=lambda x: x["duration"],
+            shuffle=dynamic_hparams["shuffle_ex"],
+            batch_ordering=dynamic_hparams["batch_ordering"],
+            max_batch_ex=dynamic_hparams["max_batch_ex"],
+        )
+
+    return datasets, train_batch_sampler
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # If distributed_launch=True then
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    sys.path.append("../")
+    from cvss_prepare import prepare_cvss
+
+    sb.utils.distributed.run_on_main(
+        prepare_cvss,
+        kwargs={
+            "src_data_folder": hparams["src_data_folder"],
+            "tgt_data_folder": hparams["tgt_data_folder"],
+            "save_folder": hparams["save_folder"],
+            "splits": hparams["splits"],
+            "seed": hparams["seed"],
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    from extract_code import extract_cvss
+
+    sb.utils.distributed.run_on_main(
+        extract_cvss,
+        kwargs={
+            "data_folder": hparams["save_folder"],
+            "splits": hparams["splits"],
+            "kmeans_folder": hparams["kmeans_source"],
+            "encoder": hparams["encoder_source"],
+            "layer": hparams["layer"],
+            "save_folder": hparams["save_folder"],
+            "sample_rate": hparams["sample_rate"],
+            "skip_extract": hparams["skip_extract"],
+        },
+    )
+
+    datasets, train_bsampler = dataio_prepare(hparams)
+
+    s2ut_brain = S2UT(
+        modules=hparams["modules"],
+        opt_class=hparams["opt_class"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    train_dataloader_opts = hparams["train_dataloader_opts"]
+    valid_dataloader_opts = hparams["valid_dataloader_opts"]
+
+    if train_bsampler is not None:
+        train_dataloader_opts = {
+            "batch_sampler": train_bsampler,
+            "num_workers": hparams["num_workers"],
+            "collate_fn": hparams["train_dataloader_opts"]["collate_fn"],
+        }
+
+    s2ut_brain.fit(
+        s2ut_brain.hparams.epoch_counter,
+        datasets["train"],
+        datasets["valid_small"],
+        train_loader_kwargs=train_dataloader_opts,
+        valid_loader_kwargs=valid_dataloader_opts,
+    )
+
+    test_dataloader_opts = {
+        "batch_size": 1,
+    }
+
+    for dataset in ["valid", "test"]:
+        s2ut_brain.evaluate(
+            datasets[dataset],
+            max_key="BLEU",
+            test_loader_kwargs=test_dataloader_opts,
+        )
diff --git a/recipes/CVSS/cvss_prepare.py b/recipes/CVSS/cvss_prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..77335f6851bd6d0fb111072fffeddda3d7223ae3
--- /dev/null
+++ b/recipes/CVSS/cvss_prepare.py
@@ -0,0 +1,237 @@
+"""
+CVSS data preparation.
+Download: https://github.com/google-research-datasets/cvss
+
+Authors
+ * Jarod DURET 2023
+"""
+
+import os
+import csv
+import json
+import logging
+import random
+import tqdm
+import pathlib as pl
+
+import torchaudio
+from speechbrain.dataio.dataio import (
+    load_pkl,
+    save_pkl,
+)
+
+OPT_FILE = "opt_cvss_prepare.pkl"
+
+SRC_METADATA = "validated.tsv"
+TGT_METADATA = {
+    "train": "train.tsv",
+    "valid": "dev.tsv",
+    "test": "test.tsv",
+}
+
+# Need to be set according to your system
+SRC_AUDIO = "clips"
+TGT_AUDIO = {
+    "train": "train",
+    "valid": "dev",
+    "test": "test",
+}
+
+# Number of samples for the small evalution subset
+SMALL_EVAL_SIZE = 1000
+
+log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+logging.basicConfig(format=log_format, level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def prepare_cvss(
+    src_data_folder,
+    tgt_data_folder,
+    save_folder,
+    splits=["train", "valid", "test"],
+    seed=1234,
+    skip_prep=False,
+):
+    """
+    Prepares the csv files for the CVSS datasets.
+
+    Arguments
+    ---------
+    src_data_folder : str
+        Path to the folder where the original source CV data is stored.
+    tgt_data_folder : str
+        Path to the folder where the original target CVSS data is stored.
+    save_folder : str
+        The directory where to store the csv files.
+    splits : list
+        List of splits to prepare.
+    skip_prep: Bool
+        If True, skip preparation.
+    seed : int
+        Random seed
+    """
+    # setting seeds for reproducible code.
+    random.seed(seed)
+
+    if skip_prep:
+        return
+
+    # Create configuration for easily skipping data_preparation stage
+    conf = {
+        "src_data_folder": src_data_folder,
+        "tgt_data_folder": tgt_data_folder,
+        "splits": splits,
+        "save_folder": save_folder,
+        "seed": seed,
+    }
+
+    if not os.path.exists(save_folder):
+        os.makedirs(save_folder)
+
+    src_validated = pl.Path(src_data_folder) / SRC_METADATA
+    tgt_train = pl.Path(tgt_data_folder) / TGT_METADATA["train"]
+    tgt_valid = pl.Path(tgt_data_folder) / TGT_METADATA["valid"]
+    tgt_test = pl.Path(tgt_data_folder) / TGT_METADATA["test"]
+
+    src_audio = pl.Path(src_data_folder) / SRC_AUDIO
+    tgt_audio_train = pl.Path(tgt_data_folder) / TGT_AUDIO["train"]
+    tgt_audio_valid = pl.Path(tgt_data_folder) / TGT_AUDIO["valid"]
+    tgt_audio_test = pl.Path(tgt_data_folder) / TGT_AUDIO["test"]
+
+    save_opt = pl.Path(save_folder) / OPT_FILE
+    save_json_train = pl.Path(save_folder) / "train.json"
+    save_json_valid = pl.Path(save_folder) / "valid.json"
+    save_json_valid_small = pl.Path(save_folder) / "valid_small.json"
+    save_json_test = pl.Path(save_folder) / "test.json"
+
+    # Check if this phase is already done (if so, skip it)
+    if skip(splits, save_folder, conf):
+        logger.info("Skipping preparation, completed in previous run.")
+        return
+
+    msg = "\tCreating json file for CVSS Dataset.."
+    logger.info(msg)
+
+    # Prepare csv
+    if "train" in splits:
+        prepare_json(
+            save_json_train,
+            src_audio,
+            tgt_audio_train,
+            src_validated,
+            tgt_train,
+        )
+    if "valid" in splits:
+        prepare_json(
+            save_json_valid,
+            src_audio,
+            tgt_audio_valid,
+            src_validated,
+            tgt_valid,
+        )
+        prepare_json(
+            save_json_valid_small,
+            src_audio,
+            tgt_audio_valid,
+            src_validated,
+            tgt_valid,
+            limit_to_n_sample=SMALL_EVAL_SIZE,
+        )
+    if "test" in splits:
+        prepare_json(
+            save_json_test, src_audio, tgt_audio_test, src_validated, tgt_test,
+        )
+
+    save_pkl(conf, save_opt)
+
+
+def skip(splits, save_folder, conf):
+    """
+    Detects if the cvss data_preparation has been already done.
+    If the preparation has been done, we can skip it.
+
+    Returns
+    -------
+    bool
+        if True, the preparation phase can be skipped.
+        if False, it must be done.
+    """
+    # Checking json files
+    skip = True
+
+    split_files = {
+        "train": "train.json",
+        "valid": "valid.json",
+        "valid_small": "valid_small.json",
+        "test": "test.json",
+    }
+
+    for split in splits:
+        if not os.path.isfile(os.path.join(save_folder, split_files[split])):
+            skip = False
+
+    #  Checking saved options
+    save_opt = os.path.join(save_folder, OPT_FILE)
+    if skip is True:
+        if os.path.isfile(save_opt):
+            opts_old = load_pkl(save_opt)
+            if opts_old == conf:
+                skip = True
+            else:
+                skip = False
+        else:
+            skip = False
+    return skip
+
+
+def prepare_json(
+    json_file,
+    src_audio_folder,
+    tgt_audio_folder,
+    src_validated,
+    tgt_split,
+    limit_to_n_sample=None,
+):
+    """
+    Creates json file.
+
+    """
+
+    json_dict = {}
+    tgt_meta = list(
+        csv.reader(open(tgt_split), delimiter="\t", quoting=csv.QUOTE_NONE)
+    )
+
+    limit_to_n_sample = (
+        len(tgt_meta) if not limit_to_n_sample else limit_to_n_sample
+    )
+
+    for i in tqdm.tqdm(range(limit_to_n_sample)):
+        session_id = tgt_meta[i][0].split(".")[0]
+
+        tgt_audio = f"{tgt_audio_folder}/{session_id}.mp3.wav"
+        src_audio = f"{src_audio_folder}/{session_id}.mp3"
+
+        src_sig, sr = torchaudio.load(src_audio)
+        duration = src_sig.shape[1] / sr
+
+        # src_text = meta_dict[session_id]["sentence"]
+        tgt_text = tgt_meta[i][1]
+
+        if duration < 1.5 or len(tgt_text) < 10:
+            continue
+
+        json_dict[session_id] = {
+            "src_audio": src_audio,
+            "tgt_audio": tgt_audio,
+            "duration": duration,
+            # "src_text": src_text,
+            "tgt_text": tgt_text,
+        }
+
+    # Writing the dictionary to the json file
+    with open(json_file, mode="w") as json_f:
+        json.dump(json_dict, json_f, indent=2)
+
+    logger.info(f"{json_file} successfully created!")
diff --git a/recipes/CommonLanguage/lang_id/hparams/train_ecapa_tdnn.yaml b/recipes/CommonLanguage/lang_id/hparams/train_ecapa_tdnn.yaml
index ebc067dffaeb0c1603a70fac4b21a048e9b8edd7..d4722f45ca06cc58946ff78f0aed3d3185b8c15c 100644
--- a/recipes/CommonLanguage/lang_id/hparams/train_ecapa_tdnn.yaml
+++ b/recipes/CommonLanguage/lang_id/hparams/train_ecapa_tdnn.yaml
@@ -16,13 +16,20 @@ __set_seed: !apply:torch.manual_seed [!ref <seed>]
 data_folder: !PLACEHOLDER # e.g. /localscratch/common_voice_kpd/
 output_folder: !ref results/ECAPA-TDNN/<seed>
 save_folder: !ref <output_folder>/save
-rir_folder: !ref <data_folder>
 train_log: !ref <output_folder>/train_log.txt
 train_csv: !ref <save_folder>/train.csv
 dev_csv: !ref <save_folder>/dev.csv
 test_csv: !ref <save_folder>/test.csv
 skip_prep: False
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
+
 # The train logger writes training statistics to a file, as well as stdout.
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
@@ -31,10 +38,10 @@ error_stats: !name:speechbrain.utils.metric_stats.MetricStats
     metric: !name:speechbrain.nnet.losses.classification_error
         reduction: batch
 
+####################### Training Parameters ####################################
+
 # Feature parameters btw: 40 - 80
 n_mels: 80
-
-# Training Parameters
 sample_rate: 16000
 number_of_epochs: 30
 batch_size: 4
@@ -44,32 +51,65 @@ emb_channels: [1024, 1024, 1024, 1024, 3072]
 emb_attention_channels: 128
 
 # Dataloaders
+num_workers: 4
+drop_last: True
 train_dataloader_options:
+    num_workers: !ref <num_workers>
     batch_size: !ref <batch_size>
-    drop_last: True
+    drop_last: !ref <drop_last>
     shuffle: True
 
 test_dataloader_options:
+    num_workers: !ref <num_workers>
     batch_size: !ref <batch_size>
     shuffle: True
 
-# Added noise and reverb come from OpenRIR dataset, automatically
-# downloaded and prepared with this Environmental Corruption class.
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    babble_prob: 0.0
-    reverb_prob: 1.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-# Adds speech change + time and frequency dropouts (time-domain implementation)
-# A small speed change help to improve the performance of speaker-id as well.
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    shuffle_augmentations: True
+    min_augmentations: 1
+    max_augmentations: 3
+    augmentations: [
+        !ref <add_reverb>,
+        !ref <add_noise>,
+        !ref <speed_perturb>]
 
 # Feature extraction
 compute_features: !new:speechbrain.lobes.features.Fbank
@@ -80,6 +120,8 @@ mean_var_norm_input: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
     std_norm: False
 
+############################## Models ##########################################
+
 # To design a custom model, either just edit the simple CustomModel
 # class that's listed here, or replace this `!new` call with a line
 # pointing to a different file you've defined.
@@ -109,8 +151,6 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
 # device, as well as having train()/eval() called on them by the Brain class.
 modules:
     compute_features: !ref <compute_features>
-    env_corrupt: !ref <env_corrupt>
-    augmentation: !ref <augmentation>
     embedding_model: !ref <embedding_model>
     mean_var_norm_input: !ref <mean_var_norm_input>
     classifier: !ref <classifier>
@@ -139,6 +179,8 @@ lr_annealing: !new:speechbrain.nnet.schedulers.LinearScheduler
     final_value: !ref <lr_final>
     epoch_count: !ref <number_of_epochs>
 
+############################## Logging and Pretrainer ##########################
+
 # This object is used for saving the state of training both so that it
 # can be resumed if it gets interrupted, and also so that the best checkpoint
 # can be later loaded for evaluation or inference.
diff --git a/recipes/CommonLanguage/lang_id/train.py b/recipes/CommonLanguage/lang_id/train.py
index bb245481618ab3006c88deaa0226b170e891e3c1..098da429204edc81135a93fcc54ec29ae8b35786 100644
--- a/recipes/CommonLanguage/lang_id/train.py
+++ b/recipes/CommonLanguage/lang_id/train.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python3
 import os
 import sys
-import torch
 import logging
 import torchaudio
 import speechbrain as sb
@@ -36,15 +35,9 @@ class LID(sb.Brain):
         """
         wavs, lens = wavs
 
-        # Add augmentation if specified. In this version of augmentation, we
-        # concatenate the original and the augment batches in a single bigger
-        # batch. This is more memory-demanding, but helps to improve the
-        # performance. Change it if you run OOM.
-        if stage == sb.Stage.TRAIN:
-            wavs_noise = self.modules.env_corrupt(wavs, lens)
-            wavs = torch.cat([wavs, wavs_noise], dim=0)
-            lens = torch.cat([lens, lens], dim=0)
-            wavs = self.hparams.augmentation(wavs, lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, lens = self.hparams.wav_augment(wavs, lens)
 
         # Feature extraction and normalization
         feats = self.modules.compute_features(wavs)
@@ -103,11 +96,10 @@ class LID(sb.Brain):
 
         # Concatenate labels (due to data augmentation)
         if stage == sb.Stage.TRAIN:
-            targets = torch.cat([targets, targets], dim=0)
-            lens = torch.cat([lens, lens], dim=0)
-
-            if hasattr(self.hparams.lr_annealing, "on_batch_end"):
-                self.hparams.lr_annealing.on_batch_end(self.optimizer)
+            if hasattr(self.hparams, "wav_augment"):
+                targets = self.hparams.wav_augment.replicate_labels(targets)
+                if hasattr(self.hparams.lr_annealing, "on_batch_end"):
+                    self.hparams.lr_annealing.on_batch_end(self.optimizer)
 
         loss = self.hparams.compute_cost(predictions, targets)
 
@@ -279,13 +271,16 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
+    # Data preparation for augmentation
+    sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
+    sb.utils.distributed.run_on_main(hparams["prepare_rir_data"])
 
     # Create dataset objects "train", "dev", and "test" and language_encoder
     datasets, language_encoder = dataio_prep(hparams)
 
     # Fetch and laod pretrained modules
     sb.utils.distributed.run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Initialize the Brain object to prepare for mask training.
     lid_brain = LID(
diff --git a/recipes/CommonVoice/ASR/CTC/README.md b/recipes/CommonVoice/ASR/CTC/README.md
index 6f28ddce23ddb22eefdda21838e3c0713dc0d2a7..0b6cbdcf61436d1a1408c442baabfee4c5eb810c 100644
--- a/recipes/CommonVoice/ASR/CTC/README.md
+++ b/recipes/CommonVoice/ASR/CTC/README.md
@@ -1,9 +1,18 @@
 # CommonVoice ASR with CTC based Seq2Seq models.
-This folder contains scripts necessary to run an ASR experiment with the CommonVoice dataset: [CommonVoice Homepage](https://commonvoice.mozilla.org/)
+This folder contains scripts necessary to run an ASR experiment with the CommonVoice 14.0 dataset
 
 # How to run
 python train.py hparams/{hparam_file}.yaml
 
+To use an n-gram Language Model (LM) for decoding, follow these steps:
+1. Uncomment the line `kenlm_model_path: none` in the `test_beam_serch` entry in the yaml file.
+2. Set a path to an ARPA or bin file containing the n-gram LM.
+
+For training an n-gram LM in ARPA (or bin) format, refer to the LM recipe in recipes/CommonVoice/LM.
+Alternatively, you can download a pre-trained n-gram LM from our Dropbox repository at this link: [Pretrained n-gram LMs](https://www.dropbox.com/scl/fo/zw505t10kesqpvkt6m3tu/h?rlkey=6626h1h665tvlo1mtekop9rx5&dl=0).
+
+These models are trained on the Commonvoice audio transcriptions available in the training set.
+
 # Data preparation
 It is important to note that CommonVoice initially offers mp3 audio files at 42Hz. Hence, audio files are downsampled on the fly within the dataio function of the training script.
 
@@ -14,18 +23,31 @@ Here is a list of the different languages that we tested within the CommonVoice
 - French
 - Italian
 - Kinyarwanda
+- Arabic
+- Spanish
+- Portuguese
+- Chinese(china)
+
+>>Note:
+>In our experiments,  we use CTC beam search and also boost the performance using the 5-gram model previously trained
+on the transcription of the training data.(Refer to LM recipe: recipes/CommonVoice/LM).
+
+>>Note:
+> For Chinese the concept of word is not well-defined, hence, we consider the character error rate instead of the word error rate. For the same reason,  we don't also employ 5-gram.
 
 # Results
 | Language | CommonVoice Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | HuggingFace link | Model link | GPUs |
 | ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:| :-----------:|
-| English | 2020-12-11 | train_en_with_wav2vec.yaml | No | 5.01 | 12.57 | 7.32 | 15.58 | Not Avail. | [model](https://www.dropbox.com/sh/o3q43r4wdovbmnd/AADXcVomQr549NdAgCpI7OQHa?dl=0) | 2xV100 32GB |
-| German | 2022-08-16 | train_de_with_wav2vec.yaml | No | 1.90 | 8.02 | 2.40 | 9.54 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-de) | [model](https://www.dropbox.com/sh/vdz7apt16nbq94g/AADI5o23Ll_NmjiPlg9bzPjta?dl=0) | 1xRTXA6000 48GB |
-| French | 2020-12-11 | train_fr_with_wav2vec.yaml | No | 2.60 | 8.59 | 3.19 | 9.96 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-fr) | [model](https://www.dropbox.com/sh/wytlbeddrt8oe4n/AAAY59qMsDlWy5F017bmBeVua?dl=0) | 2xV100 32GB |
-| Italian | 2020-12-11 | train_it_with_wav2vec.yaml | No | 2.77 | 9.83 | 3.16 | 10.85 | Not Avail. | [model](https://www.dropbox.com/sh/0v2o2hmrv1j33p6/AAA3xUiqKbSKsX88fWfptPmFa?dl=0) | 2xV100 32GB |
-| Kinyarwanda | 2020-12-11 | train_rw_with_wav2vec.yaml | No | 6.20 | 20.07 | 8.25 | 23.12 | Not Avail. | [model](https://www.dropbox.com/sh/ccgirbq9r8uzubi/AAAynCvEV8EjEpMavFRPp87Ta?dl=0) | 2xV100 32GB |
-
-*For German, it takes around 5.5 hrs an epoch.* <br>
-The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0).
+| English | 2023-08-15 | train_en_with_wav2vec.yaml | No | 5.65 | 13.67 | 7.76 | 16.16 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-en) | [model](https://www.dropbox.com/sh/ch10cnbhf1faz3w/AACdHFG65LC6582H0Tet_glTa?dl=0) | 1xV100 32GB |
+| German | 2023-08-15 | train_de_with_wav2vec.yaml | No | 1.74 | 7.40 | 2.18 | 8.39 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-de) | [model](https://www.dropbox.com/sh/dn7plq4wfsujsi1/AABS1kqB_uqLJVkg-bFkyPpVa?dl=0) | 1xV100 32GB |
+| French | 2023-08-15 | train_fr_with_wav2vec.yaml | No | 2.59 | 8.47 | 3.36 | 9.71 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-fr) | [model](https://www.dropbox.com/sh/0i7esfa8jp3rxpp/AAArdi8IuCRmob2WAS7lg6M4a?dl=0) | 1xV100 32GB |
+| Italian | 2023-08-15 | train_it_with_wav2vec.yaml | No | 2.10 | 7.77 |  2.30 | 7.99 |[model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-it) | [model](https://www.dropbox.com/sh/hthxqzh5boq15rn/AACftSab_FM6EFWWPgHpKw82a?dl=0) | 1xV100 32GB |
+| Kinyarwanda | 2023-08-15 | train_rw_with_wav2vec.yaml | No | 5.47 | 19.58 | 7.30 | 22.52 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-rw) | [model](https://www.dropbox.com/sh/4iax0l4yfry37gn/AABuQ31JY-Sbyi1VlOJfV7haa?dl=0) | 1xV100 32GB |
+| Arabic | 2023-08-15 | train_ar_with_wav2vec.yaml | No | 6.45 | 20.80 | 9.65 | 28.53 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-ar) | [model](https://www.dropbox.com/sh/7tnuqqbr4vy96cc/AAA_5_R0RmqFIiyR0o1nVS4Ia?dl=0) | 1xV100 32GB |
+| Spanish | 2023-08-15 | train_es_with_wav2vec.yaml | No | 3.36 | 12.61 | 3.67 | 12.67 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-es) | [model](https://www.dropbox.com/sh/ejvzgl3d3g8g9su/AACYtbSWbDHvBr06lAb7A4mVa?dl=0) | 1xV100 32GB |
+| Portuguese | 2023-08-15 | train_pt_with_wav2vec.yaml | No | 6.26 | 21.05 | 6.63 | 21.69 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-pt) | [model](https://www.dropbox.com/sh/80wucrvijdvao2a/AAD6-SZ2_ZZXmlAjOTw6fVloa?dl=0) | 1xV100 32GB |
+| Chinese(china) | 2023-08-15 | train_zh-CN_with_wav2vec.yaml | No | 25.03 | - | 23.17 | - | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-zh-CN) | [model](https://www.dropbox.com/sh/2bikr81vgufoglf/AABMpD0rLIaZBxjtwBHgrNpga?dl=0) | 1xV100 32GB |
+
 
 ## How to simply use pretrained models to transcribe my audio file?
 
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_it_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_ar_with_wav2vec.yaml
similarity index 56%
rename from recipes/CommonVoice/ASR/seq2seq/hparams/train_it_with_wav2vec.yaml
rename to recipes/CommonVoice/ASR/CTC/hparams/train_ar_with_wav2vec.yaml
index 5b7053ef836a73b103bd209700f4fa4a48aca13f..643df09944098da647e4067ff7d6546b433ee036 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_it_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_ar_with_wav2vec.yaml
@@ -1,29 +1,28 @@
 # ################################
-# Model: wav2vec2 + DNN + CTC/Attention
+# Model: wav2vec2 + DNN + CTC
 # Augmentation: SpecAugment
-# Authors: Titouan Parcollet 2021
-#          Mirco Ravanelli 2021
+# Authors: Pooneh Mousavi 2023
 # ################################
 
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 1234
 __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/wav2vec2_ctcatt_it/<seed>
+output_folder: !ref results/wav2vec2_ctc_ar/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
-# URL for the biggest Fairseq english wav2vec2 model.
-wav2vec2_hub: facebook/wav2vec2-large-100k-voxpopuli
+# URL for the biggest Fairseq multilingual
+wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
 wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
 
 # Data files
-data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-6.1-2020-12-11/it
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
 train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
 dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
 test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
 accented_letters: True
-language: it # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+language: ar # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
 train_csv: !ref <save_folder>/train.csv
 valid_csv: !ref <save_folder>/dev.csv
 test_csv: !ref <save_folder>/test.csv
@@ -31,23 +30,21 @@ skip_prep: False # Skip data preparation
 
 # We remove utterance slonger than 10s in the train/dev/test sets as
 # longer sentences certainly correspond to "open microphones".
-avoid_if_longer_than: 8.0
+avoid_if_longer_than: 10.0
 
-# Training parameters
-number_of_epochs: 45
-number_of_ctc_epochs: 15
+####################### Training Parameters ####################################
+
+number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
-ctc_weight: 0.3
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
-# Must be 6 per GPU to fit 16GB of VRAM
+# Must be 8 per GPU to fit 32GB of VRAM
 batch_size: 12
 test_batch_size: 4
 
@@ -62,33 +59,38 @@ test_dataloader_options:
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
-dnn_layers: 2
 dnn_neurons: 1024
-emb_size: 128
-dec_neurons: 1024
-dec_hidden_size: !ref <dec_neurons>
-dec_attn_dim: !ref <dec_neurons>
 freeze_wav2vec: False
+freeze_feature_extractor: False
+dropout: 0.15
+warmup_steps: 500
 
 # Outputs
-output_neurons: 500  # BPE size, index(blank/eos/bos) = 0
+output_neurons: 1000  # BPE size, index(blank/eos/bos) = 0
 
 # Decoding parameters
 # Be sure that the bos and eos index match with the BPEs ones
 blank_index: 0
 bos_index: 1
 eos_index: 2
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 80
-eos_threshold: 1.5
-using_max_attn_shift: True
-max_attn_shift: 140
-# ctc_weight_decode: 0.0
-temperature: 1.50
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    # kenlm_model_path: none
+
 
 #
 # Functions and classes
@@ -96,20 +98,67 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
-    activation: !ref <activation>
-    dnn_blocks: !ref <dnn_layers>
-    dnn_neurons: !ref <dnn_neurons>
-
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+    linear1: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation: !new:torch.nn.LeakyReLU
+    drop: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear2: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation2: !new:torch.nn.LeakyReLU
+    drop2: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear3: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation3: !new:torch.nn.LeakyReLU
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
+    freeze_feature_extractor: !ref <freeze_feature_extractor>
     save_path: !ref <wav2vec2_folder>
 
 #####
@@ -125,51 +174,23 @@ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
 #    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
 #####
 
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dec_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_hidden_size>
-    attn_dim: !ref <dec_attn_dim>
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.15
-
 ctc_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dnn_neurons>
     n_neurons: !ref <output_neurons>
 
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>
-
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
 
 ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
     blank_index: !ref <blank_index>
 
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
 modules:
     wav2vec2: !ref <wav2vec2>
     enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
     ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
 
 model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
+    - [!ref <enc>, !ref <ctc_lin>]
 
 model_opt_class: !name:torch.optim.Adadelta
     lr: !ref <lr>
@@ -191,22 +212,6 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
     annealing_factor: 0.9
     patient: 0
 
-beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-    eos_threshold: !ref <eos_threshold>
-    using_max_attn_shift: !ref <using_max_attn_shift>
-    max_attn_shift: !ref <max_attn_shift>
-    temperature: !ref <temperature>
-
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
     recoverables:
diff --git a/recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml
index 49e3e6909e95ecc97a9fec33deb29b0e8a1a7afa..adb8e5bb52626992dfd6899871a8f74a30bc62dd 100644
--- a/recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml
@@ -33,12 +33,13 @@ skip_prep: False
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 45
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -51,17 +52,17 @@ dataloader_num_workers: 8
 test_num_workers: 8
 
 dataloader_options:
-  batch_size: !ref <batch_size>
-  num_workers: !ref <dataloader_num_workers>
+    batch_size: !ref <batch_size>
+    num_workers: !ref <dataloader_num_workers>
 test_dataloader_options:
-  batch_size: !ref <test_batch_size>
-  num_workers: !ref <test_num_workers>
+    batch_size: !ref <test_batch_size>
+    num_workers: !ref <test_num_workers>
 
 # BPE parameters
 token_type: char # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 # activation: !name:torch.nn.LeakyReLU
 dnn_neurons: 1024
 wav2vec_output_dim: !ref <dnn_neurons>
@@ -78,42 +79,85 @@ blank_index: 0
 bos_index: 1
 eos_index: 2
 
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    # kenlm_model_path: none
+
 # Functions and classes
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-  limit: !ref <number_of_epochs>
-
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-  sample_rate: !ref <sample_rate>
-  speeds: [95, 100, 105]
+    limit: !ref <number_of_epochs>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
 
 enc: !new:speechbrain.nnet.containers.Sequential
-  input_shape: [null, null, !ref <wav2vec_output_dim>]
-  linear1: !name:speechbrain.nnet.linear.Linear
-    n_neurons: !ref <dnn_neurons>
-    bias: True
-  bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
-  activation: !new:torch.nn.LeakyReLU
-  drop: !new:torch.nn.Dropout
-    p: !ref <dropout>
-  linear2: !name:speechbrain.nnet.linear.Linear
-    n_neurons: !ref <dnn_neurons>
-    bias: True
-  bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
-  activation2: !new:torch.nn.LeakyReLU
-  drop2: !new:torch.nn.Dropout
-    p: !ref <dropout>
-  linear3: !name:speechbrain.nnet.linear.Linear
-    n_neurons: !ref <dnn_neurons>
-    bias: True
-  bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
-  activation3: !new:torch.nn.LeakyReLU
-
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
-  source: !ref <wav2vec2_hub>
-  output_norm: True
-  freeze: !ref <freeze_wav2vec>
-  freeze_feature_extractor: !ref <freeze_feature_extractor>
-  save_path: !ref <wav2vec2_folder>
+    input_shape: [null, null, !ref <wav2vec_output_dim>]
+    linear1: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation: !new:torch.nn.LeakyReLU
+    drop: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear2: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation2: !new:torch.nn.LeakyReLU
+    drop2: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear3: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation3: !new:torch.nn.LeakyReLU
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <wav2vec2_hub>
+    output_norm: True
+    freeze: !ref <freeze_wav2vec>
+    freeze_feature_extractor: !ref <freeze_feature_extractor>
+    save_path: !ref <wav2vec2_folder>
 
 #####
 # Uncomment this block if you prefer to use a Fairseq pretrained model instead
@@ -128,56 +172,56 @@ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
 #    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
 
 ctc_lin: !new:speechbrain.nnet.linear.Linear
-  input_size: !ref <dnn_neurons>
-  n_neurons: !ref <output_neurons>
+    input_size: !ref <dnn_neurons>
+    n_neurons: !ref <output_neurons>
 
 log_softmax: !new:speechbrain.nnet.activations.Softmax
-  apply_log: True
+    apply_log: True
 
 ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-  blank_index: !ref <blank_index>
+    blank_index: !ref <blank_index>
 
 modules:
-  wav2vec2: !ref <wav2vec2>
-  enc: !ref <enc>
-  ctc_lin: !ref <ctc_lin>
+    wav2vec2: !ref <wav2vec2>
+    enc: !ref <enc>
+    ctc_lin: !ref <ctc_lin>
 
 model: !new:torch.nn.ModuleList
-  - [!ref <enc>, !ref <ctc_lin>]
+    - [!ref <enc>, !ref <ctc_lin>]
 
 model_opt_class: !name:torch.optim.Adadelta
-  lr: !ref <lr>
-  rho: 0.95
-  eps: 1.e-8
+    lr: !ref <lr>
+    rho: 0.95
+    eps: 1.e-8
 
 wav2vec_opt_class: !name:torch.optim.Adam
-  lr: !ref <lr_wav2vec>
+    lr: !ref <lr_wav2vec>
 
 lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
-  initial_value: !ref <lr>
-  improvement_threshold: 0.0025
-  annealing_factor: 0.8
-  patient: 0
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.8
+    patient: 0
 
 lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
-  initial_value: !ref <lr_wav2vec>
-  improvement_threshold: 0.0025
-  annealing_factor: 0.9
-  patient: 0
+    initial_value: !ref <lr_wav2vec>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.9
+    patient: 0
 
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-  checkpoints_dir: !ref <save_folder>
-  recoverables:
-    wav2vec2: !ref <wav2vec2>
-    model: !ref <model>
-    scheduler_model: !ref <lr_annealing_model>
-    scheduler_wav2vec: !ref <lr_annealing_wav2vec>
-    counter: !ref <epoch_counter>
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        wav2vec2: !ref <wav2vec2>
+        model: !ref <model>
+        scheduler_model: !ref <lr_annealing_model>
+        scheduler_wav2vec: !ref <lr_annealing_wav2vec>
+        counter: !ref <epoch_counter>
 
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-  save_file: !ref <train_log>
+    save_file: !ref <train_log>
 
 error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
 
 cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
-  split_tokens: True
+    split_tokens: True
diff --git a/recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml
index 65853cd83cb999609b9e17a8eae333984c756865..d8aaea36e46f52efaf947e3cf10a14b339c8cb0f 100644
--- a/recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml
@@ -32,12 +32,13 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -58,7 +59,7 @@ test_dataloader_options:
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 # activation: !name:torch.nn.LeakyReLU
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
@@ -76,16 +77,59 @@ blank_index: 0
 bos_index: 1
 eos_index: 2
 
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    # kenlm_model_path: none
+
 #
 # Functions and classes
 #
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -108,7 +152,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_rw_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_es_with_wav2vec.yaml
similarity index 59%
rename from recipes/CommonVoice/ASR/seq2seq/hparams/train_rw_with_wav2vec.yaml
rename to recipes/CommonVoice/ASR/CTC/hparams/train_es_with_wav2vec.yaml
index 09be21e4bcc2f1739a59c06c67513947504d2f64..e32a242d1a4f6d09b5d01b0b614e94df08298167 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_rw_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_es_with_wav2vec.yaml
@@ -1,18 +1,18 @@
 # ################################
-# Model: wav2vec2 + DNN + CTC/Attention
+# Model: wav2vec2 + DNN + CTC
 # Augmentation: SpecAugment
-# Authors: Titouan Parcollet 2021
+# Authors: Pooneh Mousavi 2023
 # ################################
 
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 1234
 __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/wav2vec2_ctcatt_rw/<seed>
+output_folder: !ref results/wav2vec2_ctc_es/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
-# URL for the biggest HuggingFace multilingual w2v2 from XLSR.
+# URL for the biggest Fairseq multilingual
 wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
 wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
 
@@ -21,8 +21,8 @@ data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
 train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
 dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
 test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
-accented_letters: False
-language: rw # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+accented_letters: True
+language: es # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
 train_csv: !ref <save_folder>/train.csv
 valid_csv: !ref <save_folder>/dev.csv
 test_csv: !ref <save_folder>/test.csv
@@ -32,21 +32,19 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 30
-number_of_ctc_epochs: 20
 lr: 1.0
 lr_wav2vec: 0.0001
-ctc_weight: 0.3
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
-# Must be 6 per GPU to fit 32GB of VRAM
+# Must be 8 per GPU to fit 32GB of VRAM
 batch_size: 12
 test_batch_size: 4
 
@@ -61,16 +59,13 @@ test_dataloader_options:
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
-dnn_layers: 2
 dnn_neurons: 1024
-emb_size: 128
-dec_neurons: 1024
-dec_hidden_size: !ref <dec_neurons>
-dec_attn_dim: !ref <dec_neurons>
 freeze_wav2vec: False
+freeze_feature_extractor: False
+dropout: 0.15
+warmup_steps: 500
 
 # Outputs
 output_neurons: 1000  # BPE size, index(blank/eos/bos) = 0
@@ -80,14 +75,21 @@ output_neurons: 1000  # BPE size, index(blank/eos/bos) = 0
 blank_index: 0
 bos_index: 1
 eos_index: 2
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 80
-eos_threshold: 1.5
-using_max_attn_shift: True
-max_attn_shift: 140
-# ctc_weight_decode: 0.0
-temperature: 1.50
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    # kenlm_model_path: none
 
 #
 # Functions and classes
@@ -95,20 +97,67 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
-    activation: !ref <activation>
-    dnn_blocks: !ref <dnn_layers>
-    dnn_neurons: !ref <dnn_neurons>
-
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+    linear1: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation: !new:torch.nn.LeakyReLU
+    drop: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear2: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation2: !new:torch.nn.LeakyReLU
+    drop2: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear3: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation3: !new:torch.nn.LeakyReLU
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
+    freeze_feature_extractor: !ref <freeze_feature_extractor>
     save_path: !ref <wav2vec2_folder>
 
 #####
@@ -124,51 +173,23 @@ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
 #    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
 #####
 
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dec_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_hidden_size>
-    attn_dim: !ref <dec_attn_dim>
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.15
-
 ctc_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dnn_neurons>
     n_neurons: !ref <output_neurons>
 
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>
-
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
 
 ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
     blank_index: !ref <blank_index>
 
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
 modules:
     wav2vec2: !ref <wav2vec2>
     enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
     ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
 
 model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
+    - [!ref <enc>, !ref <ctc_lin>]
 
 model_opt_class: !name:torch.optim.Adadelta
     lr: !ref <lr>
@@ -190,22 +211,6 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
     annealing_factor: 0.9
     patient: 0
 
-beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-    eos_threshold: !ref <eos_threshold>
-    using_max_attn_shift: !ref <using_max_attn_shift>
-    max_attn_shift: !ref <max_attn_shift>
-    temperature: !ref <temperature>
-
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
     recoverables:
diff --git a/recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml
index 500f32da33740d312cdb3846e5ee31c3db89cce2..079cfe73fce4ce69541d86017ac971f2daa4f2fc 100644
--- a/recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml
@@ -32,12 +32,13 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -58,7 +59,7 @@ test_dataloader_options:
 token_type: char  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 # activation: !name:torch.nn.LeakyReLU
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
@@ -76,16 +77,58 @@ blank_index: 0
 bos_index: 1
 eos_index: 2
 
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    # kenlm_model_path: none
 #
 # Functions and classes
 #
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -108,7 +151,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: False
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/CommonVoice/ASR/CTC/hparams/train_it_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_it_with_wav2vec.yaml
index 1e81b2dff5f8b56fc5f1fb711f6786aaca7057d0..0332997523960e7d5ff2cb002335f464bb04a70b 100644
--- a/recipes/CommonVoice/ASR/CTC/hparams/train_it_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_it_with_wav2vec.yaml
@@ -33,12 +33,13 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 8.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 45
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -59,7 +60,7 @@ test_dataloader_options:
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 # activation: !name:torch.nn.LeakyReLU
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
@@ -77,16 +78,58 @@ blank_index: 0
 bos_index: 1
 eos_index: 2
 
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    #kenlm_model_path: none
 #
 # Functions and classes
 #
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -109,7 +152,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_en_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_pt_with_wav2vec.yaml
similarity index 59%
rename from recipes/CommonVoice/ASR/seq2seq/hparams/train_en_with_wav2vec.yaml
rename to recipes/CommonVoice/ASR/CTC/hparams/train_pt_with_wav2vec.yaml
index 5fd62fffd12598a34d75d388683a54928b09644a..d4b703eb456fc66104146de3e1b45e986d6eb9b0 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_en_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_pt_with_wav2vec.yaml
@@ -1,28 +1,27 @@
 # ################################
-# Model: wav2vec2 + DNN + CTC/Attention
+# Model: wav2vec2 + DNN + CTC
 # Augmentation: SpecAugment
-# Authors: Titouan Parcollet 2021
+# Authors: Pooneh Mousavi 2023
 # ################################
 
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 1234
 __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/wav2vec2_ctcatt_en/<seed>
+output_folder: !ref results/wav2vec2_ctc_pt/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
-# URL for the biggest HuggingFace english wav2vec2 model.
-wav2vec2_hub: facebook/wav2vec2-large-lv60
+# URL for the biggest Fairseq multilingual
+wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
 wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
-
 # Data files
 data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
 train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
 dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
 test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
-accented_letters: False
-language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+accented_letters: True
+language: pt # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
 train_csv: !ref <save_folder>/train.csv
 valid_csv: !ref <save_folder>/dev.csv
 test_csv: !ref <save_folder>/test.csv
@@ -32,21 +31,19 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 30
-number_of_ctc_epochs: 20
 lr: 1.0
 lr_wav2vec: 0.0001
-ctc_weight: 0.3
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
-# Must be 6 per GPU to fit 32GB of VRAM
+# Must be 8 per GPU to fit 32GB of VRAM
 batch_size: 12
 test_batch_size: 4
 
@@ -61,16 +58,13 @@ test_dataloader_options:
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
-dnn_layers: 2
 dnn_neurons: 1024
-emb_size: 128
-dec_neurons: 1024
-dec_hidden_size: !ref <dec_neurons>
-dec_attn_dim: !ref <dec_neurons>
 freeze_wav2vec: False
+freeze_feature_extractor: False
+dropout: 0.15
+warmup_steps: 500
 
 # Outputs
 output_neurons: 1000  # BPE size, index(blank/eos/bos) = 0
@@ -80,14 +74,21 @@ output_neurons: 1000  # BPE size, index(blank/eos/bos) = 0
 blank_index: 0
 bos_index: 1
 eos_index: 2
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 80
-eos_threshold: 1.5
-using_max_attn_shift: True
-max_attn_shift: 140
-# ctc_weight_decode: 0.0
-temperature: 1.50
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    # kenlm_model_path: none
 
 #
 # Functions and classes
@@ -95,20 +96,67 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
-    activation: !ref <activation>
-    dnn_blocks: !ref <dnn_layers>
-    dnn_neurons: !ref <dnn_neurons>
-
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+    linear1: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation: !new:torch.nn.LeakyReLU
+    drop: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear2: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation2: !new:torch.nn.LeakyReLU
+    drop2: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear3: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation3: !new:torch.nn.LeakyReLU
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
+    freeze_feature_extractor: !ref <freeze_feature_extractor>
     save_path: !ref <wav2vec2_folder>
 
 #####
@@ -124,51 +172,24 @@ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
 #    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
 #####
 
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dec_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_hidden_size>
-    attn_dim: !ref <dec_attn_dim>
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.15
 
 ctc_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dnn_neurons>
     n_neurons: !ref <output_neurons>
 
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>
-
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
 
 ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
     blank_index: !ref <blank_index>
 
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
 modules:
     wav2vec2: !ref <wav2vec2>
     enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
     ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
 
 model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
+    - [!ref <enc>, !ref <ctc_lin>]
 
 model_opt_class: !name:torch.optim.Adadelta
     lr: !ref <lr>
@@ -190,22 +211,6 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
     annealing_factor: 0.9
     patient: 0
 
-beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-    eos_threshold: !ref <eos_threshold>
-    using_max_attn_shift: !ref <using_max_attn_shift>
-    max_attn_shift: !ref <max_attn_shift>
-    temperature: !ref <temperature>
-
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
     recoverables:
diff --git a/recipes/CommonVoice/ASR/CTC/hparams/train_rw_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_rw_with_wav2vec.yaml
index b394ef9d8d3d63a2568b9950dd236dc02d32daea..ed15a8aadada5e157843d953264dc518abed1658 100644
--- a/recipes/CommonVoice/ASR/CTC/hparams/train_rw_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_rw_with_wav2vec.yaml
@@ -32,12 +32,13 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -59,7 +60,7 @@ test_dataloader_options:
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 # activation: !name:torch.nn.LeakyReLU
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
@@ -77,16 +78,58 @@ blank_index: 0
 bos_index: 1
 eos_index: 2
 
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    # kenlm_model_path: none
 #
 # Functions and classes
 #
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -109,7 +152,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/CommonVoice/ASR/CTC/hparams/train_zh-CN_with_wav2vec.yaml b/recipes/CommonVoice/ASR/CTC/hparams/train_zh-CN_with_wav2vec.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a1709931a46d9f6f2c7146e901749b2ac25b2ba1
--- /dev/null
+++ b/recipes/CommonVoice/ASR/CTC/hparams/train_zh-CN_with_wav2vec.yaml
@@ -0,0 +1,230 @@
+# ################################
+# Model: wav2vec2 + DNN + CTC
+# Augmentation: SpecAugment
+# Authors: Pooneh Mousavi 2023
+# ################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1234
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/wav2vec2_ctc_zh-CN/<seed>
+test_wer_file: !ref <output_folder>/wer_test.txt
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# URL for the biggest Fairseq multilingual
+wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
+wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
+test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
+accented_letters: True
+language: zh-CN # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev.csv
+test_csv: !ref <save_folder>/test.csv
+skip_prep: False # Skip data preparation
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+####################### Training Parameters ####################################
+
+number_of_epochs: 30
+lr: 1.0
+lr_wav2vec: 0.0001
+sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
+sample_rate: 16000
+ckpt_interval_minutes: 30 # save checkpoint every N min
+
+
+# With data_parallel batch_size is split into N jobs
+# With DDP batch_size is multiplied by N jobs
+# Must be 8 per GPU to fit 32GB of VRAM
+batch_size: 12
+test_batch_size: 4
+
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: 6
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    num_workers: 6
+
+# BPE parameters
+token_type: unigram  # ["unigram", "bpe", "char"]
+character_coverage: 1.0
+
+####################### Model Parameters #######################################
+wav2vec_output_dim: 1024
+dnn_neurons: 1024
+freeze_wav2vec: False
+freeze_feature_extractor: False
+dropout: 0.15
+warmup_steps: 500
+
+# Outputs
+output_neurons: 4652  # BPE size, index(blank/eos/bos) = 0
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+blank_index: 0
+bos_index: 1
+eos_index: 2
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+# Decoding parameters
+test_beam_search:
+    blank_index: !ref <blank_index>
+    beam_size: 100
+    beam_prune_logp: -12.0
+    token_prune_min_logp: -1.2
+    prune_history: True
+    topk: 1
+    alpha: 1.0
+    beta: 0.5
+    # To use n-gram LM for decoding, follow steps in README.md.
+    # kenlm_model_path: none
+
+#
+# Functions and classes
+#
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+enc: !new:speechbrain.nnet.containers.Sequential
+    input_shape: [null, null, !ref <wav2vec_output_dim>]
+    linear1: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation: !new:torch.nn.LeakyReLU
+    drop: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear2: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation2: !new:torch.nn.LeakyReLU
+    drop2: !new:torch.nn.Dropout
+        p: !ref <dropout>
+    linear3: !name:speechbrain.nnet.linear.Linear
+        n_neurons: !ref <dnn_neurons>
+        bias: True
+    bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
+    activation3: !new:torch.nn.LeakyReLU
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <wav2vec2_hub>
+    output_norm: True
+    freeze: !ref <freeze_wav2vec>
+    freeze_feature_extractor: !ref <freeze_feature_extractor>
+    save_path: !ref <wav2vec2_folder>
+
+#####
+# Uncomment this block if you prefer to use a Fairseq pretrained model instead
+# of a HuggingFace one. Here, we provide an URL that is obtained from the
+# Fairseq github for the multilingual XLSR.
+#
+#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
+#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
+#    pretrained_path: !ref <wav2vec2_url>
+#    output_norm: True
+#    freeze: False
+#    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
+#####
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <dnn_neurons>
+    n_neurons: !ref <output_neurons>
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+    apply_log: True
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+
+modules:
+    wav2vec2: !ref <wav2vec2>
+    enc: !ref <enc>
+    ctc_lin: !ref <ctc_lin>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <enc>, !ref <ctc_lin>]
+
+model_opt_class: !name:torch.optim.Adadelta
+    lr: !ref <lr>
+    rho: 0.95
+    eps: 1.e-8
+
+wav2vec_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_wav2vec>
+
+lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.8
+    patient: 0
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr_wav2vec>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.9
+    patient: 0
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        wav2vec2: !ref <wav2vec2>
+        model: !ref <model>
+        scheduler_model: !ref <lr_annealing_model>
+        scheduler_wav2vec: !ref <lr_annealing_wav2vec>
+        counter: !ref <epoch_counter>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+    split_tokens: True
diff --git a/recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py b/recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py
index 256e86f7e246bf305656dc2c1e73b773126cd56a..1749fb6ff03d361af731d5e44b0712ec8f058df4 100644
--- a/recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py
+++ b/recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py
@@ -46,9 +46,9 @@ class ASR(sb.core.Brain):
         tokens_bos, _ = batch.tokens_bos
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
 
         # Forward pass
         feats = self.modules.wav2vec2(wavs, wav_lens)
@@ -56,27 +56,39 @@ class ASR(sb.core.Brain):
         logits = self.modules.ctc_lin(x)
         p_ctc = self.hparams.log_softmax(logits)
 
-        return p_ctc, wav_lens
+        p_tokens = None
+        if stage == sb.Stage.VALID:
+            p_tokens = sb.decoders.ctc_greedy_decode(
+                p_ctc, wav_lens, blank_id=self.hparams.blank_index
+            )
+        elif stage == sb.Stage.TEST:
+            p_tokens = test_searcher(p_ctc, wav_lens)
+
+        return p_ctc, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (CTC) given predictions and targets."""
 
-        p_ctc, wav_lens = predictions
+        p_ctc, wav_lens, p_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens = self.hparams.wav_augment.replicate_labels(tokens)
+            tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens)
+
         loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
 
-        if stage != sb.Stage.TRAIN:
-            # Decode token terms to words
-            sequence = sb.decoders.ctc_greedy_decode(
-                p_ctc, wav_lens, blank_id=self.hparams.blank_index
-            )
+        if stage == sb.Stage.VALID:
+            # Convert token indices to words
+            predicted_words = self.tokenizer(p_tokens, task="decode_from_list")
 
-            predicted_words = self.tokenizer(sequence, task="decode_from_list")
+        elif stage == sb.Stage.TEST:
+            predicted_words = [hyp[0].text.split(" ") for hyp in p_tokens]
 
+        if stage != sb.Stage.TRAIN:
             # Convert indices to words
             target_words = undo_padding(tokens, tokens_lens)
             target_words = self.tokenizer(target_words, task="decode_from_list")
@@ -86,63 +98,6 @@ class ASR(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        should_step = self.step % self.grad_accumulation_factor == 0
-        # Managing automatic mixed precision
-        # TOFIX: CTC fine-tuning currently is unstable
-        # This is certainly due to CTC being done in fp16 instead of fp32
-        if self.auto_mix_prec:
-            with torch.cuda.amp.autocast():
-                with self.no_sync():
-                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            with self.no_sync(not should_step):
-                self.scaler.scale(
-                    loss / self.grad_accumulation_factor
-                ).backward()
-            if should_step:
-
-                if not self.hparams.wav2vec2.freeze:
-                    self.scaler.unscale_(self.wav2vec_optimizer)
-                self.scaler.unscale_(self.model_optimizer)
-                if self.check_gradients(loss):
-                    if not self.hparams.wav2vec2.freeze:
-                        if self.optimizer_step >= self.hparams.warmup_steps:
-                            self.scaler.step(self.wav2vec_optimizer)
-                    self.scaler.step(self.model_optimizer)
-                self.scaler.update()
-                self.zero_grad()
-                self.optimizer_step += 1
-        else:
-            # This is mandatory because HF models have a weird behavior with DDP
-            # on the forward pass
-            with self.no_sync():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-
-            with self.no_sync(not should_step):
-                (loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                if self.check_gradients(loss):
-                    if not self.hparams.wav2vec2.freeze:
-                        if self.optimizer_step >= self.hparams.warmup_steps:
-                            self.wav2vec_optimizer.step()
-                    self.model_optimizer.step()
-                self.zero_grad()
-                self.optimizer_step += 1
-
-        self.on_fit_batch_end(batch, outputs, loss, should_step)
-        return loss.detach().cpu()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -215,10 +170,24 @@ class ASR(sb.core.Brain):
         if self.checkpointer is not None:
             self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
 
-    def zero_grad(self, set_to_none=False):
         if not self.hparams.wav2vec2.freeze:
-            self.wav2vec_optimizer.zero_grad(set_to_none)
-        self.model_optimizer.zero_grad(set_to_none)
+            self.optimizers_dict = {
+                "wav2vec_optimizer": self.wav2vec_optimizer,
+                "model_optimizer": self.model_optimizer,
+            }
+        else:
+            self.optimizers_dict = {"model_optimizer": self.model_optimizer}
+
+    def freeze_optimizers(self, optimizers):
+        """Freezes the wav2vec2 optimizer according to the warmup steps"""
+        valid_optimizers = {}
+        if not self.hparams.wav2vec2.freeze:
+            if self.optimizer_step >= self.hparams.warmup_steps:
+                valid_optimizers["wav2vec_optimizer"] = optimizers[
+                    "wav2vec_optimizer"
+                ]
+        valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+        return valid_optimizers
 
 
 # Define custom data procedure
@@ -318,7 +287,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -370,6 +338,15 @@ if __name__ == "__main__":
 
     # Adding objects to trainer.
     asr_brain.tokenizer = tokenizer
+    vocab_list = [
+        tokenizer.sp.id_to_piece(i) for i in range(tokenizer.sp.vocab_size())
+    ]
+
+    from speechbrain.decoders.ctc import CTCBeamSearcher
+
+    test_searcher = CTCBeamSearcher(
+        **hparams["test_beam_search"], vocab_list=vocab_list,
+    )
 
     # Training
     asr_brain.fit(
diff --git a/recipes/CommonVoice/ASR/seq2seq/README.md b/recipes/CommonVoice/ASR/seq2seq/README.md
index de9b21f2c350a6f13508f31c1bccc809f96fd57a..7a1dd7b8e0827d5846580aaf4259558050da672f 100644
--- a/recipes/CommonVoice/ASR/seq2seq/README.md
+++ b/recipes/CommonVoice/ASR/seq2seq/README.md
@@ -1,6 +1,5 @@
 # CommonVoice ASR with CTC + Attention based Seq2Seq models.
-This folder contains scripts necessary to run an ASR experiment with the CommonVoice dataset: [CommonVoice Homepage](https://commonvoice.mozilla.org/)
-
+This folder contains scripts necessary to run an ASR experiment with the CommonVoice 14.0 dataset: [CommonVoice Homepage](https://commonvoice.mozilla.org/) and pytorch 2.0
 # How to run
 python train.py hparams/{hparam_file}.py
 
@@ -15,22 +14,20 @@ Here is a list of the different languages that we tested within the CommonVoice
 - Kinyarwanda
 - Italian
 - English
+- German
+- Spanish
 
 # Results
 
 | Language | CommonVoice Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | HuggingFace link | Model link | GPUs |
 | ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:| :-----------:|
-| French | 2020-12-11 | train_fr.yaml | No | 5.22 | 13.92 | 6.43 | 15.99 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-fr) | [model](https://www.dropbox.com/sh/nskc700cheejyu9/AADnRP1TO1Yh92jC-WuYCFf5a?dl=0) | 2xV100 16GB |
-| French | 2020-12-11 | train_fr_with_wav2vec.yaml | No | 6.13 | 11.82 | 9.78 | 13.34 | Not Avail. | 2xV100 32GB |
-| Kinyarwanda | 2020-12-11 | train_rw.yaml | No | 7.30 | 21.36 | 9.55 | 24.27 | Not Avail. | [model](https://www.dropbox.com/sh/glzq0hrqw2khcjq/AADfl_7ra0cLWi1VOzpy74NUa?dl=0) | 2xV100 32GB |
-| Kinyarwanda | 2020-12-11 | train_rw_with_wav2vec.yaml | No | 5.08 | 15.88 | 8.33 | 18.91 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-rw) | [model](https://www.dropbox.com/sh/x714xp9wq1a9azr/AADoxhS7JVDQ7IY1lYL7U8rJa?dl=0) | 2xV100 16GB |
-| English | 2020-12-11 | train_en.yaml | No | 8.66 | 20.16 | 12.93 | 24.89 | Not Avail. | [model](https://www.dropbox.com/sh/bdetfgii7xwscyj/AABYQL_eso8K1937QAg1GK66a?dl=0) | 2xV100 16GB |
-| English | 2020-12-11 | train_en_with_wav2vec.yaml | No | 14.50 | 13.21 | 24.65 | 15.69 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-en) | [model](https://www.dropbox.com/sh/s4khqzp7qe5iiaa/AADkWpWPE1UAbu2isycykcAAa?dl=0) | 2xV100 32GB |
-| Italian | 2020-12-11 | train_it.yaml | No | 5.14 | 15.59 | 15.40 | 16.61 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-it) | [model](https://www.dropbox.com/sh/438hgcl1wwdzmbo/AAAPZGfNXMztNYHyWwT_kk8la?dl=0) | 2xV100 16GB |
-| Italian | 2020-12-11 | train_it_with_wav2vec.yaml | No | 3.11 | 8.30 | 5.75 | 9.86 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-it) | [model](https://www.dropbox.com/sh/46z6xbydqjaxbb2/AACyQ_NaErAfev6JVjeaIWyea?dl=0) | 2xV100 16GB |
-| German | 2021-10-28 | train_de.yaml | No | 4.32 | 13.99 | 4.93 | 15.37 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-de) | -- | 1x V100 16GB |
-
-The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0).
+| French | 2023-08-15 | train_fr.yaml | No | 4.40 | 12.17 | 5.93 | 14.88 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-fr) | [model](https://www.dropbox.com/sh/07a5lt21wxp98x5/AABhNwmWFaNFyA734bNZUO03a?dl=0) | 1xV100 32GB |
+| Kinyarwanda | 2023-08-15 | train_rw.yaml | No | 6.75 | 23.66 | 10.80 | 29.22 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-rw) | [model](https://www.dropbox.com/sh/i1fv4f8miilqgii/AAB3gE97kmFDA0ISkIDSUW_La?dl=0) | 1xV100 32GB |
+| English | 2023-08-15 | train_en.yaml | No | 9.75 | 20.23 | 12.76 | 23.88 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-en) | [model](https://www.dropbox.com/sh/h8ged0yu3ztypkh/AAAu-12k_Ceg-tTjuZnrg7dza?dl=0) | 1xV100 32GB |
+| Italian | 2023-08-15 | train_it.yaml | No | 5.89 | 15.99 | 6.27 | 17.02 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-it) | [model](https://www.dropbox.com/sh/ss59uu0j5boscvp/AAASsiFhlB1nDWPkFX410bzna?dl=0) | 1xV100 32GB |
+| German | 2023-08-15 | train_de.yaml | No | 2.90 | 10.21 | 3.82 | 12.25 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-de) | [model](https://www.dropbox.com/sh/zgatirb118f79ef/AACmjh-D94nNDWcnVI4Ef5K7a?dl=0) | 1xV100 32GB |
+| Spanish | 2023-08-15 | train_es.yaml | No | 4.10 | 14.10 | 4.68 | 14.77 | [model](https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-es) | [model](https://www.dropbox.com/sh/r3w0b2tm1p73vft/AADCxdhUwDN6j4PVT9TYe-d5a?dl=0) | 1xV100 32GB |
+
 
 ## How to simply use pretrained models to transcribe my audio file?
 
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_de.yaml b/recipes/CommonVoice/ASR/seq2seq/hparams/train_de.yaml
index 88f53f91b9b625eb7625b21bd6f35df69b4b0105..cb6f2b3be095a6db5d948d76cc5beca09691d96f 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_de.yaml
+++ b/recipes/CommonVoice/ASR/seq2seq/hparams/train_de.yaml
@@ -30,12 +30,14 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 25
 number_of_ctc_epochs: 20
 lr: 1.0
 ctc_weight: 0.3
 sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
 ckpt_interval_minutes: 30
 
 # With data_parallel batch_size is split into N jobs
@@ -61,7 +63,7 @@ sample_rate: 16000
 n_fft: 400
 n_mels: 80
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 3
@@ -103,18 +105,34 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Frequency domain SpecAugment
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
+############################## Augmentations ###################################
+
+ # Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
 
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
@@ -124,6 +142,8 @@ compute_features: !new:speechbrain.lobes.features.Fbank
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
+############################## Models ##########################################
+
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
     input_shape: [null, null, !ref <n_mels>]
     activation: !ref <activation>
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_en.yaml b/recipes/CommonVoice/ASR/seq2seq/hparams/train_en.yaml
index 38b52969b21fcc8aedfaece8832bc1782acbce5b..49f9a0d2b76464c2dbc840a3a414fff0c8fd946e 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_en.yaml
+++ b/recipes/CommonVoice/ASR/seq2seq/hparams/train_en.yaml
@@ -29,12 +29,14 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 25
 number_of_ctc_epochs: 10
 lr: 1.0
 ctc_weight: 0.3
 sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
@@ -52,13 +54,14 @@ test_dataloader_options:
 # BPE parameters
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
+label_smoothing: 0.1
 
 # Feature parameters (FBANKS etc)
 sample_rate: 16000
 n_fft: 400
 n_mels: 80
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 3
@@ -100,18 +103,34 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Frequency domain SpecAugment
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
+############################## Augmentations ###################################
+
+ # Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
 
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
@@ -121,6 +140,8 @@ compute_features: !new:speechbrain.lobes.features.Fbank
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
+############################## Models ##########################################
+
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
     input_shape: [null, null, !ref <n_mels>]
     activation: !ref <activation>
@@ -173,7 +194,7 @@ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
     blank_index: !ref <blank_index>
 
 seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
+    label_smoothing: !ref <label_smoothing>
 
 modules:
     enc: !ref <enc>
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_fr_with_wav2vec.yaml b/recipes/CommonVoice/ASR/seq2seq/hparams/train_es.yaml
similarity index 60%
rename from recipes/CommonVoice/ASR/seq2seq/hparams/train_fr_with_wav2vec.yaml
rename to recipes/CommonVoice/ASR/seq2seq/hparams/train_es.yaml
index cf95e8878767ab52f0015a17eea64ac04a99a686..b94373e9bd0028e1668b37da3736259516fe14e1 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_fr_with_wav2vec.yaml
+++ b/recipes/CommonVoice/ASR/seq2seq/hparams/train_es.yaml
@@ -1,28 +1,26 @@
 # ################################
-# Model: wav2vec2 + DNN + CTC
+# Model: VGG2 + LSTM + time pooling
 # Augmentation: SpecAugment
-# Authors: Titouan Parcollet 2021
+# Authors: Titouan Parcollet, Mirco Ravanelli, Peter Plantinga, Ju-Chieh Chou,
+# and Abdel HEBA 2020
+# edited: Andreas Nautsch, 2021
 # ################################
 
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 1234
 __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/wav2vec2_ctcatt_fr/<seed>
+output_folder: !ref results/CRDNN_es/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
-# URL for the biggest HuggingFace LeBenchmarh french w2v2
-wav2vec2_hub: LeBenchmark/wav2vec2-FR-7K-large
-wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
-
 # Data files
-data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-7.0-2021-07-21/de
 train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
 dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
 test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
 accented_letters: True
-language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+language: es # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
 train_csv: !ref <save_folder>/train.csv
 valid_csv: !ref <save_folder>/dev.csv
 test_csv: !ref <save_folder>/test.csv
@@ -32,23 +30,20 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
-number_of_epochs: 30
-number_of_ctc_epochs: 15
+####################### Training Parameters ####################################
+
+number_of_epochs: 25
+number_of_ctc_epochs: 20
 lr: 1.0
-lr_wav2vec: 0.0001
 ctc_weight: 0.3
 sorting: ascending
-auto_mix_prec: False
-sample_rate: 16000
-ckpt_interval_minutes: 30 # save checkpoint every N min
-
+precision: fp32 # bf16, fp16 or fp32
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
 # Must be 6 per GPU to fit 16GB of VRAM
-batch_size: 12
-test_batch_size: 4
+batch_size: 8
+test_batch_size: 6
 
 dataloader_options:
     batch_size: !ref <batch_size>
@@ -60,17 +55,31 @@ test_dataloader_options:
 # BPE parameters
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
+label_smoothing: 0.1
 
-# Model parameters
+# Feature parameters (FBANKS etc)
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
-wav2vec_output_dim: 1024
-dnn_layers: 2
+dropout: 0.15
+cnn_blocks: 3
+cnn_channels: (128, 200, 256)
+inter_layer_pooling_size: (2, 2, 2)
+cnn_kernelsize: (3, 3)
+time_pooling_size: 4
+rnn_class: !name:speechbrain.nnet.RNN.LSTM
+rnn_layers: 5
+rnn_neurons: 1024
+rnn_bidirectional: True
+dnn_blocks: 2
 dnn_neurons: 1024
 emb_size: 128
 dec_neurons: 1024
 dec_hidden_size: !ref <dec_neurons>
 dec_attn_dim: !ref <dec_neurons>
-freeze_wav2vec: False
 
 # Outputs
 output_neurons: 500  # BPE size, index(blank/eos/bos) = 0
@@ -78,8 +87,8 @@ output_neurons: 500  # BPE size, index(blank/eos/bos) = 0
 # Decoding parameters
 # Be sure that the bos and eos index match with the BPEs ones
 blank_index: 0
-bos_index: 1
-eos_index: 2
+bos_index: 0
+eos_index: 0
 min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 beam_size: 80
@@ -95,35 +104,64 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
+############################## Augmentations ###################################
+
+ # Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+
+compute_features: !new:speechbrain.lobes.features.Fbank
     sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mels>
 
-enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
-    input_shape: [null, null, !ref <wav2vec_output_dim>]
+############################## Models ##########################################
+
+enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
+    input_shape: [null, null, !ref <n_mels>]
     activation: !ref <activation>
-    dnn_blocks: !ref <dnn_layers>
+    dropout: !ref <dropout>
+    cnn_blocks: !ref <cnn_blocks>
+    cnn_channels: !ref <cnn_channels>
+    cnn_kernelsize: !ref <cnn_kernelsize>
+    inter_layer_pooling_size: !ref <inter_layer_pooling_size>
+    time_pooling: True
+    using_2d_pooling: False
+    time_pooling_size: !ref <time_pooling_size>
+    rnn_class: !ref <rnn_class>
+    rnn_layers: !ref <rnn_layers>
+    rnn_neurons: !ref <rnn_neurons>
+    rnn_bidirectional: !ref <rnn_bidirectional>
+    rnn_re_init: True
+    dnn_blocks: !ref <dnn_blocks>
     dnn_neurons: !ref <dnn_neurons>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
-    source: !ref <wav2vec2_hub>
-    output_norm: True
-    freeze: !ref <freeze_wav2vec>
-    save_path: !ref <wav2vec2_folder>
-
-#####
-# Uncomment this block if you prefer to use a Fairseq pretrained model instead
-# of a HuggingFace one. Here, we provide an URL that is obtained from the
-# Fairseq github for the multilingual XLSR.
-#
-#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
-#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
-#    pretrained_path: !ref <wav2vec2_url>
-#    output_norm: True
-#    freeze: False
-#    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
-#####
-
 emb: !new:speechbrain.nnet.embedding.Embedding
     num_embeddings: !ref <output_neurons>
     embedding_dim: !ref <emb_size>
@@ -140,7 +178,7 @@ dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
     channels: 10
     kernel_size: 100
     re_init: True
-    dropout: 0.15
+    dropout: !ref <dropout>
 
 ctc_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dnn_neurons>
@@ -157,47 +195,36 @@ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
     blank_index: !ref <blank_index>
 
 seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
+    label_smoothing: !ref <label_smoothing>
 
 modules:
-    wav2vec2: !ref <wav2vec2>
     enc: !ref <enc>
     emb: !ref <emb>
     dec: !ref <dec>
     ctc_lin: !ref <ctc_lin>
     seq_lin: !ref <seq_lin>
+    normalize: !ref <normalize>
 
 model: !new:torch.nn.ModuleList
     - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
 
-model_opt_class: !name:torch.optim.Adadelta
+opt_class: !name:torch.optim.Adadelta
     lr: !ref <lr>
     rho: 0.95
     eps: 1.e-8
 
-wav2vec_opt_class: !name:torch.optim.Adam
-    lr: !ref <lr_wav2vec>
-
-lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
+lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
     initial_value: !ref <lr>
     improvement_threshold: 0.0025
     annealing_factor: 0.8
     patient: 0
 
-lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr_wav2vec>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.9
-    patient: 0
-
 beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     embedding: !ref <emb>
     decoder: !ref <dec>
     linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <beam_size>
@@ -209,10 +236,9 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
     recoverables:
-        wav2vec2: !ref <wav2vec2>
         model: !ref <model>
-        scheduler_model: !ref <lr_annealing_model>
-        scheduler_wav2vec: !ref <lr_annealing_wav2vec>
+        scheduler: !ref <lr_annealing>
+        normalizer: !ref <normalize>
         counter: !ref <epoch_counter>
 
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_fr.yaml b/recipes/CommonVoice/ASR/seq2seq/hparams/train_fr.yaml
index f32876afe2f5b6487373b2a4e5ef3833e6b55247..cc9b0aa995ece9f1e011d0671e59b7fa63780389 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_fr.yaml
+++ b/recipes/CommonVoice/ASR/seq2seq/hparams/train_fr.yaml
@@ -29,12 +29,14 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 25
 number_of_ctc_epochs: 20
 lr: 1.0
 ctc_weight: 0.3
 sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
@@ -59,7 +61,7 @@ sample_rate: 16000
 n_fft: 400
 n_mels: 80
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 3
@@ -101,18 +103,34 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Frequency domain SpecAugment
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
+############################## Augmentations ###################################
+
+ # Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
 
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
@@ -122,6 +140,8 @@ compute_features: !new:speechbrain.lobes.features.Fbank
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
+############################## Models ##########################################
+
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
     input_shape: [null, null, !ref <n_mels>]
     activation: !ref <activation>
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_it.yaml b/recipes/CommonVoice/ASR/seq2seq/hparams/train_it.yaml
index 52cba3477558b7b3bbe3cb6d867057db3a6f7447..2c0355ae55c9fa226a43f7081309c95fb3559814 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_it.yaml
+++ b/recipes/CommonVoice/ASR/seq2seq/hparams/train_it.yaml
@@ -29,12 +29,14 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 8.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 50
 number_of_ctc_epochs: 40
 lr: 1.0
 ctc_weight: 0.3
 sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
@@ -58,7 +60,7 @@ sample_rate: 16000
 n_fft: 400
 n_mels: 80
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 3
@@ -100,18 +102,34 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Frequency domain SpecAugment
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
+############################## Augmentations ###################################
+
+ # Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
 
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
@@ -121,6 +139,8 @@ compute_features: !new:speechbrain.lobes.features.Fbank
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
+############################## Models ##########################################
+
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
     input_shape: [null, null, !ref <n_mels>]
     activation: !ref <activation>
diff --git a/recipes/CommonVoice/ASR/seq2seq/hparams/train_rw.yaml b/recipes/CommonVoice/ASR/seq2seq/hparams/train_rw.yaml
index 03082e4a32c724940c268f6a6fe664211760ff6b..8bc89c1c4bbced9dc62234b0dfbe0507ffd50374 100644
--- a/recipes/CommonVoice/ASR/seq2seq/hparams/train_rw.yaml
+++ b/recipes/CommonVoice/ASR/seq2seq/hparams/train_rw.yaml
@@ -29,12 +29,14 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 8.0
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 25
 number_of_ctc_epochs: 20
 lr: 1.0
 ctc_weight: 0.3
 sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
@@ -58,7 +60,7 @@ sample_rate: 16000
 n_fft: 400
 n_mels: 80
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 3
@@ -100,18 +102,35 @@ temperature: 1.50
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Frequency domain SpecAugment
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
+
+############################## Augmentations ###################################
+
+ # Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
 
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
@@ -121,6 +140,8 @@ compute_features: !new:speechbrain.lobes.features.Fbank
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
+############################## Models ##########################################
+
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
     input_shape: [null, null, !ref <n_mels>]
     activation: !ref <activation>
diff --git a/recipes/CommonVoice/ASR/seq2seq/train.py b/recipes/CommonVoice/ASR/seq2seq/train.py
index 3fefeb0deec8e8baece84dfe1a8205c11482a180..ebfbdb676660ccb833b944af1c1cb2ddaad30670 100644
--- a/recipes/CommonVoice/ASR/seq2seq/train.py
+++ b/recipes/CommonVoice/ASR/seq2seq/train.py
@@ -1,14 +1,3 @@
-#!/usr/bin/env python3
-import sys
-import torch
-import logging
-import speechbrain as sb
-import torchaudio
-from hyperpyyaml import load_hyperpyyaml
-from speechbrain.tokenizers.SentencePiece import SentencePiece
-from speechbrain.utils.data_utils import undo_padding
-from speechbrain.utils.distributed import run_on_main, if_main_process
-
 """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
 The system employs an encoder, a decoder, and an attention mechanism
 between them. Decoding is performed with beamsearch.
@@ -23,10 +12,21 @@ different systems. By properly changing the parameter files, you can try
 different encoders, decoders, tokens (e.g, characters instead of BPE),
 training languages (all CommonVoice languages), and many
 other possible variations.
+
 Authors
  * Titouan Parcollet 2020
 """
 
+import sys
+import torch
+import logging
+import speechbrain as sb
+import torchaudio
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.tokenizers.SentencePiece import SentencePiece
+from speechbrain.utils.data_utils import undo_padding
+from speechbrain.utils.distributed import run_on_main, if_main_process
+
 logger = logging.getLogger(__name__)
 
 
@@ -40,14 +40,19 @@ class ASR(sb.core.Brain):
         tokens_bos, _ = batch.tokens_bos
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
+
         # Forward pass
         feats = self.hparams.compute_features(wavs)
         feats = self.modules.normalize(feats, wav_lens)
 
-        ## Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                feats = self.hparams.augmentation(feats)
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)
 
         x = self.modules.enc(feats.detach())
         e_in = self.modules.emb(tokens_bos)  # y_in bos + tokens
@@ -57,40 +62,57 @@ class ASR(sb.core.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
+        p_ctc, p_tokens = None, None
         if stage == sb.Stage.TRAIN:
             current_epoch = self.hparams.epoch_counter.current
             if current_epoch <= self.hparams.number_of_ctc_epochs:
                 # Output layer for ctc log-probabilities
                 logits = self.modules.ctc_lin(x)
                 p_ctc = self.hparams.log_softmax(logits)
-                return p_ctc, p_seq, wav_lens
-            else:
-                return p_seq, wav_lens
         else:
-            p_tokens, scores = self.hparams.beam_searcher(x, wav_lens)
-            return p_seq, wav_lens, p_tokens
+            p_tokens, _, _, _ = self.hparams.beam_searcher(x, wav_lens)
+
+        return p_ctc, p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (CTC+NLL) given predictions and targets."""
 
-        current_epoch = self.hparams.epoch_counter.current
-        if stage == sb.Stage.TRAIN:
-            if current_epoch <= self.hparams.number_of_ctc_epochs:
-                p_ctc, p_seq, wav_lens = predictions
-            else:
-                p_seq, wav_lens = predictions
-        else:
-            p_seq, wav_lens, predicted_tokens = predictions
+        p_ctc, p_seq, wav_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "wav_augment"):
+                tokens = self.hparams.wav_augment.replicate_labels(tokens)
+                tokens_lens = self.hparams.wav_augment.replicate_labels(
+                    tokens_lens
+                )
+                tokens_eos = self.hparams.wav_augment.replicate_labels(
+                    tokens_eos
+                )
+                tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                    tokens_eos_lens
+                )
+            if hasattr(self.hparams, "fea_augment"):
+                tokens = self.hparams.fea_augment.replicate_labels(tokens)
+                tokens_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_lens
+                )
+                tokens_eos = self.hparams.fea_augment.replicate_labels(
+                    tokens_eos
+                )
+                tokens_eos_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_eos_lens
+                )
+
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
         )
 
         # Add ctc loss if necessary
+        current_epoch = self.hparams.epoch_counter.current
         if (
             stage == sb.Stage.TRAIN
             and current_epoch <= self.hparams.number_of_ctc_epochs
@@ -118,23 +140,6 @@ class ASR(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -272,7 +277,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py b/recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py
index f4492088e10a735a64cd6ca0f7dc9f1a57e9b168..ec2cb2e585843b32779b72c4ad2d1f822a093ce6 100644
--- a/recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py
+++ b/recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py
@@ -48,9 +48,10 @@ class ASR(sb.core.Brain):
         tokens_bos, _ = batch.tokens_bos
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
 
         # Forward pass
         feats = self.modules.wav2vec2(wavs, wav_lens)
@@ -63,40 +64,42 @@ class ASR(sb.core.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
+        p_ctc, p_tokens = None, None
         if stage == sb.Stage.TRAIN:
             current_epoch = self.hparams.epoch_counter.current
             if current_epoch <= self.hparams.number_of_ctc_epochs:
                 # Output layer for ctc log-probabilities
                 logits = self.modules.ctc_lin(x)
                 p_ctc = self.hparams.log_softmax(logits)
-                return p_ctc, p_seq, wav_lens
-            else:
-                return p_seq, wav_lens
         else:
-            p_tokens, scores = self.hparams.beam_searcher(x, wav_lens)
-            return p_seq, wav_lens, p_tokens
+            p_tokens, _, _, _ = self.hparams.beam_searcher(x, wav_lens)
+
+        return p_ctc, p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (CTC+NLL) given predictions and targets."""
 
-        current_epoch = self.hparams.epoch_counter.current
-        if stage == sb.Stage.TRAIN:
-            if current_epoch <= self.hparams.number_of_ctc_epochs:
-                p_ctc, p_seq, wav_lens = predictions
-            else:
-                p_seq, wav_lens = predictions
-        else:
-            p_seq, wav_lens, predicted_tokens = predictions
+        p_ctc, p_seq, wav_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
+        # Augment Labels
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens = self.hparams.wav_augment.replicate_labels(tokens)
+            tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens)
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
+            )
+
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
         )
 
         # Add ctc loss if necessary
+        current_epoch = self.hparams.epoch_counter.current
         if (
             stage == sb.Stage.TRAIN
             and current_epoch <= self.hparams.number_of_ctc_epochs
@@ -124,53 +127,6 @@ class ASR(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        if self.auto_mix_prec:
-
-            if not self.hparams.wav2vec2.freeze:
-                self.wav2vec_optimizer.zero_grad()
-            self.model_optimizer.zero_grad()
-
-            with torch.cuda.amp.autocast():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-
-            self.scaler.scale(loss).backward()
-            if not self.hparams.wav2vec2.freeze:
-                self.scaler.unscale_(self.wav2vec_optimizer)
-            self.scaler.unscale_(self.model_optimizer)
-
-            if self.check_gradients(loss):
-                if not self.hparams.wav2vec2.freeze:
-                    self.scaler.step(self.wav2vec_optimizer)
-                self.scaler.step(self.model_optimizer)
-
-            self.scaler.update()
-        else:
-            outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            loss.backward()
-
-            if self.check_gradients(loss):
-                if not self.hparams.wav2vec2.freeze:
-                    self.wav2vec_optimizer.step()
-                self.model_optimizer.step()
-
-            if not self.hparams.wav2vec2.freeze:
-                self.wav2vec_optimizer.zero_grad()
-            self.model_optimizer.zero_grad()
-
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -240,10 +196,23 @@ class ASR(sb.core.Brain):
         if self.checkpointer is not None:
             self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
 
-    def zero_grad(self, set_to_none=False):
-        self.model_optimizer.zero_grad(set_to_none)
         if not self.hparams.wav2vec2.freeze:
-            self.wav2vec_optimizer.zero_grad(set_to_none)
+            self.optimizers_dict = {
+                "wav2vec_optimizer": self.wav2vec_optimizer,
+                "model_optimizer": self.model_optimizer,
+            }
+        else:
+            self.optimizers_dict = {"model_optimizer": self.model_optimizer}
+
+    def freeze_optimizers(self, optimizers):
+        """Freezes the wav2vec2 optimizer according to the warmup steps"""
+        valid_optimizers = {}
+        if not self.hparams.wav2vec2.freeze:
+            valid_optimizers["wav2vec_optimizer"] = optimizers[
+                "wav2vec_optimizer"
+            ]
+        valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+        return valid_optimizers
 
 
 # Define custom data procedure
@@ -343,7 +312,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/CommonVoice/ASR/transducer/README.md b/recipes/CommonVoice/ASR/transducer/README.md
index 27b6407c5d097aa4c1727272640524b5527afe2f..219a7f270237d4c3861485ea92309cf542cf1eca 100644
--- a/recipes/CommonVoice/ASR/transducer/README.md
+++ b/recipes/CommonVoice/ASR/transducer/README.md
@@ -1,5 +1,5 @@
 # CommonVoice ASR with Transducers.
-This folder contains scripts necessary to run an ASR experiment with the CommonVoice dataset: [CommonVoice Homepage](https://commonvoice.mozilla.org/)
+This folder contains scripts necessary to run an ASR experiment with the CommonVoice 14.0 dataset: [CommonVoice Homepage](https://commonvoice.mozilla.org/) and pytorch 2.0
 
 # Extra-Dependencies
 This recipe support two implementation of Transducer loss, see `use_torchaudio` arg in Yaml file:
@@ -20,12 +20,16 @@ It is important to note that CommonVoice initially offers mp3 audio files at 42H
 Here is a list of the different languages that we tested within the CommonVoice dataset
 with our transducers:
 - French
+- Italian
+- German
 
 # Results
 
 | Language | Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | Model link | GPUs |
 | ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:|
-| French | 2020-06-22 | train_fr.yaml | No | 6.70 | 18.97 | 7.41 | 20.18 | [model](https://www.dropbox.com/sh/mp5w1asmuy88vhr/AABF9fFIOh3AIBTP-Xn-7c8_a?dl=0) | 2xV100 16GB |
+| French | 2023-08-15 | train_fr.yaml | No | 5.75 | 14.53 | 7.61 | 17.58 | [model](https://huggingface.co/speechbrain/asr-transducer-commonvoice-14-fr) | [model](https://www.dropbox.com/sh/nv2pnpo5n3besn3/AADZ7l41oLt11ZuOE4MqoJhCa?dl=0) | 1xV100 16GB |
+| Italian | 2023-08-15 | train_it.yaml | No | 4.66 | 14.08 | 5.11 | 14.88 | [model](https://huggingface.co/speechbrain/asr-transducer-commonvoice-14-it) | [model](https://www.dropbox.com/sh/ksm08x0wwiomrgs/AABnjPePWGPxqIqW7bJHp1jea?dl=0) | 1xV100 16GB |
+| German | 2023-08-15 | train_de.yaml | No | 4.32 | 13.09 | 5.43 | 15.25 | [model](https://huggingface.co/speechbrain/asr-transducer-commonvoice-14-de) | [model](https://www.dropbox.com/sh/jfge6ixbtoje64t/AADeAgL5un0A8uEjPSM84ex8a?dl=0) | 1xV100 16GB |
 
 The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0).
 
diff --git a/recipes/CommonVoice/ASR/transducer/hparams/train_de.yaml b/recipes/CommonVoice/ASR/transducer/hparams/train_de.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9bbab166930f6e31a086be82c595a307961db551
--- /dev/null
+++ b/recipes/CommonVoice/ASR/transducer/hparams/train_de.yaml
@@ -0,0 +1,264 @@
+# ############################################################################
+# Model: E2E ASR with attention-based ASR
+# Encoder: CRDNN model
+# Decoder: GRU + beamsearch + RNNLM
+# Tokens: BPE with unigram
+# losses: Transducer
+# Training: Librispeech 100h
+# Authors:  Abdel HEBA, Mirco Ravanelli, Sung-Lin Yeh 2020
+# ############################################################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1234
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/cv_transducer_de/<seed>
+test_wer_file: !ref <output_folder>/wer_test.txt
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
+test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
+accented_letters: True
+language: de # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev.csv
+test_csv: !ref <save_folder>/test.csv
+skip_prep: False # Skip data preparation
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+####################### Training Parameters ####################################
+number_of_epochs: 30
+batch_size: 6
+batch_size_valid: 1
+lr: 1.0
+sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
+ckpt_interval_minutes: 15 # save checkpoint every N min
+# MTL for encoder with CTC (uncomment enc_lin layer)
+#number_of_ctc_epochs: 2
+#ctc_weight: 0.33
+# MTL for decoder with CE (uncomment dec_lin layer)
+#number_of_ce_epochs: 2
+#ce_weight: 0.33
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+opt_class: !name:torch.optim.Adadelta
+   lr: !ref <lr>
+   rho: 0.95
+   eps: 1.e-8
+
+# BPE parameters
+token_type: unigram  # ["unigram", "bpe", "char"]
+character_coverage: 1.0
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+valid_dataloader_opts:
+   batch_size: !ref <batch_size_valid>
+
+test_dataloader_opts:
+   batch_size: !ref <batch_size_valid>
+
+####################### Model Parameters #######################################
+activation: !name:torch.nn.LeakyReLU
+dropout: 0.15
+cnn_blocks: 3
+cnn_channels: (128, 200, 256)
+inter_layer_pooling_size: (2, 2, 2)
+cnn_kernelsize: (3, 3)
+time_pooling_size: 4
+rnn_class: !name:speechbrain.nnet.RNN.LSTM
+rnn_layers: 5
+rnn_neurons: 1024
+rnn_bidirectional: True
+dnn_blocks: 2
+dnn_neurons: 1024
+dec_neurons: 1024
+output_neurons: 1000  # index(blank/eos/bos) = 0
+joint_dim: 1024
+blank_index: 0
+
+# Decoding parameters
+beam_size: 4
+nbest: 1
+# by default {state,expand}_beam = 2.3 as mention in paper
+# https://arxiv.org/abs/1904.02619
+state_beam: 2.3
+expand_beam: 2.3
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+   limit: !ref <number_of_epochs>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+   norm_type: global
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+   sample_rate: !ref <sample_rate>
+   n_fft: !ref <n_fft>
+   n_mels: !ref <n_mels>
+
+############################## Augmentations ###################################
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 15
+   drop_length_high: 25
+   drop_count_low: 5
+   drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 25
+   drop_length_high: 35
+   drop_count_low: 3
+   drop_count_high: 3
+   dim: 2
+
+# Time warp
+
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+   min_augmentations: 3
+   max_augmentations: 3
+   augment_prob: 1.0
+   augmentations: [
+      !ref <time_drop>,
+      !ref <freq_drop>,
+      !ref <time_warp>]
+
+############################## Models ##########################################
+
+enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
+   input_shape: [null, null, !ref <n_mels>]
+   activation: !ref <activation>
+   dropout: !ref <dropout>
+   cnn_blocks: !ref <cnn_blocks>
+   cnn_channels: !ref <cnn_channels>
+   cnn_kernelsize: !ref <cnn_kernelsize>
+   inter_layer_pooling_size: !ref <inter_layer_pooling_size>
+   time_pooling: True
+   using_2d_pooling: False
+   time_pooling_size: !ref <time_pooling_size>
+   rnn_class: !ref <rnn_class>
+   rnn_layers: !ref <rnn_layers>
+   rnn_neurons: !ref <rnn_neurons>
+   rnn_bidirectional: !ref <rnn_bidirectional>
+   rnn_re_init: True
+   dnn_blocks: !ref <dnn_blocks>
+   dnn_neurons: !ref <dnn_neurons>
+
+# For MTL CTC over the encoder
+enc_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <dnn_neurons>
+   n_neurons: !ref <joint_dim>
+
+# Uncomment for MTL with CTC
+# ctc_cost: !name:speechbrain.nnet.ctc_loss
+#    blank_index: !ref <blank_index>
+
+emb: !new:speechbrain.nnet.embedding.Embedding
+   num_embeddings: !ref <output_neurons>
+   consider_as_one_hot: True
+   blank_id: !ref <blank_index>
+
+dec: !new:speechbrain.nnet.RNN.GRU
+   input_shape: [null, null, !ref <output_neurons> - 1]
+   hidden_size: !ref <dec_neurons>
+   num_layers: 1
+   re_init: True
+
+# For MTL with LM over the decoder
+dec_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <dec_neurons>
+   n_neurons: !ref <joint_dim>
+   bias: False
+
+# For MLT with CTC
+#ce_cost: !name:speechbrain.nnet.losses.nll_loss
+#   label_smoothing: 0.1
+
+Tjoint: !new:speechbrain.nnet.transducer.transducer_joint.Transducer_joint
+   joint: sum # joint [sum | concat]
+   nonlinearity: !ref <activation>
+
+transducer_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <joint_dim>
+   n_neurons: !ref <output_neurons>
+   bias: False
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+   apply_log: True
+
+transducer_cost: !name:speechbrain.nnet.losses.transducer_loss
+   use_torchaudio: True
+   blank_index: !ref <blank_index>
+
+# for MTL
+# update model if any HEAD module is added
+modules:
+   enc: !ref <enc>
+   enc_lin: !ref <enc_lin>
+   emb: !ref <emb>
+   dec: !ref <dec>
+   dec_lin: !ref <dec_lin>
+   Tjoint: !ref <Tjoint>
+   transducer_lin: !ref <transducer_lin>
+   normalize: !ref <normalize>
+
+# for MTL
+# update model if any HEAD module is added
+model: !new:torch.nn.ModuleList
+   - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <transducer_lin>]
+
+# greedy_searcher: !new:speechbrain.decoders.transducer.TransducerBeamSearcher
+#   decode_network_lst: [!ref <emb>, !ref <dec>]
+#   tjoint: !ref <Tjoint>
+#   classifier_network: [!ref <transducer_lin>]
+#   blank_id: !ref <blank_index>
+#   beam_size: 1
+#   nbest: 1
+
+beam_searcher: !new:speechbrain.decoders.transducer.TransducerBeamSearcher
+   decode_network_lst: [!ref <emb>, !ref <dec>]
+   tjoint: !ref <Tjoint>
+   classifier_network: [!ref <transducer_lin>]
+   blank_id: !ref <blank_index>
+   beam_size: !ref <beam_size>
+   nbest: !ref <nbest>
+   state_beam: !ref <state_beam>
+   expand_beam: !ref <expand_beam>
+
+lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
+   initial_value: !ref <lr>
+   improvement_threshold: 0.0025
+   annealing_factor: 0.8
+   patient: 0
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+   checkpoints_dir: !ref <save_folder>
+   recoverables:
+      model: !ref <model>
+      scheduler: !ref <lr_annealing>
+      normalizer: !ref <normalize>
+      counter: !ref <epoch_counter>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+   save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+   split_tokens: True
diff --git a/recipes/CommonVoice/ASR/transducer/hparams/train_fr.yaml b/recipes/CommonVoice/ASR/transducer/hparams/train_fr.yaml
index c137724ab469a374a53e72a4ccdfbcfceb42589e..c96a0939466a88d7e364efe5d69226f020d702e2 100644
--- a/recipes/CommonVoice/ASR/transducer/hparams/train_fr.yaml
+++ b/recipes/CommonVoice/ASR/transducer/hparams/train_fr.yaml
@@ -11,7 +11,7 @@
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 1234
 __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/cv_transducer/<seed>
+output_folder: !ref results/cv_transducer_fr/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -32,12 +32,13 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 batch_size: 6
 batch_size_valid: 1
 lr: 1.0
 sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
 ckpt_interval_minutes: 15 # save checkpoint every N min
 # MTL for encoder with CTC (uncomment enc_lin layer)
 #number_of_ctc_epochs: 2
@@ -70,7 +71,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
    batch_size: !ref <batch_size_valid>
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 3
@@ -108,18 +109,37 @@ compute_features: !new:speechbrain.lobes.features.Fbank
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mels>
 
-# Frequency domain SpecAugment
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-   time_warp: True
-   time_warp_window: 5
-   time_warp_mode: bicubic
-   freq_mask: True
-   n_freq_mask: 2
-   time_mask: True
-   n_time_mask: 2
-   replace_with_zero: False
-   freq_mask_width: 30
-   time_mask_width: 40
+############################## Augmentations ###################################
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 15
+   drop_length_high: 25
+   drop_count_low: 5
+   drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 25
+   drop_length_high: 35
+   drop_count_low: 3
+   drop_count_high: 3
+   dim: 2
+
+# Time warp
+
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+   min_augmentations: 3
+   max_augmentations: 3
+   augment_prob: 1.0
+   augmentations: [
+      !ref <time_drop>,
+      !ref <freq_drop>,
+      !ref <time_warp>]
+
+############################## Models ##########################################
 
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
    input_shape: [null, null, !ref <n_mels>]
@@ -197,7 +217,6 @@ modules:
    Tjoint: !ref <Tjoint>
    transducer_lin: !ref <transducer_lin>
    normalize: !ref <normalize>
-   augmentation: !ref <augmentation>
 
 # for MTL
 # update model if any HEAD module is added
diff --git a/recipes/CommonVoice/ASR/transducer/hparams/train_it.yaml b/recipes/CommonVoice/ASR/transducer/hparams/train_it.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cf366205efffe515ddc91d4a31bb48d75f505bc3
--- /dev/null
+++ b/recipes/CommonVoice/ASR/transducer/hparams/train_it.yaml
@@ -0,0 +1,264 @@
+# ############################################################################
+# Model: E2E ASR with attention-based ASR
+# Encoder: CRDNN model
+# Decoder: GRU + beamsearch + RNNLM
+# Tokens: BPE with unigram
+# losses: Transducer
+# Training: Librispeech 100h
+# Authors:  Abdel HEBA, Mirco Ravanelli, Sung-Lin Yeh 2020
+# ############################################################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1234
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/cv_transducer_it/<seed>
+test_wer_file: !ref <output_folder>/wer_test.txt
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
+test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
+accented_letters: True
+language: it # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev.csv
+test_csv: !ref <save_folder>/test.csv
+skip_prep: False # Skip data preparation
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+####################### Training Parameters ####################################
+number_of_epochs: 30
+batch_size: 6
+batch_size_valid: 1
+lr: 1.0
+sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
+ckpt_interval_minutes: 15 # save checkpoint every N min
+# MTL for encoder with CTC (uncomment enc_lin layer)
+#number_of_ctc_epochs: 2
+#ctc_weight: 0.33
+# MTL for decoder with CE (uncomment dec_lin layer)
+#number_of_ce_epochs: 2
+#ce_weight: 0.33
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+opt_class: !name:torch.optim.Adadelta
+   lr: !ref <lr>
+   rho: 0.95
+   eps: 1.e-8
+
+# BPE parameters
+token_type: unigram  # ["unigram", "bpe", "char"]
+character_coverage: 1.0
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+valid_dataloader_opts:
+   batch_size: !ref <batch_size_valid>
+
+test_dataloader_opts:
+   batch_size: !ref <batch_size_valid>
+
+####################### Model Parameters #######################################
+activation: !name:torch.nn.LeakyReLU
+dropout: 0.15
+cnn_blocks: 3
+cnn_channels: (128, 200, 256)
+inter_layer_pooling_size: (2, 2, 2)
+cnn_kernelsize: (3, 3)
+time_pooling_size: 4
+rnn_class: !name:speechbrain.nnet.RNN.LSTM
+rnn_layers: 5
+rnn_neurons: 1024
+rnn_bidirectional: True
+dnn_blocks: 2
+dnn_neurons: 1024
+dec_neurons: 1024
+output_neurons: 1000  # index(blank/eos/bos) = 0
+joint_dim: 1024
+blank_index: 0
+
+# Decoding parameters
+beam_size: 4
+nbest: 1
+# by default {state,expand}_beam = 2.3 as mention in paper
+# https://arxiv.org/abs/1904.02619
+state_beam: 2.3
+expand_beam: 2.3
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+   limit: !ref <number_of_epochs>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+   norm_type: global
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+   sample_rate: !ref <sample_rate>
+   n_fft: !ref <n_fft>
+   n_mels: !ref <n_mels>
+
+############################## Augmentations ###################################
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 15
+   drop_length_high: 25
+   drop_count_low: 5
+   drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 25
+   drop_length_high: 35
+   drop_count_low: 3
+   drop_count_high: 3
+   dim: 2
+
+# Time warp
+
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+   min_augmentations: 3
+   max_augmentations: 3
+   augment_prob: 1.0
+   augmentations: [
+      !ref <time_drop>,
+      !ref <freq_drop>,
+      !ref <time_warp>]
+
+############################## Models ##########################################
+
+enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
+   input_shape: [null, null, !ref <n_mels>]
+   activation: !ref <activation>
+   dropout: !ref <dropout>
+   cnn_blocks: !ref <cnn_blocks>
+   cnn_channels: !ref <cnn_channels>
+   cnn_kernelsize: !ref <cnn_kernelsize>
+   inter_layer_pooling_size: !ref <inter_layer_pooling_size>
+   time_pooling: True
+   using_2d_pooling: False
+   time_pooling_size: !ref <time_pooling_size>
+   rnn_class: !ref <rnn_class>
+   rnn_layers: !ref <rnn_layers>
+   rnn_neurons: !ref <rnn_neurons>
+   rnn_bidirectional: !ref <rnn_bidirectional>
+   rnn_re_init: True
+   dnn_blocks: !ref <dnn_blocks>
+   dnn_neurons: !ref <dnn_neurons>
+
+# For MTL CTC over the encoder
+enc_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <dnn_neurons>
+   n_neurons: !ref <joint_dim>
+
+# Uncomment for MTL with CTC
+# ctc_cost: !name:speechbrain.nnet.ctc_loss
+#    blank_index: !ref <blank_index>
+
+emb: !new:speechbrain.nnet.embedding.Embedding
+   num_embeddings: !ref <output_neurons>
+   consider_as_one_hot: True
+   blank_id: !ref <blank_index>
+
+dec: !new:speechbrain.nnet.RNN.GRU
+   input_shape: [null, null, !ref <output_neurons> - 1]
+   hidden_size: !ref <dec_neurons>
+   num_layers: 1
+   re_init: True
+
+# For MTL with LM over the decoder
+dec_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <dec_neurons>
+   n_neurons: !ref <joint_dim>
+   bias: False
+
+# For MLT with CTC
+#ce_cost: !name:speechbrain.nnet.losses.nll_loss
+#   label_smoothing: 0.1
+
+Tjoint: !new:speechbrain.nnet.transducer.transducer_joint.Transducer_joint
+   joint: sum # joint [sum | concat]
+   nonlinearity: !ref <activation>
+
+transducer_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <joint_dim>
+   n_neurons: !ref <output_neurons>
+   bias: False
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+   apply_log: True
+
+transducer_cost: !name:speechbrain.nnet.losses.transducer_loss
+   use_torchaudio: True
+   blank_index: !ref <blank_index>
+
+# for MTL
+# update model if any HEAD module is added
+modules:
+   enc: !ref <enc>
+   enc_lin: !ref <enc_lin>
+   emb: !ref <emb>
+   dec: !ref <dec>
+   dec_lin: !ref <dec_lin>
+   Tjoint: !ref <Tjoint>
+   transducer_lin: !ref <transducer_lin>
+   normalize: !ref <normalize>
+
+# for MTL
+# update model if any HEAD module is added
+model: !new:torch.nn.ModuleList
+   - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <transducer_lin>]
+
+# greedy_searcher: !new:speechbrain.decoders.transducer.TransducerBeamSearcher
+#   decode_network_lst: [!ref <emb>, !ref <dec>]
+#   tjoint: !ref <Tjoint>
+#   classifier_network: [!ref <transducer_lin>]
+#   blank_id: !ref <blank_index>
+#   beam_size: 1
+#   nbest: 1
+
+beam_searcher: !new:speechbrain.decoders.transducer.TransducerBeamSearcher
+   decode_network_lst: [!ref <emb>, !ref <dec>]
+   tjoint: !ref <Tjoint>
+   classifier_network: [!ref <transducer_lin>]
+   blank_id: !ref <blank_index>
+   beam_size: !ref <beam_size>
+   nbest: !ref <nbest>
+   state_beam: !ref <state_beam>
+   expand_beam: !ref <expand_beam>
+
+lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
+   initial_value: !ref <lr>
+   improvement_threshold: 0.0025
+   annealing_factor: 0.8
+   patient: 0
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+   checkpoints_dir: !ref <save_folder>
+   recoverables:
+      model: !ref <model>
+      scheduler: !ref <lr_annealing>
+      normalizer: !ref <normalize>
+      counter: !ref <epoch_counter>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+   save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+   split_tokens: True
diff --git a/recipes/CommonVoice/ASR/transducer/train.py b/recipes/CommonVoice/ASR/transducer/train.py
index 69e0969e57b8acefa585ff31c4e3f9a080c1e3c1..0304aabc810bfba7df541982bbc5d62c8c3f83b2 100644
--- a/recipes/CommonVoice/ASR/transducer/train.py
+++ b/recipes/CommonVoice/ASR/transducer/train.py
@@ -51,16 +51,25 @@ class ASR(sb.Brain):
         batch = batch.to(self.device)
         wavs, wav_lens = batch.sig
         tokens_with_bos, token_with_bos_lens = batch.tokens_bos
-        # wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "wav_augment"):
+                wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+                tokens_with_bos = self.hparams.wav_augment.replicate_labels(
+                    tokens_with_bos
+                )
 
         # Forward pass
         feats = self.hparams.compute_features(wavs)
         feats = self.modules.normalize(feats, wav_lens)
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "augmentation"):
-                feats = self.modules.augmentation(feats)
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_with_bos = self.hparams.fea_augment.replicate_labels(
+                tokens_with_bos
+            )
 
         x = self.modules.enc(feats.detach())
         e_in = self.modules.emb(tokens_with_bos)
@@ -123,6 +132,26 @@ class ASR(sb.Brain):
         tokens, token_lens = batch.tokens
         tokens_eos, token_eos_lens = batch.tokens_eos
 
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "wav_augment"):
+                (
+                    tokens,
+                    token_lens,
+                    tokens_eos,
+                    token_eos_lens,
+                ) = self.hparams.wav_augment.replicate_multiple_labels(
+                    tokens, token_lens, tokens_eos, token_eos_lens
+                )
+            if hasattr(self.hparams, "fea_augment"):
+                (
+                    tokens,
+                    token_lens,
+                    tokens_eos,
+                    token_eos_lens,
+                ) = self.hparams.fea_augment.replicate_multiple_labels(
+                    tokens, token_lens, tokens_eos, token_eos_lens
+                )
+
         if stage == sb.Stage.TRAIN:
             if len(predictions) == 4:
                 p_ctc, p_ce, logits_transducer, wav_lens = predictions
@@ -199,23 +228,6 @@ class ASR(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -353,7 +365,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/CommonVoice/ASR/transformer/README.md b/recipes/CommonVoice/ASR/transformer/README.md
index 128f2177d62fd0e64558dac33c87f5286b5e3fc0..cbdc0340fea39f9da8395247158be201bb9ba8fc 100644
--- a/recipes/CommonVoice/ASR/transformer/README.md
+++ b/recipes/CommonVoice/ASR/transformer/README.md
@@ -1,6 +1,5 @@
 # CommonVoice ASR with CTC + Attention based Seq2Seq models.
-This folder contains scripts necessary to run an ASR experiment with the CommonVoice dataset: [CommonVoice Homepage](https://commonvoice.mozilla.org/)
-
+This folder contains scripts necessary to run an ASR experiment with the CommonVoice 14.0 dataset: [CommonVoice Homepage](https://commonvoice.mozilla.org/) and pytorch 2.0
 # How to run
 ```shell
 python train.py hparams/{hparam_file}.py
@@ -23,34 +22,44 @@ It is important to note that CommonVoice initially offers mp3 audio files at 42H
 Here is a list of the different languages that we tested within the CommonVoice dataset
 with our transformers:
 - French
+- Italian
+- German
 
-For Whisper-large-v2 finetuning, here is list of the different language that we tested  within the CommonVoice.10_0 dataset:
+For Whisper-large-v2 and medium finetuning, here is list of the different language that we tested  within the CommonVoice.14_0 dataset:
 - Hindi
 - Arabic
 - Persian
 - Serbian
 - Mongolian
 - French
+- Italian
 
 
 # Results
 
-| Language | Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | Model link | GPUs |
-| ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:|
-| French | 2020-06-22 | train_fr.yaml | No | 5.15 | 17.80 | 6.01 | 19.21 | [model](https://www.dropbox.com/sh/o3bhc0n8h3t9tdg/AADxwA2j-z9a5Y9Um2E6rUwha?dl=0) | 1xV100 16GB |
+| Language | Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | Hugging Face link |  Model link | GPUs |
+| ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:|:-----------:| :-----------:| :-----------:|
+| French | 2023-08-15 | train_fr.yaml | No | 5.41 | 16.00 | 5.41 | 17.61 | - | [model](https://www.dropbox.com/sh/zvu9h9pctksnuvp/AAD1kyS3-N0YtmcoMgjM-_Tba?dl=0) | 1xV100 32GB |
+| Italian | 2023-08-15 | train_it.yaml | No | 3.72 | 16.31 | 4.01 | 16.80 | - | [model](https://www.dropbox.com/sh/yy8du12jgbkm3qe/AACBHhTCM-cU-oGvAKJ9kTtaa?dl=0) | 1xV100 32GB |
+| German | 2023-08-15 | train_de.yaml | No | 3.60 | 15.33 | 4.22 | 16.76 |- | [model](https://www.dropbox.com/sh/umfq986o3d9o1px/AAARNF2BFYELOWx3xhIOEoZka?dl=0) | 1xV100 32GB |
 
 ## Whisper Finetuning Result:
-Following table contains whisper-finetuning results for 1 epoch using whisper_large_v2 model, freezing encoder and finetuning decoder.
-| Language | Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | Model link | GPUs |
-| ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:|
-| Arabic | 2023-01-10 | train_ar_hf_whisper.yaml | No | 4.02 | 12.47 | 5.20 | 16.96 | [model](https://www.dropbox.com/sh/45o3xkxdheksdfi/AAAs1zxCw76mcAbudYEonzg0a?dl=0) | 1xV100 16GB |
-| Persian | 2023-01-10 | train_fa_hf_whisper.yaml | No | 6.91 | 25.30 | 9.38 | 31.75 | [model](https://www.dropbox.com/sh/a2vd6nn0icybdcz/AAC7z41jcheW1R9aNNK4-lHha?dl=0) | 1xV100 16GB |
-| Mongolian | 2023-01-10 | train_mn_hf_whisper.yaml | No | 24.05 | 62.37 | 25.73 | 64.92 | [model](https://www.dropbox.com/sh/2t0srpb2nt2wst5/AACRJQCwooRaLxPoLkmTvKq8a?dl=0) | 1xV100 16GB |
-| Hindi | 2023-01-10 | train_hi_hf_whisper.yaml | No | 4.54 | 10.46 | 7.00 | 15.27 | [model](https://www.dropbox.com/sh/qkcm86bzzb1y4sj/AABjA_ckw_hPwJCBzUiXLWrBa?dl=0) | 1xV100 16GB |
-| Serbian | 2023-01-10 | train_sr_hf_whisper.yaml | No | 8.92 | 27.12 |  7.60 | 23.63 | [model](https://www.dropbox.com/sh/a798gw3k2ezerp5/AADz7UxvQRQDOH4DnCJ4J4dja?dl=0) | 1xV100 16GB |
-| French | 2023-01-10 | train_fr_hf_whisper.yaml | No | 3.00 | 8.95 | 3.83 | 10.62 | [model](https://www.dropbox.com/sh/8c2lpa7m5amasjz/AAD5AZlD6OslhFc0W81D3nosa?dl=0) | 1xV100 16GB |
-
-The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0).
+Following table contains whisper-finetuning results for 1 epoch using whisper_medium model, freezing encoder and finetuning decoder.
+| Language | Release | Model | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | HuggingFace link | Model link | GPUs |
+| ------------- |:-------------:| -----:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------: |:-----------:| :-----------:|
+| Arabic | 2023-08-15 | large-v2 | train_ar_hf_whisper.yaml | No | 4.02 | 12.47 | 5.20 | 16.96 | [model](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-ar) | [model](https://www.dropbox.com/sh/45o3xkxdheksdfi/AAAs1zxCw76mcAbudYEonzg0a?dl=0) | 1xV100 16GB |
+| Persian | 2023-08-15 | large-v2 | train_fa_hf_whisper.yaml | No | 6.91 | 25.30 | 9.38 | 31.75 | [model](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fa) | [model](https://www.dropbox.com/sh/a2vd6nn0icybdcz/AAC7z41jcheW1R9aNNK4-lHha?dl=0) | 1xV100 16GB |
+| Mongolian | 2023-08-15 | large-v2 | train_mn_hf_whisper.yaml | No | 24.05 | 62.37 | 25.73 | 64.92 | [model](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-mn) | [model](https://www.dropbox.com/sh/2t0srpb2nt2wst5/AACRJQCwooRaLxPoLkmTvKq8a?dl=0) | 1xV100 16GB |
+| Hindi | 2023-08-15 | large-v2 | train_hi_hf_whisper.yaml | No | 4.54 | 10.46 | 7.00 | 15.27 | [model](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-hi) | [model](https://www.dropbox.com/sh/qkcm86bzzb1y4sj/AABjA_ckw_hPwJCBzUiXLWrBa?dl=0) | 1xV100 16GB |
+| Serbian | 2023-08-15 | large-v2 | train_sr_hf_whisper.yaml | No | 8.92 | 27.12 |  7.60 | 23.63 | [model](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-sr) | [model](https://www.dropbox.com/sh/a798gw3k2ezerp5/AADz7UxvQRQDOH4DnCJ4J4dja?dl=0) | 1xV100 16GB |
+| French | 2023-08-15 | large-v2 | train_fr_hf_whisper.yaml | No | 3.00 | 8.95 | 3.83 | 10.62 | [model](https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fr) | [model](https://www.dropbox.com/sh/8c2lpa7m5amasjz/AAD5AZlD6OslhFc0W81D3nosa?dl=0) | 1xV100 16GB |
+| Arabic | 2023-08-15 | Medium | train_ar_hf_whisper.yaml | No | 4.95 | 14.82 | 6.51 | 20.24 | [model](https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-ar) | [model](https://www.dropbox.com/sh/0e4vtvbg6hf2e13/AAD-tfzCZGUrh85aeAeJj8I9a?dl=0) | 1xV100 16GB |
+| Persian | 2023-08-15 | Medium | train_fa_hf_whisper.yaml | No | 8.58 | 35.48 | 11.27 | 35.48 |[model](https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-fa) | [model](https://www.dropbox.com/sh/w1urihacmtoulmi/AADMtK3qeAF5mLYk5LMHyiOra?dl=0) | 1xV100 16GB |
+| Mongolian | 2023-08-15 | Medium | train_mn_hf_whisper.yaml | No |  27.08 |  67.41 | 27.69 | 67.84 | [model](https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-mn) | [model](https://www.dropbox.com/sh/6fbhmey7q1udykf/AAAiGObWTTe2cdXHt2Uv2VQXa?dl=0) | 1xV100 16GB |
+| Hindi | 2023-08-15 | Medium | train_hi_hf_whisper.yaml | No | 5.82 | 12.51 | 8.16 | 17.04 | [model](https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-hi) | [model](https://www.dropbox.com/sh/z9vriyy3i6xqvif/AAB7ql-40yWTjKEQJiuhYUr5a?dl=0) | 1xV100 16GB |
+| Serbian | 2023-08-15 | Medium | train_sr_hf_whisper.yaml | No | 8.63 | 25.10 |  7.25 | 22.29 | [model](https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-sr) | [model](https://www.dropbox.com/sh/5lhk230q45sd97z/AAD-U9b_Ws_vFPs-cazsbOY0a?dl=0) | 1xV100 16GB |
+| French | 2023-08-15 | Medium | train_fr_hf_whisper.yaml | No | 3.26 | 9.65 | 4.30 | 11.79 | [model](https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-fr) | [model](https://www.dropbox.com/sh/7zlk07yxnslk4yy/AAANcI3EaG0ZFy6UrKk1Mm2Ga?dl=0) | 1xV100 16GB |
+| Italian | 2023-08-15 | Medium | train_it_hf_whisper.yaml | No | 2.42 | 8.26 | 3.03 | 9.63 | [model](https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-it) | [model](https://www.dropbox.com/sh/u5tex3nvzzs5pex/AAD-J7cOBE_fNfBono8waTKCa?dl=0) | 1xV100 16GB |
 
 # **About SpeechBrain**
 - Website: https://speechbrain.github.io/
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml
index aaa0572c522f65c2713235107eb4842380982326..d33c50c2bacb21e3bea276f428933d63f92cdb22 100644
--- a/recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml
@@ -37,11 +37,11 @@ avoid_if_longer_than: 10.0
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 1
 lr_whisper: 0.00003
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
@@ -63,7 +63,7 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 test_beam_size: 8
 
-# Model parameters
+####################### Model Parameters #######################################
 freeze_whisper: False
 freeze_encoder: True
 
@@ -82,11 +82,41 @@ test_loader_kwargs:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
     source: !ref <whisper_hub>
     freeze: !ref <freeze_whisper>
     freeze_encoder: !ref <freeze_encoder>
@@ -105,14 +135,14 @@ whisper_opt_class: !name:torch.optim.AdamW
     lr: !ref <lr_whisper>
     weight_decay: 0.000000001
 
-valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
+valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
     model: !ref <whisper>
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
+test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
     module: [!ref <whisper>]
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_de.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_de.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c5533e9bb30af916ba32d6554e7dd84ebf90f864
--- /dev/null
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_de.yaml
@@ -0,0 +1,256 @@
+# ############################################################################
+# Model: E2E ASR with Transformer
+# Encoder: Transformer Encoder
+# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch
+# Tokens: unigram
+# losses: CTC + KLdiv (Label Smoothing loss)
+# Authors:  Titouan Parcollet and Jianyuan Zhong
+# ############################################################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1234
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/transformer_de/<seed>
+test_wer_file: !ref <output_folder>/wer_test.txt
+valid_wer_file: !ref <output_folder>/wer_valid.txt
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
+test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
+accented_letters: True
+language: de # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev.csv
+test_csv: !ref <save_folder>/test.csv
+skip_prep: False # Skip data preparation
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+ckpt_interval_minutes: 15 # save checkpoint every N min
+
+####################### Training Parameters ####################################
+number_of_epochs: 50
+batch_size: 32 # This works with a 32GB GPU ! (bs * nb_gpu * accum) > 128 !
+ctc_weight: 0.3
+grad_accumulation_factor: 4
+loss_reduction: 'batchmean'
+sorting: random
+precision: fp32 # bf16, fp16 or fp32
+
+# stages related parameters
+stage_one_epochs: 40
+lr_adam: 1.0
+lr_sgd: 0.00003
+
+# BPE parameters
+token_type: unigram  # ["unigram", "bpe", "char"]
+character_coverage: 1.0
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+# Dataloader options
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: 6
+
+valid_dataloader_opts:
+    batch_size: !ref <batch_size>
+    num_workers: 6
+
+test_dataloader_opts:
+    batch_size: !ref <batch_size>
+    num_workers: 6
+
+####################### Model Parameters ###########################
+# Transformer
+d_model: 768
+nhead: 8
+num_encoder_layers: 12
+num_decoder_layers: 6
+d_ffn: 3072
+transformer_dropout: 0.0
+activation: !name:torch.nn.GELU
+output_neurons: 500
+
+# Outputs
+blank_index: 0
+label_smoothing: 0.1
+pad_index: 0
+bos_index: 1
+eos_index: 2
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+valid_search_interval: 5
+valid_beam_size: 10
+test_beam_size: 80
+ctc_weight_decode: 0.3
+scorer_beam_scale: 0.3
+
+############################## models ################################
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+    input_shape: (8, 10, 80)
+    num_blocks: 3
+    num_layers_per_block: 1
+    out_channels: (128, 200, 256)
+    kernel_sizes: (3, 3, 1)
+    strides: (2, 2, 1)
+    residuals: (False, False, False)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+    input_size: 5120
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    d_ffn: !ref <d_ffn>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    normalize_before: False
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+modules:
+    CNN: !ref <CNN>
+    Transformer: !ref <Transformer>
+    seq_lin: !ref <seq_lin>
+    ctc_lin: !ref <ctc_lin>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+# We define two optimizers as we have two stages (training + finetuning)
+Adam: !name:torch.optim.Adam
+    lr: !ref <lr_adam>
+    betas: (0.9, 0.98)
+    eps: 0.000000001
+
+SGD: !name:torch.optim.SGD
+    lr: !ref <lr_sgd>
+    momentum: 0.99
+    nesterov: True
+
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+    scorer_beam_scale: !ref <scorer_beam_scale>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: True
+    scorer: !ref <scorer>
+
+log_softmax: !new:torch.nn.LogSoftmax
+    dim: -1
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+    reduction: !ref <loss_reduction>
+
+seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr_adam>
+    n_warmup_steps: 25000
+    model_size: !ref <d_model>
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        noam_scheduler: !ref <noam_annealing>
+        normalizer: !ref <normalize>
+        counter: !ref <epoch_counter>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+    update_until_epoch: 3
+
+############################## Augmentations ###################################
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+    sample_rate: !ref <sample_rate>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mels>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+    split_tokens: True
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml
index 098e5082bf288b579d9efe0360098e7d6fdc9560..bb23c98a6ede2b4aeedbba67a8fb4ba61071cf53 100644
--- a/recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml
@@ -37,11 +37,11 @@ avoid_if_longer_than: 10.0
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 1
 lr_whisper: 0.00003
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
@@ -63,7 +63,7 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 test_beam_size: 8
 
-# Model parameters
+####################### Model Parameters #######################################
 freeze_whisper: False
 freeze_encoder: True
 
@@ -82,11 +82,42 @@ test_loader_kwargs:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
     source: !ref <whisper_hub>
     freeze: !ref <freeze_whisper>
     freeze_encoder: !ref <freeze_encoder>
@@ -105,14 +136,14 @@ whisper_opt_class: !name:torch.optim.AdamW
     lr: !ref <lr_whisper>
     weight_decay: 0.000000001
 
-valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
+valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
     model: !ref <whisper>
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
+test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
     module: [!ref <whisper>]
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_fr.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_fr.yaml
index a9c09a3508b03f92f12681e6a4495f108fbc30da..e62d9c390a91a49253f93e6c6e1587ff18f67639 100644
--- a/recipes/CommonVoice/ASR/transformer/hparams/train_fr.yaml
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_fr.yaml
@@ -9,7 +9,7 @@
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 1234
 __set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/transformer/<seed>
+output_folder: !ref results/transformer_fr/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 valid_wer_file: !ref <output_folder>/wer_valid.txt
 save_folder: !ref <output_folder>/save
@@ -33,13 +33,14 @@ avoid_if_longer_than: 10.0
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 50
 batch_size: 32 # This works with a 32GB GPU ! (bs * nb_gpu * accum) > 128 !
 ctc_weight: 0.3
-gradient_accumulation: 4
+grad_accumulation_factor: 4
 loss_reduction: 'batchmean'
 sorting: random
+precision: fp32 # bf16, fp16 or fp32
 
 # stages related parameters
 stage_one_epochs: 40
@@ -69,7 +70,7 @@ test_dataloader_opts:
     batch_size: !ref <batch_size>
     num_workers: 6
 
-####################### Model parameters ###########################
+####################### Model Parameters ###########################
 # Transformer
 d_model: 768
 nhead: 8
@@ -92,8 +93,9 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 valid_search_interval: 5
 valid_beam_size: 10
-# test_beam_size: 80
+test_beam_size: 80
 ctc_weight_decode: 0.3
+scorer_beam_scale: 0.3
 
 ############################## models ################################
 
@@ -137,7 +139,7 @@ model: !new:torch.nn.ModuleList
 
 # We define two optimizers as we have two stages (training + finetuning)
 Adam: !name:torch.optim.Adam
-    lr: 0
+    lr: !ref <lr_adam>
     betas: (0.9, 0.98)
     eps: 0.000000001
 
@@ -146,17 +148,39 @@ SGD: !name:torch.optim.SGD
     momentum: 0.99
     nesterov: True
 
-beam_searcher: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-    bos_index: !ref <bos_index>
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
     eos_index: !ref <eos_index>
     blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+    scorer_beam_scale: !ref <scorer_beam_scale>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: True
+    scorer: !ref <scorer>
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -189,17 +213,34 @@ normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
     update_until_epoch: 3
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
+############################## Augmentations ###################################
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
 
 compute_features: !new:speechbrain.lobes.features.Fbank
     sample_rate: !ref <sample_rate>
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml
index 0447d65a605fa0d4ac99132cf982a33f9eac91f3..62363bdada17730280ee59bc48ad90b5ed17f3a7 100644
--- a/recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml
@@ -7,7 +7,7 @@
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 1986
 __set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/train_whisper/<seed>/<locale>/test
+output_folder: !ref results/train_whisper/<seed>/<locale>
 test_wer_file: !ref <output_folder>/wer_test.txt
 valid_wer_file: !ref <output_folder>/wer_valid.txt
 save_folder: !ref <output_folder>/save
@@ -37,11 +37,11 @@ avoid_if_longer_than: 10.0
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 1
 lr_whisper: 0.00003
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
@@ -63,7 +63,7 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 test_beam_size: 8
 
-# Model parameters
+####################### Model Parameters #######################################
 freeze_whisper: False
 freeze_encoder: True
 
@@ -82,11 +82,42 @@ test_loader_kwargs:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
     source: !ref <whisper_hub>
     freeze: !ref <freeze_whisper>
     freeze_encoder: !ref <freeze_encoder>
@@ -105,14 +136,14 @@ whisper_opt_class: !name:torch.optim.AdamW
     lr: !ref <lr_whisper>
     weight_decay: 0.000000001
 
-valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
+valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
     model: !ref <whisper>
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
+test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
     module: [!ref <whisper>]
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml
index 4d9c31e5bc7ba54543e10ead81ec6d4ab9e6d274..e21852639b9b75437dc27e1dd9e10a963de7f0b9 100644
--- a/recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml
@@ -37,11 +37,11 @@ avoid_if_longer_than: 10.0
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 1
 lr_whisper: 0.00003
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
@@ -63,7 +63,7 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 test_beam_size: 8
 
-# Model parameters
+####################### Model Parameters #######################################
 freeze_whisper: False
 freeze_encoder: True
 
@@ -82,11 +82,42 @@ test_loader_kwargs:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
     source: !ref <whisper_hub>
     freeze: !ref <freeze_whisper>
     freeze_encoder: !ref <freeze_encoder>
@@ -105,14 +136,14 @@ whisper_opt_class: !name:torch.optim.AdamW
     lr: !ref <lr_whisper>
     weight_decay: 0.000000001
 
-valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
+valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
     model: !ref <whisper>
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
+test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
     module: [!ref <whisper>]
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_it.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_it.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d95fbaffae6d07fda47183b3a97d192b487bf649
--- /dev/null
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_it.yaml
@@ -0,0 +1,256 @@
+# ############################################################################
+# Model: E2E ASR with Transformer
+# Encoder: Transformer Encoder
+# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch
+# Tokens: unigram
+# losses: CTC + KLdiv (Label Smoothing loss)
+# Authors:  Titouan Parcollet and Jianyuan Zhong
+# ############################################################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1234
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/transformer_it/<seed>
+test_wer_file: !ref <output_folder>/wer_test.txt
+valid_wer_file: !ref <output_folder>/wer_valid.txt
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
+test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
+accented_letters: True
+language: it # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev.csv
+test_csv: !ref <save_folder>/test.csv
+skip_prep: False # Skip data preparation
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+ckpt_interval_minutes: 15 # save checkpoint every N min
+
+####################### Training Parameters ####################################
+number_of_epochs: 50
+batch_size: 32 # This works with a 32GB GPU ! (bs * nb_gpu * accum) > 128 !
+ctc_weight: 0.3
+grad_accumulation_factor: 4
+loss_reduction: 'batchmean'
+sorting: random
+precision: fp32 # bf16, fp16 or fp32
+
+# stages related parameters
+stage_one_epochs: 40
+lr_adam: 1.0
+lr_sgd: 0.00003
+
+# BPE parameters
+token_type: unigram  # ["unigram", "bpe", "char"]
+character_coverage: 1.0
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+# Dataloader options
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: 6
+
+valid_dataloader_opts:
+    batch_size: !ref <batch_size>
+    num_workers: 6
+
+test_dataloader_opts:
+    batch_size: !ref <batch_size>
+    num_workers: 6
+
+####################### Model Parameters ###########################
+# Transformer
+d_model: 768
+nhead: 8
+num_encoder_layers: 12
+num_decoder_layers: 6
+d_ffn: 3072
+transformer_dropout: 0.0
+activation: !name:torch.nn.GELU
+output_neurons: 500
+
+# Outputs
+blank_index: 0
+label_smoothing: 0.1
+pad_index: 0
+bos_index: 1
+eos_index: 2
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+valid_search_interval: 5
+valid_beam_size: 10
+test_beam_size: 80
+ctc_weight_decode: 0.3
+scorer_beam_scale: 0.3
+
+############################## models ################################
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+    input_shape: (8, 10, 80)
+    num_blocks: 3
+    num_layers_per_block: 1
+    out_channels: (128, 200, 256)
+    kernel_sizes: (3, 3, 1)
+    strides: (2, 2, 1)
+    residuals: (False, False, False)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+    input_size: 5120
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    d_ffn: !ref <d_ffn>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    normalize_before: False
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+modules:
+    CNN: !ref <CNN>
+    Transformer: !ref <Transformer>
+    seq_lin: !ref <seq_lin>
+    ctc_lin: !ref <ctc_lin>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+# We define two optimizers as we have two stages (training + finetuning)
+Adam: !name:torch.optim.Adam
+    lr: !ref <lr_adam>
+    betas: (0.9, 0.98)
+    eps: 0.000000001
+
+SGD: !name:torch.optim.SGD
+    lr: !ref <lr_sgd>
+    momentum: 0.99
+    nesterov: True
+
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+    scorer_beam_scale: !ref <scorer_beam_scale>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: True
+    scorer: !ref <scorer>
+
+log_softmax: !new:torch.nn.LogSoftmax
+    dim: -1
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+    reduction: !ref <loss_reduction>
+
+seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr_adam>
+    n_warmup_steps: 25000
+    model_size: !ref <d_model>
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        noam_scheduler: !ref <noam_annealing>
+        normalizer: !ref <normalize>
+        counter: !ref <epoch_counter>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+    update_until_epoch: 3
+
+############################## Augmentations ###################################
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 5
+    drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 25
+    drop_length_high: 35
+    drop_count_low: 2
+    drop_count_high: 2
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+    sample_rate: !ref <sample_rate>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mels>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+    split_tokens: True
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_it_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_it_hf_whisper.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e1fc08263096249cc2a2396263ed07029769c8e4
--- /dev/null
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_it_hf_whisper.yaml
@@ -0,0 +1,173 @@
+# ################################
+# Model: Whisper (Encoder-Decoder) + NLL
+# Augmentation: TimeDomainSpecAugment
+# Authors: Pooneh Mousavi 2022
+# ################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/train_whisper/<seed>/<locale>
+test_wer_file: !ref <output_folder>/wer_test.txt
+valid_wer_file: !ref <output_folder>/wer_valid.txt
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# URL for the biggest Fairseq english whisper model.
+whisper_hub: openai/whisper-tiny
+
+# Normalize inputs with the same normalization done in the paper (https://cdn.openai.com/papers/whisper.pdf). Refer to Appendix C for further information.
+normalized_transcripts: True
+
+# Data files
+locale: it # use 'it' for italian, 'fr' for french, 'en' for english , It is a language for common-voice data.
+data_folder: !PLACEHOLDER
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+dev_tsv_file: !ref <data_folder>/dev.tsv  # Standard CommonVoice .tsv files
+test_tsv_file: !ref <data_folder>/test.tsv  # Standard CommonVoice .tsv files
+accented_letters: True
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev.csv
+test_csv: !ref <save_folder>/test.csv
+skip_prep: False # Skip data preparation
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+ckpt_interval_minutes: 30 # save checkpoint every N min
+
+####################### Training Parameters ####################################
+number_of_epochs: 1
+lr_whisper: 0.00003
+sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
+sample_rate: 16000
+
+# With data_parallel batch_size is split into N jobs
+# With DDP batch_size is multiplied by N jobs
+batch_size: 12
+test_batch_size: 8
+
+# These values are only used for the searchers.
+# They needs to be hardcoded and should not be changed with Whisper.
+# They are used as part of the searching process.
+# The bos token of the searcher will be timestamp_index
+# and will be concatenated with the bos, language and task tokens.
+timestamp_index: 50363
+eos_index: 50257
+bos_index: 50258
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+test_beam_size: 8
+
+####################### Model Parameters #######################################
+freeze_whisper: False
+freeze_encoder: True
+
+train_loader_kwargs:
+    batch_size: !ref <batch_size>
+
+valid_loader_kwargs:
+    batch_size: !ref <batch_size>
+
+test_loader_kwargs:
+    batch_size: !ref <test_batch_size>
+
+#
+# Functions and classes
+#
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
+    source: !ref <whisper_hub>
+    freeze: !ref <freeze_whisper>
+    freeze_encoder: !ref <freeze_encoder>
+    save_path: !ref <save_folder>/whisper_checkpoint
+    encoder_only: False
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+    apply_log: True
+
+nll_loss: !name:speechbrain.nnet.losses.nll_loss
+
+modules:
+    whisper: !ref <whisper>
+
+whisper_opt_class: !name:torch.optim.AdamW
+    lr: !ref <lr_whisper>
+    weight_decay: 0.000000001
+
+valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
+    model: !ref <whisper>
+    bos_index: !ref <timestamp_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+
+test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
+    module: [!ref <whisper>]
+    bos_index: !ref <timestamp_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+
+lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr_whisper>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.9
+    patient: 0
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        whisper: !ref <whisper>
+        scheduler_whisper: !ref <lr_annealing_whisper>
+        counter: !ref <epoch_counter>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+    split_tokens: True
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml
index 1d01ec5ebde15298672ca583cff68a98903bd071..fe4fd6f17314fd33df1cce88f14f5ae5330c6fb8 100644
--- a/recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml
@@ -37,11 +37,11 @@ avoid_if_longer_than: 10.0
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 1
 lr_whisper: 0.00003
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
@@ -63,7 +63,7 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 test_beam_size: 8
 
-# Model parameters
+####################### Model Parameters #######################################
 freeze_whisper: False
 freeze_encoder: True
 
@@ -82,11 +82,42 @@ test_loader_kwargs:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
     source: !ref <whisper_hub>
     freeze: !ref <freeze_whisper>
     freeze_encoder: !ref <freeze_encoder>
@@ -105,14 +136,14 @@ whisper_opt_class: !name:torch.optim.AdamW
     lr: !ref <lr_whisper>
     weight_decay: 0.000000001
 
-valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
+valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
     model: !ref <whisper>
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
+test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
     module: [!ref <whisper>]
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml
index 05d50eacb440404bd9056211d6dd6c8726c7af63..d7390d9a59374591e54f9e69cd40dbf08b0f7ef7 100644
--- a/recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml
+++ b/recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml
@@ -37,11 +37,11 @@ avoid_if_longer_than: 10.0
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 1
 lr_whisper: 0.00003
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
@@ -63,7 +63,7 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 test_beam_size: 8
 
-# Model parameters
+####################### Model Parameters #######################################
 freeze_whisper: False
 freeze_encoder: True
 
@@ -83,11 +83,42 @@ test_loader_kwargs:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
     source: !ref <whisper_hub>
     freeze: !ref <freeze_whisper>
     freeze_encoder: !ref <freeze_encoder>
@@ -106,14 +137,14 @@ whisper_opt_class: !name:torch.optim.AdamW
     lr: !ref <lr_whisper>
     weight_decay: 0.000000001
 
-valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
+valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
     model: !ref <whisper>
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
+test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
     module: [!ref <whisper>]
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
diff --git a/recipes/CommonVoice/ASR/transformer/train.py b/recipes/CommonVoice/ASR/transformer/train.py
index c961c8af3e3a1edbef652505df4d3b9d30aab46a..89847d352e28d65dd2b23717b2dc134f89ed65dc 100644
--- a/recipes/CommonVoice/ASR/transformer/train.py
+++ b/recipes/CommonVoice/ASR/transformer/train.py
@@ -23,6 +23,7 @@ other possible variations.
 Authors
  * Titouan Parcollet 2021
  * Jianyuan Zhong 2020
+ * Pooneh Mousavi 2023
 """
 import sys
 import torch
@@ -47,15 +48,20 @@ class ASR(sb.core.Brain):
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
         tokens_bos, _ = batch.tokens_bos
 
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
+
         # compute features
         feats = self.hparams.compute_features(wavs)
         current_epoch = self.hparams.epoch_counter.current
         feats = self.hparams.normalize(feats, wav_lens, epoch=current_epoch)
 
-        # Augmentation
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                feats = self.hparams.augmentation(feats)
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)
 
         # forward modules
         src = self.modules.CNN(feats)
@@ -73,15 +79,20 @@ class ASR(sb.core.Brain):
 
         # Compute outputs
         hyps = None
-        if stage == sb.Stage.TRAIN:
-            hyps = None
-        elif stage == sb.Stage.VALID:
-            hyps = None
-            current_epoch = self.hparams.epoch_counter.current
-            if current_epoch % self.hparams.valid_search_interval == 0:
-                hyps, _ = self.hparams.beam_searcher(enc_out.detach(), wav_lens)
-        elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.beam_searcher(enc_out.detach(), wav_lens)
+        current_epoch = self.hparams.epoch_counter.current
+        is_valid_search = (
+            stage == sb.Stage.VALID
+            and current_epoch % self.hparams.valid_search_interval == 0
+        )
+        is_test_search = stage == sb.Stage.TEST
+
+        if is_valid_search:
+            hyps, _, _, _ = self.hparams.valid_search(
+                enc_out.detach(), wav_lens
+            )
+
+        elif is_test_search:
+            hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
         return p_ctc, p_seq, wav_lens, hyps
 
@@ -94,6 +105,29 @@ class ASR(sb.core.Brain):
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
+        # Augment Labels
+        if stage == sb.Stage.TRAIN:
+            # Labels must be extended if parallel augmentation or concatenated
+            # augmentation was performed on the input (increasing the time dimension)
+            if hasattr(self.hparams, "wav_augment"):
+                (
+                    tokens,
+                    tokens_lens,
+                    tokens_eos,
+                    tokens_eos_lens,
+                ) = self.hparams.wav_augment.replicate_multiple_labels(
+                    tokens, tokens_lens, tokens_eos, tokens_eos_lens
+                )
+            if hasattr(self.hparams, "fea_augment"):
+                (
+                    tokens,
+                    tokens_lens,
+                    tokens_eos,
+                    tokens_eos_lens,
+                ) = self.hparams.fea_augment.replicate_multiple_labels(
+                    tokens, tokens_lens, tokens_eos, tokens_eos_lens
+                )
+
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
         )
@@ -126,38 +160,17 @@ class ASR(sb.core.Brain):
             self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-
+    def on_fit_batch_start(self, batch, should_step):
+        """Gets called at the beginning of each fit_batch."""
         # check if we need to switch optimizer
         # if so change the optimizer from Adam to SGD
         self.check_and_reset_optimizer()
 
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        # normalize the loss by gradient_accumulation step
-        (loss / self.hparams.gradient_accumulation).backward()
-
-        if self.step % self.hparams.gradient_accumulation == 0:
-            # gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
-
-            self.optimizer.step()
-            self.optimizer.zero_grad()
-
-            # anneal lr every update
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
             self.hparams.noam_annealing(self.optimizer)
 
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        with torch.no_grad():
-            predictions = self.compute_forward(batch, stage=stage)
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -183,7 +196,7 @@ class ASR(sb.core.Brain):
                 stage_stats["CER"] = self.cer_metric.summarize("error_rate")
 
         # log stats and save checkpoint at end-of-epoch
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
 
             # report different epoch stages according current stage
             current_epoch = self.hparams.epoch_counter.current
@@ -254,18 +267,17 @@ class ASR(sb.core.Brain):
         # Wrap modules with parallel backend after jit
         self._wrap_distributed()
 
+        if self.checkpointer is not None:
+            self.checkpointer.recover_if_possible()
+
         # Initialize optimizers after parameters are configured
         self.init_optimizers()
 
-        # Load latest checkpoint to check to current epoch number
-        if self.checkpointer is not None:
-            self.checkpointer.recover_if_possible(
-                device=torch.device(self.device)
-            )
-
-        # if the model is resumed from stage two, reinitialize the optimizer
         current_epoch = self.hparams.epoch_counter.current
         if current_epoch > self.hparams.stage_one_epochs:
+            logger.info(
+                f"The checkpoint's epoch(= {current_epoch}) is greater than stage_one_epochs(= {self.hparams.stage_one_epochs}). Using the fine-tuning optimizer instead of the training one."
+            )
             self.optimizer = self.hparams.SGD(self.modules.parameters())
 
             if self.checkpointer is not None:
@@ -273,9 +285,7 @@ class ASR(sb.core.Brain):
 
         # Load latest checkpoint to resume training if interrupted
         if self.checkpointer is not None:
-            self.checkpointer.recover_if_possible(
-                device=torch.device(self.device)
-            )
+            self.checkpointer.recover_if_possible()
 
 
 # Define custom data procedure
@@ -375,7 +385,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/CommonVoice/ASR/transformer/train_with_whisper.py b/recipes/CommonVoice/ASR/transformer/train_with_whisper.py
index 4b960c0c66f727653ae3ee0b2f424c568399f522..7733b8a2a59685082a606ceddc0031334f1574b9 100644
--- a/recipes/CommonVoice/ASR/transformer/train_with_whisper.py
+++ b/recipes/CommonVoice/ASR/transformer/train_with_whisper.py
@@ -30,10 +30,10 @@ class ASR(sb.Brain):
         wavs, wav_lens = batch.sig
         bos_tokens, bos_tokens_lens = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            bos_tokens = self.hparams.wav_augment.replicate_labels(bos_tokens)
 
         # We compute the padding mask and replace the values with the pad_token_id
         # that the Whisper decoder expect to see.
@@ -49,9 +49,11 @@ class ASR(sb.Brain):
 
         hyps = None
         if stage == sb.Stage.VALID:
-            hyps, _ = self.hparams.valid_greedy_searcher(enc_out, wav_lens)
+            hyps, _, _, _ = self.hparams.valid_search(
+                enc_out.detach(), wav_lens
+            )
         elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.valid_greedy_searcher(enc_out, wav_lens)
+            hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
         return logits, hyps, wav_lens
 
@@ -63,6 +65,13 @@ class ASR(sb.Brain):
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
 
+        # Augment Labels
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
+            )
+
         log_probs = self.hparams.log_softmax(logits)
         loss = self.hparams.nll_loss(
             log_probs, tokens_eos, length=tokens_eos_lens,
@@ -71,6 +80,8 @@ class ASR(sb.Brain):
         if stage != sb.Stage.TRAIN:
             tokens, tokens_lens = batch.tokens
 
+            hyps = [hyp[0] if len(hyp) > 0 else [] for hyp in hyps]
+
             # Decode token terms to words
             predicted_words = self.tokenizer.batch_decode(
                 hyps, skip_special_tokens=True
@@ -141,7 +152,9 @@ class ASR(sb.Brain):
                 test_stats=stage_stats,
             )
             if if_main_process():
-                with open(self.hparams.test_wer_file, "w") as w:
+                with open(
+                    self.hparams.test_wer_file, "w", encoding="utf-8"
+                ) as w:
                     self.wer_metric.write_stats(w)
 
 
@@ -238,7 +251,6 @@ if __name__ == "__main__":
     # CLI:
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -276,17 +288,11 @@ if __name__ == "__main__":
     tokenizer.set_prefix_tokens(language, "transcribe", False)
 
     # we need to prepare the tokens for searchers
-    hparams["valid_greedy_searcher"].set_decoder_input_tokens(
-        tokenizer.prefix_tokens
-    )
-    hparams["valid_greedy_searcher"].set_language_token(
-        tokenizer.prefix_tokens[1]
-    )
+    hparams["valid_search"].set_decoder_input_tokens(tokenizer.prefix_tokens)
+    hparams["valid_search"].set_language_token(tokenizer.prefix_tokens[1])
 
-    hparams["test_beam_searcher"].set_decoder_input_tokens(
-        tokenizer.prefix_tokens
-    )
-    hparams["test_beam_searcher"].set_language_token(tokenizer.prefix_tokens[1])
+    hparams["test_search"].set_decoder_input_tokens(tokenizer.prefix_tokens)
+    hparams["test_search"].set_language_token(tokenizer.prefix_tokens[1])
 
     # here we create the datasets objects as well as tokenization and encoding
     train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
diff --git a/recipes/CommonVoice/LM/README.md b/recipes/CommonVoice/LM/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d3c13f8c09310dd6a2ba522b62161927be72a2a6
--- /dev/null
+++ b/recipes/CommonVoice/LM/README.md
@@ -0,0 +1,69 @@
+
+# Traing KenLM
+This folder contains recipes for training the kenLM-gram model for the CommonVoice Dataset.
+Using Wav2Vec2 in combination with a language model can yield a significant improvement, especially when the model is fine-tuned on small speech datasets. This is a guide to explain how one can create an n-gram language model and combine it with an existing fine-tuned Wav2Vec2.
+
+
+You can download CommonVoice at https://commonvoice.mozilla.org/en
+
+## Installing Extra Dependencies
+
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+
+```
+pip install -r extra_requirements.txt
+```
+
+We will use the popular KenLM library to build an n-gram. Let's start by installing the Ubuntu library prerequisites. For a complete guide on how to install required dependencies, please refer to [this](https://kheafield.com/code/kenlm/dependencies/) link:
+ ```
+ sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
+ ```
+
+ Next, we need to start downloading and unpacking the KenLM repo.
+ ```
+ wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
+ ```
+
+KenLM is written in C++, so we'll make use of cmake to build the binaries.
+ ```
+mkdir kenlm/build && cd kenlm/build && cmake .. && make -j2
+ ```
+
+Now, make sure that the executables are added to your .bashrc file. To do it,
+- Open the ~/.bashrc file in a text editor.
+- Scroll to the end of the file and add the following line:  ```export PATH=$PATH:/your/path/to/kenlm/build/bin ```
+- Save it and type:  `source ~/.bashrc `
+
+ ```
+# How to run:
+```shell
+python train.py hparams/train_kenlm.yaml  --data_folder=your/data/folder
+```
+
+# Results
+The script trains a n-gram language model, which is stored in the popular ARPA format.
+The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/scl/fo/zw505t10kesqpvkt6m3tu/h?rlkey=6626h1h665tvlo1mtekop9rx5&dl=0).
+
+
+
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/CommonVoice/LM/common_voice_prepare.py b/recipes/CommonVoice/LM/common_voice_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..5dacdbfdd153b86d9362e8b891f0cf095b9bfc10
--- /dev/null
+++ b/recipes/CommonVoice/LM/common_voice_prepare.py
@@ -0,0 +1 @@
+../common_voice_prepare.py
\ No newline at end of file
diff --git a/recipes/CommonVoice/LM/hparams/train_kenlm.yaml b/recipes/CommonVoice/LM/hparams/train_kenlm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3872ad09ae5e011591514a2fdc96301e7a0b2ac3
--- /dev/null
+++ b/recipes/CommonVoice/LM/hparams/train_kenlm.yaml
@@ -0,0 +1,22 @@
+#########
+# Recipe for Training kenLM on CommonVoice Data
+# It is  used to boost Wav2Vec2 with n-grams.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/CommonVoice/ngrams/<language>/<seed>
+
+# Data files
+data_folder: !PLACEHOLDER # e.g, /localscratch/cv-corpus-14.0-2023-06-23/en
+train_tsv_file: !ref <data_folder>/train.tsv
+language: en
+# accented_letters should be set according to the language
+accented_letters: True
+train_csv: !ref <output_folder>/train.csv
+skip_prep: False
+text_file: !ref <output_folder>/train.txt
+ngram: 5
+ngram_file: !ref <output_folder>/<language>_<ngram>gram.arpa
diff --git a/recipes/CommonVoice/LM/train.py b/recipes/CommonVoice/LM/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a37de42eb595b79aeed29a4570802c62b5abceb
--- /dev/null
+++ b/recipes/CommonVoice/LM/train.py
@@ -0,0 +1,94 @@
+"""
+Recipe  to train  kenlm ngram model  to combine an n-gram with Wav2Vec2.
+https://huggingface.co/blog/wav2vec2-with-ngram
+
+To run this recipe, do the following:
+> python train.py hparams/train.yaml --data_folder=/path/to/CommonVoice
+Author
+ * Pooneh Mousavi 2023
+"""
+
+import os
+import csv
+import sys
+import logging
+import speechbrain as sb
+from speechbrain.utils.distributed import run_on_main
+from hyperpyyaml import load_hyperpyyaml
+
+
+logger = logging.getLogger(__name__)
+
+
+def csv2text():
+    """Read CSV file and convert specific data entries into text file.
+    """
+    annotation_file = open(hparams["train_csv"], "r")
+    reader = csv.reader(annotation_file)
+    headers = next(reader, None)
+    text_file = open(hparams["text_file"], "w+")
+    index_label = headers.index("wrd")
+    row_idx = 0
+    for row in reader:
+        row_idx += 1
+        sent = row[index_label]
+        text_file.write(sent + "\n")
+    text_file.close()
+    annotation_file.close()
+    logger.info("Text file created at: " + hparams["text_file"])
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing Librispeech)
+    from common_voice_prepare import prepare_common_voice  # noqa
+
+    # multi-gpu (ddp) save data preparation
+    if not os.path.exists(hparams["text_file"]):
+        run_on_main(
+            prepare_common_voice,
+            kwargs={
+                "data_folder": hparams["data_folder"],
+                "save_folder": hparams["output_folder"],
+                "train_tsv_file": hparams["train_tsv_file"],
+                "accented_letters": hparams["accented_letters"],
+                "language": hparams["language"],
+                "skip_prep": hparams["skip_prep"],
+            },
+        )
+        csv2text()
+
+    logger.info(f"Start tarining {hparams['ngram']}-gram kenlm model.")
+    tmp_ngram_file = "ngram.arpa"
+    cmd = f'lmplz -o {hparams["ngram"]} <"{hparams["text_file"]}" > "{tmp_ngram_file}"'
+    os.system(cmd)
+    with open(tmp_ngram_file, "r") as read_file, open(
+        hparams["ngram_file"], "w"
+    ) as write_file:
+        has_added_eos = False
+        for line in read_file:
+            if not has_added_eos and "ngram 1=" in line:
+                count = line.strip().split("=")[-1]
+                write_file.write(line.replace(f"{count}", f"{int(count)+1}"))
+            elif not has_added_eos and "<s>" in line:
+                write_file.write(line)
+                write_file.write(line.replace("<s>", "</s>"))
+                has_added_eos = True
+            else:
+                write_file.write(line)
+    os.remove(tmp_ngram_file)
+    logger.info(
+        f"{hparams['ngram']}-gram kenlm model was built and saved in {hparams['ngram_file']}."
+    )
diff --git a/recipes/CommonVoice/common_voice_prepare.py b/recipes/CommonVoice/common_voice_prepare.py
index 7c4f9b2f39bea0ba99642d2a0c86b20c2160f4b0..176b1c1e5cd4c73314ec247705eb9b68611dfd78 100644
--- a/recipes/CommonVoice/common_voice_prepare.py
+++ b/recipes/CommonVoice/common_voice_prepare.py
@@ -1,11 +1,12 @@
 """
 Data preparation.
-Download: https://voice.mozilla.org/en/datasets
+Download: https://commonvoice.mozilla.org/en/datasets
 Author
 ------
 Titouan Parcollet
 Luca Della Libera 2022
 Pooneh Mousavi 2022
+Salima Mdhaffar 2023
 """
 
 from dataclasses import dataclass
@@ -13,7 +14,6 @@ import os
 import csv
 import re
 import logging
-import torchaudio
 import unicodedata
 import functools
 
@@ -35,7 +35,8 @@ def prepare_common_voice(
 ):
     """
     Prepares the csv files for the Mozilla Common Voice dataset.
-    Download: https://voice.mozilla.org/en/datasets
+    Download: https://commonvoice.mozilla.org/en
+
     Arguments
     ---------
     data_folder : str
@@ -169,16 +170,11 @@ def process_line(line, data_folder, language, accented_letters):
     # Path is at indice 1 in Common Voice tsv files. And .mp3 files
     # are located in datasets/lang/clips/
     mp3_path = data_folder + "/clips/" + line.split("\t")[1]
+
     file_name = mp3_path.split(".")[-2].split("/")[-1]
     spk_id = line.split("\t")[0]
     snt_id = file_name
 
-    # Setting torchaudio backend to sox-io (needed to read mp3 files)
-    if torchaudio.get_audio_backend() != "sox_io":
-        logger.warning("This recipe needs the sox-io backend of torchaudio")
-        logger.warning("The torchaudio backend is changed to sox_io")
-        torchaudio.set_audio_backend("sox_io")
-
     # Reading the signal (to retrieve duration in seconds)
     if os.path.isfile(mp3_path):
         info = read_audio_info(mp3_path)
@@ -215,7 +211,7 @@ def process_line(line, data_folder, language, accented_letters):
     chars = " ".join([char for char in chars][:])
 
     # Remove too short sentences (or empty):
-    if language in ["ja", "ch"]:
+    if language in ["ja", "zh-CN"]:
         if len(chars) < 3:
             return None
     else:
@@ -330,11 +326,55 @@ def language_specific_preprocess(language, words):
             "0000SS0000", "ß"
         )  # replace 0000SS0000 back to ß as its initial presence in the corpus
 
-    if language == "fr":
-        # Replace J'y D'hui etc by J_ D_hui
-        words = words.replace("'", " ")
-        words = words.replace("’", " ")
-
+    elif language == "fr":  # SM
+        words = re.sub(
+            "[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words
+        )
+        words = words.replace("’", "'")
+        words = words.replace("é", "é")
+        words = words.replace("æ", "ae")
+        words = words.replace("œ", "oe")
+        words = words.replace("â", "â")
+        words = words.replace("ç", "ç")
+        words = words.replace("è", "è")
+        words = words.replace("à", "à")
+        words = words.replace("û", "û")
+        words = words.replace("î", "î")
+        words = words.upper()
+
+        # Case of apostrophe collés
+        words = words.replace("L'", "L' ")
+        words = words.replace("L'  ", "L' ")
+        words = words.replace("S'", "S' ")
+        words = words.replace("S'  ", "S' ")
+        words = words.replace("D'", "D' ")
+        words = words.replace("D'  ", "D' ")
+        words = words.replace("J'", "J' ")
+        words = words.replace("J'  ", "J' ")
+        words = words.replace("N'", "N' ")
+        words = words.replace("N'  ", "N' ")
+        words = words.replace("C'", "C' ")
+        words = words.replace("C'  ", "C' ")
+        words = words.replace("QU'", "QU' ")
+        words = words.replace("QU'  ", "QU' ")
+        words = words.replace("M'", "M' ")
+        words = words.replace("M'  ", "M' ")
+
+        # Case of apostrophe qui encadre quelques mots
+        words = words.replace(" '", " ")
+        words = words.replace("A'", "A")
+        words = words.replace("B'", "B")
+        words = words.replace("E'", "E")
+        words = words.replace("F'", "F")
+        words = words.replace("G'", "G")
+        words = words.replace("K'", "K")
+        words = words.replace("Q'", "Q")
+        words = words.replace("V'", "V")
+        words = words.replace("W'", "W")
+        words = words.replace("Z'", "Z")
+        words = words.replace("O'", "O")
+        words = words.replace("X'", "X")
+        words = words.replace("AUJOURD' HUI", "AUJOURD'HUI")
     elif language == "ar":
         HAMZA = "\u0621"
         ALEF_MADDA = "\u0622"
diff --git a/recipes/CommonVoice/quantization/README.md b/recipes/CommonVoice/quantization/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..03c9de85570377612a4bae1b4e439962580ea4c6
--- /dev/null
+++ b/recipes/CommonVoice/quantization/README.md
@@ -0,0 +1,49 @@
+
+# K-means (Quantization)
+This folder contains recipes for training K-means clustering model for the CommonVoice Dataset.
+The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation.
+It supports  kmeans model using the features from  HuBERT, WAVLM or Wav2Vec.
+
+You can download CommonVoice at https://commonvoice.mozilla.org/en
+
+## Installing Extra Dependencies
+
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+
+```
+pip install -r extra_requirements.txt
+```
+
+# How to run:
+```shell
+python train.py hparams/train_with_{SSL_model}.yaml
+```
+
+# Results
+
+The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0).
+
+The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository.
+
+
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/CommonVoice/quantization/common_voice_prepare.py b/recipes/CommonVoice/quantization/common_voice_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..5dacdbfdd153b86d9362e8b891f0cf095b9bfc10
--- /dev/null
+++ b/recipes/CommonVoice/quantization/common_voice_prepare.py
@@ -0,0 +1 @@
+../common_voice_prepare.py
\ No newline at end of file
diff --git a/recipes/CommonVoice/quantization/extra-requirements.txt b/recipes/CommonVoice/quantization/extra-requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d5e06028d853376200623c16ebcf4992f4ae60c2
--- /dev/null
+++ b/recipes/CommonVoice/quantization/extra-requirements.txt
@@ -0,0 +1 @@
+scikit-learn
diff --git a/recipes/CommonVoice/quantization/hparams/train_with_hubert.yaml b/recipes/CommonVoice/quantization/hparams/train_with_hubert.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..88ffe4e0acb19bddf91fa2bb6239636e10e14e4d
--- /dev/null
+++ b/recipes/CommonVoice/quantization/hparams/train_with_hubert.yaml
@@ -0,0 +1,61 @@
+################################
+# Recipe for Training K-Means Clustering on CommonVoice Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from CommonVoice data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/CommonVoice/clustering/hubert/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+accented_letters: False
+language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+train_csv: !ref <save_folder>/train.csv
+skip_prep: False # Skip data preparation
+sample_rate: 16000
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+ssl_hub: facebook/hubert-base-ls960
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/hubert_checkpoint
+ssl_layer_num: 7
+batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+dataloader_num_workers: 8
+sorting: ascending
+
+# Dataloader options
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: !ref <dataloader_num_workers>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
+    source: !ref <ssl_hub>
+    output_norm: False
+    freeze: !ref <freeze_ssl>
+    freeze_feature_extractor: !ref <freeze_feature_extractor>
+    output_all_hiddens: True
+    save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/CommonVoice/quantization/hparams/train_with_wav2vec.yaml b/recipes/CommonVoice/quantization/hparams/train_with_wav2vec.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2ef55545208eb96ff81e4d3a8edb998bd4ba488d
--- /dev/null
+++ b/recipes/CommonVoice/quantization/hparams/train_with_wav2vec.yaml
@@ -0,0 +1,61 @@
+################################
+# Recipe for Training K-Means Clustering on CommonVoice Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from CommonVoice data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/CommonVoice/clustering/wav2vec/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+accented_letters: False
+language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+train_csv: !ref <save_folder>/train.csv
+skip_prep: False # Skip data preparation
+sample_rate: 16000
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+ssl_hub: facebook/wav2vec2-large-960h-lv60-self
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/wav2vec_checkpoint
+ssl_layer_num: 7
+batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+dataloader_num_workers: 8
+sorting: ascending
+
+# Dataloader options
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: !ref <dataloader_num_workers>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <ssl_hub>
+    output_norm: False
+    freeze: !ref <freeze_ssl>
+    freeze_feature_extractor: !ref <freeze_feature_extractor>
+    output_all_hiddens: True
+    save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/CommonVoice/quantization/hparams/train_with_wavlm.yaml b/recipes/CommonVoice/quantization/hparams/train_with_wavlm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ec4851e557bbe9cdf2ab24f412897eefb26f7b0c
--- /dev/null
+++ b/recipes/CommonVoice/quantization/hparams/train_with_wavlm.yaml
@@ -0,0 +1,61 @@
+################################
+# Recipe for Training K-Means Clustering on CommonVoice Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from CommonVoice data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/CommonVoice/clustering/wavlm/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
+train_tsv_file: !ref <data_folder>/train.tsv  # Standard CommonVoice .tsv files
+accented_letters: False
+language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
+train_csv: !ref <save_folder>/train.csv
+skip_prep: False # Skip data preparation
+sample_rate: 16000
+
+# We remove utterance slonger than 10s in the train/dev/test sets as
+# longer sentences certainly correspond to "open microphones".
+avoid_if_longer_than: 10.0
+
+ssl_hub: microsoft/wavlm-large
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/wavlm_checkpoint
+ssl_layer_num: 7
+batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+dataloader_num_workers: 8
+sorting: ascending
+
+# Dataloader options
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: !ref <dataloader_num_workers>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
+    source: !ref <ssl_hub>
+    output_norm: False
+    freeze: !ref <freeze_ssl>
+    freeze_feature_extractor: !ref <freeze_feature_extractor>
+    output_all_hiddens: True
+    save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/CommonVoice/quantization/train.py b/recipes/CommonVoice/quantization/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..261116ae5152e6f509b81b442994dd04915fa498
--- /dev/null
+++ b/recipes/CommonVoice/quantization/train.py
@@ -0,0 +1,154 @@
+"""
+Recipe  to train K-means clustering model on self-supervised representations.
+
+To run this recipe, do the following:
+> python train.py hparams/train_with_[SSL-model].yaml --data_folder=/path/to/LibriSPeech
+Author
+ * Pooneh Mousavi 2023
+"""
+
+import os
+import sys
+import logging
+import speechbrain as sb
+import torchaudio
+from speechbrain.utils.distributed import run_on_main
+from hyperpyyaml import load_hyperpyyaml
+from torch.utils.data import DataLoader
+from speechbrain.dataio.dataloader import LoopedLoader
+from speechbrain.utils.kmeans import fetch_kmeans_model, train, save_model
+
+
+logger = logging.getLogger(__name__)
+
+
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions."""
+
+    # 1. Define datasets
+    data_folder = hparams["data_folder"]
+
+    train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
+        csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
+    )
+
+    if hparams["sorting"] == "ascending":
+        # we sort training data to speed up training and get better results.
+        train_data = train_data.filtered_sorted(
+            sort_key="duration",
+            key_max_value={"duration": hparams["avoid_if_longer_than"]},
+        )
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["dataloader_options"]["shuffle"] = False
+
+    elif hparams["sorting"] == "descending":
+        train_data = train_data.filtered_sorted(
+            sort_key="duration",
+            reverse=True,
+            key_max_value={"duration": hparams["avoid_if_longer_than"]},
+        )
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["dataloader_options"]["shuffle"] = False
+
+    elif hparams["sorting"] == "random":
+        pass
+
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+
+    datasets = [train_data]
+
+    # 2. Define audio pipeline:
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        info = torchaudio.info(wav)
+        sig = sb.dataio.dataio.read_audio(wav)
+        resampled = torchaudio.transforms.Resample(
+            info.sample_rate, hparams["sample_rate"],
+        )(sig)
+        return resampled
+
+    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
+
+    # 4. Set output:
+    sb.dataio.dataset.set_output_keys(
+        datasets, ["id", "sig"],
+    )
+    return train_data
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing Librispeech)
+    from common_voice_prepare import prepare_common_voice  # noqa
+
+    # multi-gpu (ddp) save data preparation
+    run_on_main(
+        prepare_common_voice,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "save_folder": hparams["save_folder"],
+            "train_tsv_file": hparams["train_tsv_file"],
+            "accented_letters": hparams["accented_letters"],
+            "language": hparams["language"],
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Load SSL model
+    hparams["ssl_model"] = hparams["ssl_model"].to(run_opts["device"])
+
+    # Make training Dataloader
+    train_set = dataio_prepare(hparams)
+    if not (
+        isinstance(train_set, DataLoader) or isinstance(train_set, LoopedLoader)
+    ):
+        train_set = sb.dataio.dataloader.make_dataloader(
+            train_set, **hparams["dataloader_options"]
+        )
+
+    # Load pretrained KMeans model if it exists. Otherwise,  create new one.
+    checkpoint_path = os.path.join(
+        hparams["save_folder"], f"kmeans_{hparams['num_clusters']}.pt"
+    )
+    kmeans_model = fetch_kmeans_model(
+        n_clusters=hparams["num_clusters"],
+        init=hparams["init"],
+        max_iter=hparams["max_iter"],
+        batch_size=hparams["batch_size"],
+        tol=hparams["tol"],
+        max_no_improvement=hparams["max_no_improvement"],
+        n_init=hparams["n_init"],
+        reassignment_ratio=hparams["reassignment_ratio"],
+        random_state=hparams["seed"],
+        checkpoint_path=checkpoint_path,
+    )
+
+    # Train and save Kmeans model
+    train(
+        kmeans_model,
+        train_set,
+        hparams["ssl_model"],
+        hparams["ssl_layer_num"],
+        kmeans_batch_size=hparams["kmeans_batch_size"],
+        device=run_opts["device"],
+    )
+
+    logger.info(f"Saving kmeans model at {checkpoint_path}.")
+    save_model(kmeans_model, checkpoint_path)
diff --git a/recipes/CommonVoice/self-supervised-learning/wav2vec2/README.md b/recipes/CommonVoice/self-supervised-learning/wav2vec2/README.md
index 9ef297c1e9190989d41778da009815fd3bb6efc8..cd8faf81ddba0226fca21a086618124a025c6030 100644
--- a/recipes/CommonVoice/self-supervised-learning/wav2vec2/README.md
+++ b/recipes/CommonVoice/self-supervised-learning/wav2vec2/README.md
@@ -22,13 +22,13 @@ Do not forget to replace the `!PLACEHOLDER` variables in the yaml corresponding
 
 # Use a pretrained model for fine-tuning with SpeechBrain
 
-The checkpoint generated by this pretraining is a standard PyTorch checkpoint. If you wish to use it as any pretrained HuggingFace model, as you would do for all the recipes that we currently have for wav2vec 2.0 finetuning, you simply need to copy this checkpoint to a folder that contains the corresponding `config.json` and `preprocessor_config.json`. Indeed, SpeechBrain depends (for now) from HuggingFace to train the wav2vec 2.0 model, and these files are the way HuggingFace defines all the parameters of the model. They usually can be found directly on the HuggingFace repository. Then, you just have to use the `speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec` (e.g., [CommonVoice FR ASR](https://github.com/speechbrain/speechbrain/blob/develop/recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml)) class and give the `wav2vec2_hub:/my/path/to/my/speechbrain_wav2vec2_model` parameter, and your pretrained model will be loaded directly for downstream training!
+The checkpoint generated by this pretraining is a standard PyTorch checkpoint. If you wish to use it as any pretrained HuggingFace model, as you would do for all the recipes that we currently have for wav2vec 2.0 finetuning, you simply need to copy this checkpoint to a folder that contains the corresponding `config.json` and `preprocessor_config.json`. Indeed, SpeechBrain depends (for now) from HuggingFace to train the wav2vec 2.0 model, and these files are the way HuggingFace defines all the parameters of the model. They usually can be found directly on the HuggingFace repository. Then, you just have to use the `speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2` (e.g., [CommonVoice FR ASR](https://github.com/speechbrain/speechbrain/blob/develop/recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml)) class and give the `wav2vec2_hub:/my/path/to/my/speechbrain_wav2vec2_model` parameter, and your pretrained model will be loaded directly for downstream training!
 
 # Advices
 Training wav2vec 2.0 models is crazy w.r.t compute resources. For instance, this recipe only trains a BASE wav2vec 2.0 architecture, and it already requires 16 Tesla V100 for 7 days. Of course, you can scale this to your needs (e.g., you can work with 2 GPUs only), but it will take ages! Welcome to the wav2vec 2.0 world!
 
 Here is a list of the most important advices:
-- To train w2v2 models, it is **extremely** important to have an effective batch size as high as possible. For instance, the original BASE model is trained with batches containing 1.6H of speech. This means that (duration_per_minibatch * nb_gpu * gradient_accumulation) must be at least equal to 1.6H.
+- To train w2v2 models, it is **extremely** important to have an effective batch size as high as possible. For instance, the original BASE model is trained with batches containing 1.6H of speech. This means that (duration_per_minibatch * nb_gpu * grad_accumulation_factor) must be at least equal to 1.6H.
 - Do not train on sequences longer than 20s, this will blow your VRAM up and is useless for now. Indeed training with shorter sentences (10s) may work just as well.
 - Set the `n_warmup_steps` steps in such a way that it corresponds to 10% of the total training steps. The number of steps correspond to the actual number of call to .backward w.r.t the batch size.
 
diff --git a/recipes/CommonVoice/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml b/recipes/CommonVoice/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml
index a843c026e8fa1324677c0d200a86f2e7a1aa002c..e7ceed4f5cf9c0b0c47bc1ed433c9cb8dfe90282 100644
--- a/recipes/CommonVoice/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml
+++ b/recipes/CommonVoice/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml
@@ -27,11 +27,11 @@ skip_prep: False
 
 
 # We remove utterance slonger than 10s in the train/dev/test sets as
-# longer sentences certainly correspond to "open microphones".
+# longer sentences certainly correspond to open microphones.
 avoid_if_longer_than: 10.0
 avoid_if_shorter_than: 1.0
 
-# Training parameters
+####################### Training Parameters ####################################
 # Parameters are corresponding the the ones reported in the official wav2vec2
 # paper (for the masking).
 mask_length: 10
@@ -42,8 +42,7 @@ number_of_epochs: 100
 lr_adam: 2.0 # This will get reduced by the training scheduler
 weight_decay: 0.01
 d_model: 768  # Needed by the scheduler. 768 is for the BASE w2v2
-sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -51,15 +50,15 @@ ckpt_interval_minutes: 30 # save checkpoint every N min
 # With DDP batch_size is multiplied by N jobs
 # Must be 12 per GPU to fit 32GB of VRAM
 # IMPORTANT: To train w2v2 model, we recommand to have the effective batch_size
-# higher than 100 (batch_size * nb_gpu * gradient_accumulation)
+# higher than 100 (batch_size * nb_gpu * grad_accumulation_factor)
 # Examples are:
-# 32 Tesla V100 32GB — 12 * 32 * 1
-# 4 Tesla V100 32GB — 12 * 4 * {6-8}
+# 32 Tesla V100 32GB = 12 * 32 * 1
+# 4 Tesla V100 32GB = 12 * 4 * (6-8)
 batch_size: 12
 test_batch_size: 8
-gradient_accumulation: 8
+grad_accumulation_factor: 8
 num_workers: 4
-
+sorting: ascending
 dataloader_options:
     batch_size: !ref <batch_size>
     num_workers: !ref <num_workers>
@@ -72,19 +71,22 @@ test_dataloader_options:
 # Instead of the default setting. While the recipe will work directly by setting
 # it to True, you will first need to read the tutorial on dynamic batching to
 # properly adapt the hyperparameters to your GPU memory! Using Dynamic Batching
-# will drastically optimise your GPU utilization and decrease your training time.
+# will drastically optimise your GPU utilization and decrease your training time.
 # Be careful to also adjust the gradient accumulation when using dynamic batching.
 # This setup will work with 32GB GPUs.
 # Dynamic Batching parameters, if used are:
 dynamic_batching: False
-dyn_batch_len: 120 # Cumulative length of each batch, per gpu.
-max_batch_size: 64 # Max number of samples per batch, per gpu.
+max_batch_length: 120 # Cumulative length of each batch, per gpu.
+max_batch_ex: 64 # Max number of samples per batch, per gpu.
+shuffle: True
+num_buckets: 30
+
 dynamic_batch_sampler:
-    max_batch_len: !ref <dyn_batch_len>
-    max_batch_ex: !ref <max_batch_size>
-    shuffle_ex: True
+    max_batch_length: !ref <max_batch_length>
+    max_batch_ex: !ref <max_batch_ex>
+    shuffle: !ref <shuffle>
     batch_ordering: !ref <sorting>
-    num_buckets: 30
+    num_buckets: !ref <num_buckets>
 
 #
 # Functions and classes
@@ -92,7 +94,7 @@ dynamic_batch_sampler:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2Pretrain
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2Pretrain
     source: !ref <wav2vec2_hub>
     save_path: !ref <wav2vec2_folder>
     mask_prob: !ref <mask_prob>
@@ -102,7 +104,7 @@ modules:
     wav2vec2: !ref <wav2vec2>
 
 opt_class: !name:torch.optim.AdamW
-    lr: 0 # Will be changed by the scheduler, but we start at 0!
+    lr: 0 # Will be changed by the scheduler, but we start at 0
     betas: (0.9, 0.98)
     eps: 0.000000001
     weight_decay: !ref <weight_decay>
diff --git a/recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py b/recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py
index e98ff8a2064ac03cb59bc9e795b85eac0d7be4f3..079008049a99029e2da04dcf9f6429a411de75b0 100644
--- a/recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py
+++ b/recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py
@@ -1,13 +1,4 @@
 #!/usr/bin/env python3
-
-import sys
-import torch
-import logging
-import speechbrain as sb
-import torchaudio
-from hyperpyyaml import load_hyperpyyaml
-from speechbrain.utils.distributed import run_on_main
-
 """Recipe for pretraining a wav2vec 2.0 model on CommonVoice EN. Note that it can be
 trained with ANY dataset as long as you provide the correct JSON or CSV file.
 
@@ -35,6 +26,13 @@ Authors
  * Titouan Parcollet 2021
  * Yan Gao 2021
 """
+import sys
+import torch
+import logging
+import speechbrain as sb
+import torchaudio
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.distributed import run_on_main
 
 logger = logging.getLogger(__name__)
 
@@ -80,51 +78,10 @@ class W2VBrain(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-
-        # Here we manage mixed precision
-        if self.auto_mix_prec:
-            with torch.cuda.amp.autocast():
-                predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(
-                    predictions, batch, sb.Stage.TRAIN
-                )
-
-            # normalize the loss by gradient_accumulation step
-            self.scaler.scale(
-                loss / self.hparams.gradient_accumulation
-            ).backward()
-
-            if self.step % self.hparams.gradient_accumulation == 0:
-                # gradient clipping & early stop if loss is not fini
-                self.check_gradients(loss)
-
-                self.scaler.unscale_(self.optimizer)
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-                self.optimizer.zero_grad()
-
-                # anneal lr every update
-                self.hparams.noam_annealing(self.optimizer)
-        else:
-            predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-            loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-            # normalize the loss by gradient_accumulation step
-            (loss / self.hparams.gradient_accumulation).backward()
-
-            if self.step % self.hparams.gradient_accumulation == 0:
-                # gradient clipping & early stop if loss is not fini
-                self.check_gradients(loss)
-
-                self.optimizer.step()
-                self.optimizer.zero_grad()
-
-                # anneal lr every update
-                self.hparams.noam_annealing(self.optimizer)
-
-        return loss.detach()
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            self.hparams.noam_annealing(self.optimizer)
 
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
@@ -259,24 +216,13 @@ def dataio_prepare(hparams):
         from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
 
         dynamic_hparams = hparams["dynamic_batch_sampler"]
-        num_buckets = dynamic_hparams["num_buckets"]
 
         train_batch_sampler = DynamicBatchSampler(
-            train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            train_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
         valid_batch_sampler = DynamicBatchSampler(
-            valid_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
-            length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            valid_data, **dynamic_hparams, length_func=lambda x: x["duration"],
         )
 
     return (
@@ -295,7 +241,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/DNS/README.md b/recipes/DNS/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..82eaf1a5ceefb5a0cacbb9fcb70be76c3f8f8917
--- /dev/null
+++ b/recipes/DNS/README.md
@@ -0,0 +1,143 @@
+# **Speech Enhancement for Microsoft Deep Noise Suppression (DNS) Challenge – ICASSP 2022**
+This repository contains training recipes for a speech enhancement system designed for the 4th Deep Noise Suppression Challenge, organized by Microsoft at Interspeech 2022. <br>
+The Deep Noise Suppression Challenge features two distinct tracks:
+1. **Real Time Non-Personalized DNS**
+2. Real Time Personalized DNS (PDNS) for Fullband Audio
+
+We focus on implementing solutions only for the first track, which involves real-time non-personalized DNS.
+
+- **Model and Data** : For this challenge, we employ the [Sepformer model](https://arxiv.org/abs/2010.13154v2) to train our speech enhancement system. Our training utilizes 500 hours of fullband audio.
+
+- **Evaluation Strategy** : We follow the official evaluation strategy outlined by the ITU-T P.835 subjective test framework. It measures speech quality, background noise quality, and overall audio quality. This is done using [DNSMOS P.835](https://arxiv.org/pdf/2110.01763.pdf), a machine learning-based model capable of predicting SIG (Speech Quality), BAK (Background Noise Quality), and OVRL (Overall Audio Quality).
+
+**Related links**
+- [Official Website](https://www.microsoft.com/en-us/research/academic-program/deep-noise-suppression-challenge-icassp-2022/)
+- [DNS-4 ICASSP 2022 github repository](https://github.com/microsoft/DNS-Challenge/tree/5582dcf5ba43155621de72a035eb54a7d233af14)
+
+## **DNS-4 dataset**
+DNS-4 dataset once decompressed, the directory structure and sizes of datasets are:
+```
+datasets_fullband 892G
++-- dev_testset 1.7G
++-- impulse_responses 5.9G
++-- noise_fullband 58G
+\-- clean_fullband 827G
+    +-- emotional_speech 2.4G
+    +-- french_speech 62G
+    +-- german_speech 319G
+    +-- italian_speech 42G
+    +-- read_speech 299G
+    +-- russian_speech 12G
+    +-- spanish_speech 65G
+    +-- vctk_wav48_silence_trimmed 27G
+    \-- VocalSet_48kHz_mono 974M
+```
+
+### **Required disk space**
+The `dns_download.py` download script downloads the Real-time DNS track data and de-compresses it. The compressed data takes around 550 GB of disk space and when de-compressed you would need 1 TB to store audio files. We bundle this decompressed audio into larger archives called as shards.
+However this is not the end, the downloaded clean-audio files, RIRs, and noisy-audio files are further used to synthesize clean-noisy audio pairs for training. Once again, we bundle the synthesized data into shards for efficient and faster accessibility. This means further space will be needed to store the synthesized clean-noisy-noise shards.
+
+**NOTE**
+- This dataset download process can be extremely time-consuming. With a total of 126 splits (train, noise and dev data), the script downloads each split in a serial order. The script also allows concurrent data download (by enabling `--parallel_download` param) by using multiple threads (equal to number of your CPU cores). This is helpful especially when you have access to a large cluster. (Alternatively, you can download all 126 splits and decompress them at once by using array job submission.)
+
+## **Installing Extra Dependencies**
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+
+```
+pip install -r extra_requirements.txt
+```
+
+## **Getting started**
+- STEP 1: Download DNS dataset.
+- STEP 2: Synthesize noisy data.
+- STEP 3: Begin training.
+
+## Step 1: **Downloading Real-time DNS track dataset and create the Webdataset shards**
+The DNS dataset can be downloaded by running the script below
+```
+python dns_download.py --compressed_path DNS-dataset --decompressed_path DNS-compressed
+```
+To use parallel downloading
+```
+python dns_download.py --compressed_path DNS-dataset --decompressed_path DNS-compressed --parallel_download
+```
+The compressed files are downloaded in `DNS-compressed` and further decompressed audio files can be found in `DNS-dataset`.
+
+Next, create webdataset shards
+```
+## webdataset shards for clean_fullband (choose one one language i.e. read, german etc. at a time)
+python create_wds_shards.py DNS-dataset/datasets_fullband/clean_fullband/<read_speech/german_speech/french_speech/...>/ DNS-shards/clean_fullband/
+
+## webdataset shards for noise_fullband
+python create_wds_shards.py DNS-dataset/datasets_fullband/noise_fullband/ DNS-shards/noise_fullband
+
+## webdataset shards for baseline dev-set
+python create_wds_shards.py DNS-dataset/datasets_fullband/dev_testset/noisy_testclips/ DNS-shards/devsets_fullband
+```
+## Step 2: **Synthesize noisy data and create the Webdataset shards**
+To synthesize clean-noisy audio for speech enhancement training (we add noise, RIR to clean fullband speech to synthesize clean-noisy pairs)
+```
+cd noisyspeech_synthesizer
+
+## synthesize read speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name read_speech --synthesized_data_dir synthesized_data_shards
+
+## synthesize German speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name german_speech --synthesized_data_dir synthesized_data_shards
+
+## synthesize Italian speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name italian_speech --synthesized_data_dir synthesized_data_shards
+
+## similarly do for spanish, russian and french.
+```
+*For more, please see `noisyspeech_synthesizer` on how to synthesize noisy files from clean audio and noise audio files.*
+
+## Step 3: **Begin training**
+To start training
+```
+cd enhancement
+python train.py hparams/sepformer-dns-16k.yaml --data_folder <path/to/synthesized_shards_data> --baseline_noisy_shards_folder <path/to/baseline_shards_data>
+```
+*For more details and how to perform evaluation, see `enhancement` folder on details about the main training script*
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
+
+
+**Citing SepFormer**
+```bibtex
+@inproceedings{subakan2021attention,
+      title={Attention is All You Need in Speech Separation},
+      author={Cem Subakan and Mirco Ravanelli and Samuele Cornell and Mirko Bronzi and Jianyuan Zhong},
+      year={2021},
+      booktitle={ICASSP 2021}
+}
+```
+
+**Citing DNS-4 dataset (ICASSP 2022)**
+```bibtex
+@inproceedings{dubey2022icassp,
+  title={ICASSP 2022 Deep Noise Suppression Challenge},
+  author={Dubey, Harishchandra and Gopal, Vishak and Cutler, Ross and Matusevych, Sergiy and Braun, Sebastian and Eskimez, Emre Sefik and Thakker, Manthan and Yoshioka, Takuya and Gamper, Hannes and Aichner, Robert},
+  booktitle={ICASSP},
+  year={2022}
+}
+```
\ No newline at end of file
diff --git a/recipes/DNS/create_wds_shards.py b/recipes/DNS/create_wds_shards.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc52ab18da43992d46149b568f3859d7992e3722
--- /dev/null
+++ b/recipes/DNS/create_wds_shards.py
@@ -0,0 +1,186 @@
+################################################################################
+#
+# Converts the uncompressed DNS folder
+# {french,german,...}_speech/../<*.wav>
+# structure of DNS into a WebDataset format
+#
+# Author(s): Tanel Alumäe, Nik Vaessen, Sangeet Sagar (2023)
+################################################################################
+
+import os
+import json
+from tqdm import tqdm
+import pathlib
+import argparse
+import random
+from collections import defaultdict
+
+import librosa
+import torch
+import torchaudio
+import webdataset as wds
+
+################################################################################
+# methods for writing the shards
+
+ID_SEPARATOR = "&"
+
+
+def load_audio(audio_file_path: pathlib.Path) -> torch.Tensor:
+    t, sr = torchaudio.load(audio_file_path)
+
+    return t
+
+
+def write_shards(
+    dns_folder_path: pathlib.Path,
+    shards_path: pathlib.Path,
+    seed: int,
+    samples_per_shard: int,
+    min_dur: float,
+):
+    """
+    Parameters
+    ----------
+    dns_folder_path: folder where extracted DNS data is located
+    shards_path: folder to write shards of data to
+    seed: random seed used to initially shuffle data into shards
+    samples_per_shard: number of data samples to store in each shards.
+    """
+    # make sure output folder exist
+    shards_path.mkdir(parents=True, exist_ok=True)
+
+    # find all audio files
+    audio_files = sorted([f for f in dns_folder_path.rglob("*.wav")])
+
+    # create tuples (unique_sample_id, language_id, path_to_audio_file, duration)
+    data_tuples = []
+
+    # track statistics on data
+    all_language_ids = set()
+    sample_keys_per_language = defaultdict(list)
+
+    if "clean" in dns_folder_path.as_posix():
+        delim = "clean_fullband/"
+    elif "noise" in dns_folder_path.as_posix():
+        delim = "noise_fullband/"
+        lang = "noise"
+    elif "dev_testset" in dns_folder_path.as_posix():
+        delim = "dev_testset/"
+        lang = "baseline_noisytestset"
+    else:
+        delim = os.path.basename(dns_folder_path.as_posix())
+        lang = delim
+
+    for f in tqdm(audio_files):
+        # path should be
+        # {french,german,...}_speech/../<*.wav>
+        sub_path = f.as_posix().split(delim)[1]
+
+        loc = f.as_posix()
+        key = os.path.splitext(os.path.basename(sub_path))[0]
+        if "clean_fullband" in dns_folder_path.as_posix():
+            lang = key.split("_speech")[0]
+
+        dur = librosa.get_duration(path=loc)
+
+        # Period is not allowed in a WebDataset key name
+        key = key.replace(".", "_")
+        if dur > min_dur:
+            # store statistics
+            all_language_ids.add(lang)
+            sample_keys_per_language[lang].append(key)
+            t = (key, lang, loc, dur)
+            data_tuples.append(t)
+
+    all_language_ids = sorted(all_language_ids)
+
+    # write a meta.json file which contains statistics on the data
+    # which will be written to shards
+    meta_dict = {
+        "language_ids": list(all_language_ids),
+        "sample_keys_per_language": sample_keys_per_language,
+        "num_data_samples": len(data_tuples),
+    }
+
+    with (shards_path / "meta.json").open("w") as f:
+        json.dump(meta_dict, f, indent=4)
+
+    # shuffle the tuples so that each shard has a large variety in languages
+    random.seed(seed)
+    random.shuffle(data_tuples)
+
+    # write shards
+    all_keys = set()
+    shards_path.mkdir(exist_ok=True, parents=True)
+    pattern = str(shards_path / "shard") + "-%06d.tar"
+
+    with wds.ShardWriter(pattern, maxcount=samples_per_shard) as sink:
+        for key, language_id, f, duration in data_tuples:
+            # load the audio tensor
+            tensor = load_audio(f)
+
+            # verify key is unique
+            assert key not in all_keys
+            all_keys.add(key)
+
+            # create sample to write
+            sample = {
+                "__key__": key,
+                "audio.pth": tensor,
+                "language_id": language_id,
+            }
+
+            # write sample to sink
+            sink.write(sample)
+
+
+################################################################################
+# define CLI
+
+parser = argparse.ArgumentParser(
+    description="Convert DNS-4 to WebDataset shards"
+)
+
+parser.add_argument(
+    "dns_decompressed_path",
+    type=pathlib.Path,
+    help="directory containing the (decompressed) DNS dataset",
+)
+parser.add_argument(
+    "shards_path", type=pathlib.Path, help="directory to write shards to"
+)
+parser.add_argument(
+    "--seed",
+    type=int,
+    default=12345,
+    help="random seed used for shuffling data before writing to shard",
+)
+parser.add_argument(
+    "--samples_per_shard",
+    type=int,
+    default=5000,
+    help="the maximum amount of samples placed in each shard. The last shard "
+    "will most likely contain fewer samples.",
+)
+parser.add_argument(
+    "--min-duration",
+    type=float,
+    default=3.0,
+    help="Minimum duration of the audio",
+)
+
+
+################################################################################
+# execute script
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+
+    write_shards(
+        args.dns_decompressed_path,
+        args.shards_path,
+        args.seed,
+        args.samples_per_shard,
+        args.min_duration,
+    )
diff --git a/recipes/DNS/dns_download.py b/recipes/DNS/dns_download.py
new file mode 100644
index 0000000000000000000000000000000000000000..84381f243029326cd63d15292c61c687ae51215c
--- /dev/null
+++ b/recipes/DNS/dns_download.py
@@ -0,0 +1,600 @@
+#!/usr/bin/env/python3
+"""
+Recipe for downloading DNS-4 dataset- training,
+baseline DEV noisyset, blind testset
+Source:
+https://github.com/microsoft/DNS-Challenge
+https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-4.sh
+
+Disk-space (compressed): 550 GB
+Disk-space (decompressed): 1 TB
+
+NOTE:
+    1. Some of the azure links provided by Microsoft are not perfect and data
+    download may stop mid-way through the download process. Hence we validate
+    download size of each of the file.
+    2. Instead of using the impulse response files provided in the challenge,
+    we opt to download them from OPENSLR. OPENSLR offers both real and synthetic
+    RIRs, while the challenge offers only real RIRs.
+
+Authors
+    * Sangeet Sagar 2022
+"""
+
+import os
+import ssl
+import shutil
+import zipfile
+import tarfile
+import certifi
+import argparse
+import fileinput
+import requests
+import urllib.request
+from tqdm.auto import tqdm
+from concurrent.futures import ThreadPoolExecutor
+
+BLOB_NAMES = [
+    "clean_fullband/datasets_fullband.clean_fullband.VocalSet_48kHz_mono_000_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.emotional_speech_000_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_000_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_001_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_002_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_003_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_004_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_005_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_006_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_007_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.french_speech_008_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_000_0.00_3.47.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_001_3.47_3.64.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_002_3.64_3.74.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_003_3.74_3.81.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_004_3.81_3.86.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_005_3.86_3.91.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_006_3.91_3.96.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_007_3.96_4.00.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_008_4.00_4.04.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_009_4.04_4.08.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_010_4.08_4.12.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_011_4.12_4.16.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_012_4.16_4.21.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_013_4.21_4.26.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_014_4.26_4.33.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_015_4.33_4.43.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_016_4.43_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_017_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_018_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_019_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_020_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_021_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_022_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_023_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_024_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_025_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_026_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_027_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_028_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_029_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_030_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_031_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_032_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_033_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_034_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_035_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_036_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_037_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_038_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_039_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_040_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_041_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.german_speech_042_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.italian_speech_000_0.00_3.98.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.italian_speech_001_3.98_4.21.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.italian_speech_002_4.21_4.40.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.italian_speech_003_4.40_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.italian_speech_004_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.italian_speech_005_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_000_0.00_3.75.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_001_3.75_3.88.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_002_3.88_3.96.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_003_3.96_4.02.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_004_4.02_4.06.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_005_4.06_4.10.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_006_4.10_4.13.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_007_4.13_4.16.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_008_4.16_4.19.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_009_4.19_4.21.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_010_4.21_4.24.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_011_4.24_4.26.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_012_4.26_4.29.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_013_4.29_4.31.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_014_4.31_4.33.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_015_4.33_4.35.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_016_4.35_4.38.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_017_4.38_4.40.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_018_4.40_4.42.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_019_4.42_4.45.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_020_4.45_4.48.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_021_4.48_4.52.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_022_4.52_4.57.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_023_4.57_4.67.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_024_4.67_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_025_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_026_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_027_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_028_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_029_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_030_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_031_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_032_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_033_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_034_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_035_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_036_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_037_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_038_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.read_speech_039_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.russian_speech_000_0.00_4.31.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.russian_speech_001_4.31_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_000_0.00_4.09.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_001_4.09_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_002_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_003_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_004_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_005_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_006_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_007_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_008_NA_NA.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_000.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_001.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_002.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_003.tar.bz2",
+    "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_004.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.audioset_000.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.audioset_001.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.audioset_002.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.audioset_003.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.audioset_004.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.audioset_005.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.audioset_006.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.freesound_000.tar.bz2",
+    "noise_fullband/datasets_fullband.noise_fullband.freesound_001.tar.bz2",
+    "datasets_fullband.dev_testset_000.tar.bz2",
+]
+
+AZURE_URL = (
+    "https://dns4public.blob.core.windows.net/dns4archive/datasets_fullband"
+)
+
+# Impulse reponse and Blind testset
+OTHER_URLS = {
+    "impulse_responses": [
+        "https://www.openslr.org/resources/26/sim_rir_16k.zip",
+        "https://www.openslr.org/resources/28/rirs_noises.zip",
+    ],
+    "blind_testset": [
+        "https://dns4public.blob.core.windows.net/dns4archive/blind_testset_bothtracks.zip"
+    ],
+}
+
+RIR_table_simple_URL = "https://raw.githubusercontent.com/microsoft/DNS-Challenge/0443a12f5e6e7bec310f453cf0d9637ca28e0eea/datasets/acoustic_params/RIR_table_simple.csv"
+
+SPLIT_LIST = [
+    "dev_testset",
+    "impulse_responses",
+    "noise_fullband",
+    "emotional_speech",
+    "french_speech",
+    "german_speech",
+    "italian_speech",
+    "read_speech",
+    "russian_speech",
+    "spanish_speech",
+    "vctk_wav48_silence_trimmed",
+    "VocalSet_48kHz_mono",
+]
+
+
+def prepare_download():
+    """
+    Downloads and prepares various data files and resources. It
+    downloads real-time DNS track data files (train set and dev
+    noisy set).
+    """
+    # Real-time DNS track (train set + dev noisy set)
+    for file_url in BLOB_NAMES:
+        for split in SPLIT_LIST:
+            if split in file_url:
+                split_name = split
+
+        split_path = os.path.join(COMPRESSED_PATH, split_name)
+        if not os.path.exists(split_path):
+            os.makedirs(split_path)
+        if not os.path.exists(DECOMPRESSED_PATH):
+            os.makedirs(DECOMPRESSED_PATH)
+
+        filename = file_url.split("/")[-1]
+        download_path = os.path.join(split_path, filename)
+        download_url = AZURE_URL + "/" + file_url
+
+        if not validate_file(download_url, download_path):
+            if os.path.exists(download_path):
+                resume_byte_pos = os.path.getsize(download_path)
+            else:
+                resume_byte_pos = None
+
+            download_file(
+                download_url,
+                download_path,
+                split_name,
+                filename,
+                resume_byte_pos=resume_byte_pos,
+            )
+        else:
+            print(", \tDownload complete. Skipping")
+        decompress_file(download_path, DECOMPRESSED_PATH, split_name)
+
+    # Download RIR (impulse response) & BLIND testset
+    rir_blind_test_download()
+
+
+def rir_blind_test_download():
+    """
+    Download the RIRs (room impulse responses), and the blind
+    test set.
+    """
+    # RIR (impulse response) & BLIND testset
+    for split_name, download_urls in OTHER_URLS.items():
+        for file_url in download_urls:
+            split_path = os.path.join(COMPRESSED_PATH, split_name)
+            if not os.path.exists(split_path):
+                os.makedirs(split_path)
+
+            filename = file_url.split("/")[-1]
+            download_path = os.path.join(split_path, filename)
+
+            if not validate_file(file_url, download_path):
+                if os.path.exists(download_path):
+                    resume_byte_pos = os.path.getsize(download_path)
+                else:
+                    resume_byte_pos = None
+
+                download_file(
+                    file_url,
+                    download_path,
+                    split_name,
+                    filename,
+                    resume_byte_pos=resume_byte_pos,
+                )
+            else:
+                print(", \tDownload complete. Skipping")
+            decompress_file(
+                download_path,
+                os.path.join(DECOMPRESSED_PATH, split_name),
+                split_name,
+            )
+
+    # Download RIRs simple table
+    file_path = os.path.join(
+        DECOMPRESSED_PATH, "impulse_responses", "RIR_table_simple.csv"
+    )
+    response = requests.get(RIR_table_simple_URL)
+    if response.status_code == 200:
+        with open(file_path, "wb") as file:
+            file.write(response.content)
+        print("\nRIR_simple_table downloaded successfully.")
+
+    else:
+        print(
+            f"\nFailed to download RIR_simple_table. Status code: {response.status_code}"
+        )
+
+
+def download_file(
+    download_url, download_path, split_name, filename, resume_byte_pos=None
+):
+    """
+    Download file from given URL
+
+    Arguments
+    ---------
+    download_url : str
+        URL of file being downloaded
+    download_path : str
+        Full path of the file that is to be downloaded
+        (or already downloaded)
+    split_name : str
+        Split name of the file being downloaded
+        e.g. read_speech
+    filename : str
+        Fielname of the file being downloaded
+    resume_byte_pos: (int, optional)
+        Starting byte position for resuming the download.
+        Default is None, which means a fresh download.
+
+    Returns
+    -------
+    bool
+        If True, the file need not be downloaded again.
+        Else the download might have failed or is incomplete.
+    """
+    print("Downloading:", split_name, "=>", filename)
+    resume_header = (
+        {"Range": f"bytes={resume_byte_pos}-"} if resume_byte_pos else None
+    )
+    response = requests.get(download_url, headers=resume_header, stream=True)
+    file_size = int(response.headers.get("Content-Length"))
+
+    mode = "ab" if resume_byte_pos else "wb"
+    initial_pos = resume_byte_pos if resume_byte_pos else 0
+
+    with open(download_path, mode) as f:
+        with tqdm(
+            total=file_size,
+            unit="B",
+            unit_scale=True,
+            unit_divisor=1024,
+            initial=initial_pos,
+            miniters=1,
+        ) as pbar:
+            for chunk in response.iter_content(32 * 1024):
+                f.write(chunk)
+                pbar.update(len(chunk))
+
+    # Validate downloaded file
+    if validate_file(download_url, download_path):
+        return True
+    else:
+        print("Download failed. Moving on.")
+        return False
+
+
+def download_file_parallel(args):
+    """
+    Downloads a file in parallel using the provided arguments. It
+    makes use of `download_file` function to download the required file.
+
+    Arguments
+    ---------
+    args : tuple
+        Tuple containing the download URL, download path, split
+        name, filename, and required bytes to be downloaded.
+    """
+    download_url, download_path, split_name, filename, resume_byte_pos = args
+    download_file(
+        download_url,
+        download_path,
+        split_name,
+        filename,
+        resume_byte_pos=resume_byte_pos,
+    )
+
+
+def parallel_download():
+    """
+    Perform parallel download of files using `using ThreadPoolExecutor`.
+    """
+    with ThreadPoolExecutor() as executor:
+        futures = []
+        for file_url in BLOB_NAMES:
+            for split in SPLIT_LIST:
+                if split in file_url:
+                    split_name = split
+            split_path = os.path.join(COMPRESSED_PATH, split_name)
+            if not os.path.exists(split_path):
+                os.makedirs(split_path)
+            if not os.path.exists(DECOMPRESSED_PATH):
+                os.makedirs(DECOMPRESSED_PATH)
+
+            filename = file_url.split("/")[-1]
+            download_path = os.path.join(split_path, filename)
+            download_url = AZURE_URL + "/" + file_url
+
+            if not validate_file(download_url, download_path):
+                if os.path.exists(download_path):
+                    resume_byte_pos = os.path.getsize(download_path)
+                else:
+                    resume_byte_pos = None
+                args = (
+                    download_url,
+                    download_path,
+                    split_name,
+                    filename,
+                    resume_byte_pos,
+                )
+                futures.append(executor.submit(download_file_parallel, args))
+                # download_file(download_url, download_path, split_name, filename)
+                # decompress_file(download_path, DECOMPRESSED_PATH)
+            else:
+                print(", \tDownload complete. Skipping")
+                decompress_file(download_path, DECOMPRESSED_PATH, split_name)
+
+        for future in futures:
+            future.result()
+
+    # Download RIR (impulse response) & BLIND testset
+    rir_blind_test_download()
+
+
+def decompress_file(file, decompress_path, split_name):
+    """
+    Decompress the downloaded file if the target folder does not exist.
+
+    Arguments
+    ---------
+    file : str
+        Path to the compressed downloaded file
+    decompress_path : str
+        Path to store the decompressed audio files
+    """
+    for _, dirs, _ in os.walk(decompress_path):
+        if split_name in dirs:
+            print("\tDecompression skipped. Folder already exists.")
+            return True
+
+    if "sim_rir_16k" in file:
+        slr26_dir = os.path.join(decompress_path, "SLR26")
+        if os.path.exists(slr26_dir):
+            print("\tDecompression skipped. Folder already exists.")
+            return True
+
+    if "rirs_noises" in file:
+        slr28_dir = os.path.join(decompress_path, "SLR28")
+        if os.path.exists(slr28_dir):
+            print("\tDecompression skipped. Folder already exists.")
+            return True
+
+    print("\tDecompressing...")
+    file_extension = os.path.splitext(file)[-1].lower()
+    if file_extension == ".zip":
+        zip = zipfile.ZipFile(file, "r")
+        zip.extractall(decompress_path)
+        rename_rirs(decompress_path)
+
+    elif file_extension == ".bz2":
+        tar = tarfile.open(file, "r:bz2")
+        tar.extractall(decompress_path)
+        tar.close()
+    else:
+        print("Unsupported file format. Only zip and bz2 files are supported.")
+    # os.remove(file)
+
+
+def rename_rirs(decompress_path):
+    """
+    Rename directories containing simulated room impulse responses
+    (RIRs).
+
+    Arguments
+    ---------
+        decompress_path (str): The path to the directory containing the RIRs
+
+    Returns
+    -------
+        None
+    """
+    try:
+        os.rename(
+            os.path.join(decompress_path, "simulated_rirs_16k"),
+            os.path.join(decompress_path, "SLR26"),
+        )
+    except Exception:
+        pass
+    try:
+        os.rename(
+            os.path.join(decompress_path, "RIRS_NOISES"),
+            os.path.join(decompress_path, "SLR28"),
+        )
+    except Exception:
+        pass
+
+
+def validate_file(download_url, download_path):
+    """
+    Validate the downloaded file and resume the download if needed.
+
+    Arguments
+    ---------
+    download_url : str
+        URL of the file being downloaded
+    download_path : str
+        Full path of the file that is to be downloaded
+        (or already downloaded)
+
+    Returns
+    -------
+    bool
+        If True, the file need not be downloaded again.
+        Else, either the file is not yet downloaded or
+        partially downloaded, thus resume the download.
+    """
+    if not os.path.isfile(download_path):
+        # File not yet downloaded
+        return False
+
+    # Get file size in MB
+    actual_size = urllib.request.urlopen(
+        download_url,
+        context=ssl.create_default_context(cafile=certifi.where()),
+    ).length
+
+    download_size = os.path.getsize(download_path)
+
+    print(
+        "File: {}, \t downloaded {} MB out of {} MB".format(
+            download_path.split("/")[-1],
+            download_size // (1024 * 1024),
+            actual_size // (1024 * 1024),
+        ),
+        end="",
+    )
+    # Set a margin of 100 MB. We skip re-downloading the file if downloaded
+    # size differs from actual size by max 100 MB. More than this margin,
+    # re-download is to attempted.
+    if actual_size - download_size < 100 * 1024 * 1024:
+        return True
+    else:
+        print(", \tIncomplete download. Resuming...")
+        return False
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Download and extract DNS dataset."
+    )
+    parser.add_argument(
+        "--compressed_path",
+        type=str,
+        default="DNS-compressed",
+        help="Path to store the compressed data.",
+    )
+    parser.add_argument(
+        "--decompressed_path",
+        type=str,
+        default="DNS-dataset",
+        help="Path to store the decompressed data.",
+    )
+
+    parser.add_argument(
+        "--parallel_download",
+        action="store_true",
+        help="Use parallel download.",
+    )
+
+    args = parser.parse_args()
+
+    COMPRESSED_PATH = args.compressed_path
+    DECOMPRESSED_PATH = args.decompressed_path
+
+    if args.parallel_download:
+        parallel_download()
+    else:
+        prepare_download()
+
+    # Modfy contents inside RIR_simple_table.csv
+    file_path = os.path.join(
+        DECOMPRESSED_PATH, "impulse_responses", "RIR_table_simple.csv"
+    )
+    full_path = os.path.abspath(os.path.dirname(file_path))
+
+    replacements = {
+        "datasets/impulse_responses/SLR26/simulated_rirs_16k": os.path.join(
+            full_path, "SLR26"
+        ),
+        "datasets/impulse_responses/SLR28/RIRS_NOISES": os.path.join(
+            full_path, "SLR28"
+        ),
+    }
+
+    # Perform the replacements directly in the file using fileinput module
+    with fileinput.FileInput(file_path, inplace=True) as file:
+        for line in file:
+            for original, replacement in replacements.items():
+                line = line.replace(original, replacement)
+            print(line, end="")
+
+    if not os.path.exists(
+        os.path.join("noisyspeech_synthesizer", "RIR_table_simple.csv")
+    ):
+        shutil.move(file_path, "noisyspeech_synthesizer")
diff --git a/recipes/DNS/enhancement/README.md b/recipes/DNS/enhancement/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee4f1d01c7659ce5b387da53e4e2fe4099c99e1b
--- /dev/null
+++ b/recipes/DNS/enhancement/README.md
@@ -0,0 +1,74 @@
+# **Speech enhancement with Microsoft DNS dataset**
+This folder contains the recipe for speech enhancement on Deep Noise Suppression (DNS) Challenge 4 (ICASSP 2022) dataset using SepFormer.
+
+For data download and prepration, please refer to the `README.md` in `recipes/DNS/`
+
+## **Start training**
+```
+python train.py hparams/sepformer-dns-16k.yaml --data_folder <path/to/synthesized_shards_data> --baseline_noisy_shards_folder <path/to/baseline_dev_shards_data>
+```
+## **DNSMOS Evaluation on baseline-testclips**
+*Reference: [Offical repo](https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS) <br>*
+Download the evalution models from [Offical repo](https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS) and save it under `DNSMOS`. Then, to run DNSMOS evalution on the baseline-testclips saved in the above step.
+```
+# Model=SepFormer
+python dnsmos_local.py -t results/sepformer-enhancement-16k/1234/save/baseline_audio_results/enhanced_testclips/ -o dnsmos_enhance.csv
+
+# Model=Noisy
+python dnsmos_local.py -t <path-to/datasets_fullband/dev_testset/noisy_testclips/> -o dnsmos_noisy.csv
+```
+
+## **Results**
+1. The DNS challenge doesn't provide the ground-truth clean files for dev test. Therefore, we randomly separate out 5% of training set as valid set so that we can compute valid stats like Si-SNR and PESQ during validation. Here we show validation performance.
+
+      | Sampling rate | Valid Si-SNR | Valid PESQ | HuggingFace link	| Full Model link |
+      |---------------|--------------|------------|-------------------|------------|
+      | 16k           | -10.6        | 2.06       | [HuggingFace](https://huggingface.co/speechbrain/sepformer-dns4-16k-enhancement) |  https://www.dropbox.com/sh/d3rp5d3gjysvy7c/AACmwcEkm_IFvaW1lt2GdtQka?dl=0          |
+
+2. Evaluation on DNS4 2022 baseline dev set using DNSMOS.
+
+    | Model      | SIG    | BAK    | OVRL   |
+    |------------|--------|--------|--------|
+    | Noisy      | 2.984  | 2.560  | 2.205  |
+    | Baseline: NSNet2| 3.014  | 3.942  | 2.712  |
+    | **SepFormer**  | 2.999  | 3.076  | 2.437  |
+
+We performed 45 epochs of training for the enhancement using an 8 X RTXA6000 48GB GPU. On average, each epoch took approximately 9.25 hours to complete. **Consider training it for atleast 90-100 epochs for superior performance.**
+
+**NOTE**
+- Refer [NSNet2](https://github.com/microsoft/DNS-Challenge/tree/5582dcf5ba43155621de72a035eb54a7d233af14/NSNet2-baseline) on how to perform enhancement on baseline dev set (noisy testclips) using the baseline model- NSNet2.
+
+## **Computing power**
+Kindly be aware that in terms of computational power, training can be extremely resource demanding due to the dataset's large size and the complexity of the SepFormer model. To handle the size of 1300 hours of clean-noisy pairs, we employed a multi-GPU distributed data-parallel (DDP) training scheme on an Nvidia 8 X RTXA6000 48GB GPU. The training process lasted for 17 days, for just 45 epochs.
+
+## **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+## **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
+
+
+**Citing SepFormer**
+```bibtex
+@inproceedings{subakan2021attention,
+      title={Attention is All You Need in Speech Separation},
+      author={Cem Subakan and Mirco Ravanelli and Samuele Cornell and Mirko Bronzi and Jianyuan Zhong},
+      year={2021},
+      booktitle={ICASSP 2021}
+}
+```
diff --git a/recipes/DNS/enhancement/composite_eval.py b/recipes/DNS/enhancement/composite_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..68353caf515ea1b91cc53032864a415aef074cd5
--- /dev/null
+++ b/recipes/DNS/enhancement/composite_eval.py
@@ -0,0 +1,466 @@
+"""Composite objective enhancement scores in Python (CSIG, CBAK, COVL)
+
+Taken from https://github.com/facebookresearch/denoiser/blob/master/scripts/matlab_eval.py
+
+Authors
+ * adiyoss (https://github.com/adiyoss)
+"""
+
+from scipy.linalg import toeplitz
+from tqdm import tqdm
+from pesq import pesq
+import librosa
+import numpy as np
+import os
+import sys
+
+
+def eval_composite(ref_wav, deg_wav, sample_rate):
+    """Evaluate audio quality metrics based on reference
+    and degraded audio signals.
+    This function computes various audio quality metrics,
+    including PESQ, CSIG, CBAK, and COVL, based on the
+    reference and degraded audio signals provided.
+    """
+    ref_wav = ref_wav.reshape(-1)
+    deg_wav = deg_wav.reshape(-1)
+
+    alpha = 0.95
+    len_ = min(ref_wav.shape[0], deg_wav.shape[0])
+    ref_wav = ref_wav[:len_]
+    deg_wav = deg_wav[:len_]
+
+    # Compute WSS measure
+    wss_dist_vec = wss(ref_wav, deg_wav, sample_rate)
+    wss_dist_vec = sorted(wss_dist_vec, reverse=False)
+    wss_dist = np.mean(wss_dist_vec[: int(round(len(wss_dist_vec) * alpha))])
+
+    # Compute LLR measure
+    LLR_dist = llr(ref_wav, deg_wav, sample_rate)
+    LLR_dist = sorted(LLR_dist, reverse=False)
+    LLRs = LLR_dist
+    LLR_len = round(len(LLR_dist) * alpha)
+    llr_mean = np.mean(LLRs[:LLR_len])
+
+    # Compute the SSNR
+    snr_mean, segsnr_mean = SSNR(ref_wav, deg_wav, sample_rate)
+    segSNR = np.mean(segsnr_mean)
+
+    # Compute the PESQ
+    pesq_raw = PESQ(ref_wav, deg_wav, sample_rate)
+
+    Csig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_raw - 0.009 * wss_dist
+    Csig = trim_mos(Csig)
+    Cbak = 1.634 + 0.478 * pesq_raw - 0.007 * wss_dist + 0.063 * segSNR
+    Cbak = trim_mos(Cbak)
+    Covl = 1.594 + 0.805 * pesq_raw - 0.512 * llr_mean - 0.007 * wss_dist
+    Covl = trim_mos(Covl)
+
+    return {"pesq": pesq_raw, "csig": Csig, "cbak": Cbak, "covl": Covl}
+
+
+# ----------------------------- HELPERS ------------------------------------ #
+def trim_mos(val):
+    """Trim a value to be within the MOS (Mean Opinion Score)
+    range [1, 5].
+    """
+    return min(max(val, 1), 5)
+
+
+def lpcoeff(speech_frame, model_order):
+    """Calculate linear prediction (LP) coefficients using
+    the autocorrelation method.
+    """
+    # (1) Compute Autocor lags
+    winlength = speech_frame.shape[0]
+    R = []
+    for k in range(model_order + 1):
+        first = speech_frame[: (winlength - k)]
+        second = speech_frame[k:winlength]
+        R.append(np.sum(first * second))
+
+    # (2) Lev-Durbin
+    a = np.ones((model_order,))
+    E = np.zeros((model_order + 1,))
+    rcoeff = np.zeros((model_order,))
+    E[0] = R[0]
+    for i in range(model_order):
+        if i == 0:
+            sum_term = 0
+        else:
+            a_past = a[:i]
+            sum_term = np.sum(a_past * np.array(R[i:0:-1]))
+        rcoeff[i] = (R[i + 1] - sum_term) / E[i]
+        a[i] = rcoeff[i]
+        if i > 0:
+            a[:i] = a_past[:i] - rcoeff[i] * a_past[::-1]
+        E[i + 1] = (1 - rcoeff[i] * rcoeff[i]) * E[i]
+    acorr = np.array(R, dtype=np.float32)
+    refcoeff = np.array(rcoeff, dtype=np.float32)
+    a = a * -1
+    lpparams = np.array([1] + list(a), dtype=np.float32)
+    acorr = np.array(acorr, dtype=np.float32)
+    refcoeff = np.array(refcoeff, dtype=np.float32)
+    lpparams = np.array(lpparams, dtype=np.float32)
+
+    return acorr, refcoeff, lpparams
+
+
+# -------------------------------------------------------------------------- #
+
+# ---------------------- Speech Quality Metric ----------------------------- #
+def PESQ(ref_wav, deg_wav, sample_rate):
+    """Compute PESQ score.
+    """
+    psq_mode = "wb" if sample_rate == 16000 else "nb"
+    return pesq(sample_rate, ref_wav, deg_wav, psq_mode)
+
+
+def SSNR(ref_wav, deg_wav, srate=16000, eps=1e-10):
+    """ Segmental Signal-to-Noise Ratio Objective Speech Quality Measure
+        This function implements the segmental signal-to-noise ratio
+        as defined in [1, p. 45] (see Equation 2.12).
+    """
+    clean_speech = ref_wav
+    processed_speech = deg_wav
+    clean_length = ref_wav.shape[0]
+
+    # scale both to have same dynamic range. Remove DC too.
+    clean_speech -= clean_speech.mean()
+    processed_speech -= processed_speech.mean()
+    processed_speech *= np.max(np.abs(clean_speech)) / np.max(
+        np.abs(processed_speech)
+    )
+
+    # Signal-to-Noise Ratio
+    dif = ref_wav - deg_wav
+    overall_snr = 10 * np.log10(
+        np.sum(ref_wav ** 2) / (np.sum(dif ** 2) + 10e-20)
+    )
+    # global variables
+    winlength = int(np.round(30 * srate / 1000))  # 30 msecs
+    skiprate = winlength // 4
+    MIN_SNR = -10
+    MAX_SNR = 35
+
+    # For each frame, calculate SSNR
+    num_frames = int(clean_length / skiprate - (winlength / skiprate))
+    start = 0
+    time = np.linspace(1, winlength, winlength) / (winlength + 1)
+    window = 0.5 * (1 - np.cos(2 * np.pi * time))
+    segmental_snr = []
+
+    for frame_count in range(int(num_frames)):
+        # (1) get the frames for the test and ref speech.
+        # Apply Hanning Window
+        clean_frame = clean_speech[start : start + winlength]
+        processed_frame = processed_speech[start : start + winlength]
+        clean_frame = clean_frame * window
+        processed_frame = processed_frame * window
+
+        # (2) Compute Segmental SNR
+        signal_energy = np.sum(clean_frame ** 2)
+        noise_energy = np.sum((clean_frame - processed_frame) ** 2)
+        segmental_snr.append(
+            10 * np.log10(signal_energy / (noise_energy + eps) + eps)
+        )
+        segmental_snr[-1] = max(segmental_snr[-1], MIN_SNR)
+        segmental_snr[-1] = min(segmental_snr[-1], MAX_SNR)
+        start += int(skiprate)
+    return overall_snr, segmental_snr
+
+
+def wss(ref_wav, deg_wav, srate):
+    """ Calculate Weighted Spectral Slope (WSS) distortion
+    measure between reference and degraded audio signals.
+    This function computes the WSS distortion measure using
+    critical band filters and spectral slope differences.
+    """
+    clean_speech = ref_wav
+    processed_speech = deg_wav
+    clean_length = ref_wav.shape[0]
+    processed_length = deg_wav.shape[0]
+
+    assert clean_length == processed_length, clean_length
+
+    winlength = round(30 * srate / 1000.0)  # 240 wlen in samples
+    skiprate = np.floor(winlength / 4)
+    max_freq = srate / 2
+    num_crit = 25  # num of critical bands
+
+    n_fft = int(2 ** np.ceil(np.log(2 * winlength) / np.log(2)))
+    n_fftby2 = int(n_fft / 2)
+    Kmax = 20
+    Klocmax = 1
+
+    # Critical band filter definitions (Center frequency and BW in Hz)
+    cent_freq = [
+        50.0,
+        120,
+        190,
+        260,
+        330,
+        400,
+        470,
+        540,
+        617.372,
+        703.378,
+        798.717,
+        904.128,
+        1020.38,
+        1148.30,
+        1288.72,
+        1442.54,
+        1610.70,
+        1794.16,
+        1993.93,
+        2211.08,
+        2446.71,
+        2701.97,
+        2978.04,
+        3276.17,
+        3597.63,
+    ]
+    bandwidth = [
+        70.0,
+        70,
+        70,
+        70,
+        70,
+        70,
+        70,
+        77.3724,
+        86.0056,
+        95.3398,
+        105.411,
+        116.256,
+        127.914,
+        140.423,
+        153.823,
+        168.154,
+        183.457,
+        199.776,
+        217.153,
+        235.631,
+        255.255,
+        276.072,
+        298.126,
+        321.465,
+        346.136,
+    ]
+
+    bw_min = bandwidth[0]  # min critical bandwidth
+
+    # set up critical band filters. Note here that Gaussianly shaped filters
+    # are used. Also, the sum of the filter weights are equivalent for each
+    # critical band filter. Filter less than -30 dB and set to zero.
+    min_factor = np.exp(-30.0 / (2 * 2.303))  # -30 dB point of filter
+
+    crit_filter = np.zeros((num_crit, n_fftby2))
+    all_f0 = []
+    for i in range(num_crit):
+        f0 = (cent_freq[i] / max_freq) * (n_fftby2)
+        all_f0.append(np.floor(f0))
+        bw = (bandwidth[i] / max_freq) * (n_fftby2)
+        norm_factor = np.log(bw_min) - np.log(bandwidth[i])
+        j = list(range(n_fftby2))
+        crit_filter[i, :] = np.exp(
+            -11 * (((j - np.floor(f0)) / bw) ** 2) + norm_factor
+        )
+        crit_filter[i, :] = crit_filter[i, :] * (crit_filter[i, :] > min_factor)
+
+    # For each frame of input speech, compute Weighted Spectral Slope Measure
+    num_frames = int(clean_length / skiprate - (winlength / skiprate))
+    start = 0  # starting sample
+    time = np.linspace(1, winlength, winlength) / (winlength + 1)
+    window = 0.5 * (1 - np.cos(2 * np.pi * time))
+    distortion = []
+
+    for frame_count in range(num_frames):
+        # (1) Get the Frames for the test and reference speeech.
+        # Multiply by Hanning window.
+        clean_frame = clean_speech[start : start + winlength]
+        processed_frame = processed_speech[start : start + winlength]
+        clean_frame = clean_frame * window
+        processed_frame = processed_frame * window
+
+        # (2) Compuet Power Spectrum of clean and processed
+        clean_spec = np.abs(np.fft.fft(clean_frame, n_fft)) ** 2
+        processed_spec = np.abs(np.fft.fft(processed_frame, n_fft)) ** 2
+        clean_energy = [None] * num_crit
+        processed_energy = [None] * num_crit
+
+        # (3) Compute Filterbank output energies (in dB)
+        for i in range(num_crit):
+            clean_energy[i] = np.sum(clean_spec[:n_fftby2] * crit_filter[i, :])
+            processed_energy[i] = np.sum(
+                processed_spec[:n_fftby2] * crit_filter[i, :]
+            )
+        clean_energy = np.array(clean_energy).reshape(-1, 1)
+        eps = np.ones((clean_energy.shape[0], 1)) * 1e-10
+        clean_energy = np.concatenate((clean_energy, eps), axis=1)
+        clean_energy = 10 * np.log10(np.max(clean_energy, axis=1))
+        processed_energy = np.array(processed_energy).reshape(-1, 1)
+        processed_energy = np.concatenate((processed_energy, eps), axis=1)
+        processed_energy = 10 * np.log10(np.max(processed_energy, axis=1))
+
+        # (4) Compute Spectral Shape (dB[i+1] - dB[i])
+        clean_slope = clean_energy[1:num_crit] - clean_energy[: num_crit - 1]
+        processed_slope = (
+            processed_energy[1:num_crit] - processed_energy[: num_crit - 1]
+        )
+
+        # (5) Find the nearest peak locations in the spectra to each
+        # critical band. If the slope is negative, we search
+        # to the left. If positive, we search to the right.
+        clean_loc_peak = []
+        processed_loc_peak = []
+        for i in range(num_crit - 1):
+            if clean_slope[i] > 0:
+                # search to the right
+                n = i
+                while n < num_crit - 1 and clean_slope[n] > 0:
+                    n += 1
+                clean_loc_peak.append(clean_energy[n - 1])
+            else:
+                # search to the left
+                n = i
+                while n >= 0 and clean_slope[n] <= 0:
+                    n -= 1
+                clean_loc_peak.append(clean_energy[n + 1])
+            # find the peaks in the processed speech signal
+            if processed_slope[i] > 0:
+                n = i
+                while n < num_crit - 1 and processed_slope[n] > 0:
+                    n += 1
+                processed_loc_peak.append(processed_energy[n - 1])
+            else:
+                n = i
+                while n >= 0 and processed_slope[n] <= 0:
+                    n -= 1
+                processed_loc_peak.append(processed_energy[n + 1])
+
+        # (6) Compuet the WSS Measure for this frame. This includes
+        # determination of the weighting functino
+        dBMax_clean = max(clean_energy)
+        dBMax_processed = max(processed_energy)
+
+        # The weights are calculated by averaging individual
+        # weighting factors from the clean and processed frame.
+        # These weights W_clean and W_processed should range
+        # from 0 to 1 and place more emphasis on spectral
+        # peaks and less emphasis on slope differences in spectral
+        # valleys.  This procedure is described on page 1280 of
+        # Klatt's 1982 ICASSP paper.
+        clean_loc_peak = np.array(clean_loc_peak)
+        processed_loc_peak = np.array(processed_loc_peak)
+        Wmax_clean = Kmax / (Kmax + dBMax_clean - clean_energy[: num_crit - 1])
+        Wlocmax_clean = Klocmax / (
+            Klocmax + clean_loc_peak - clean_energy[: num_crit - 1]
+        )
+        W_clean = Wmax_clean * Wlocmax_clean
+        Wmax_processed = Kmax / (
+            Kmax + dBMax_processed - processed_energy[: num_crit - 1]
+        )
+        Wlocmax_processed = Klocmax / (
+            Klocmax + processed_loc_peak - processed_energy[: num_crit - 1]
+        )
+        W_processed = Wmax_processed * Wlocmax_processed
+        W = (W_clean + W_processed) / 2
+        distortion.append(
+            np.sum(
+                W
+                * (
+                    clean_slope[: num_crit - 1]
+                    - processed_slope[: num_crit - 1]
+                )
+                ** 2
+            )
+        )
+
+        # this normalization is not part of Klatt's paper, but helps
+        # to normalize the meaasure. Here we scale the measure by the sum of the
+        # weights
+        distortion[frame_count] = distortion[frame_count] / np.sum(W)
+        start += int(skiprate)
+    return distortion
+
+
+def llr(ref_wav, deg_wav, srate):
+    """Calculate Log Likelihood Ratio (LLR) distortion measure
+    between reference and degraded audio signals. This function
+    computes the LLR distortion measure between reference and
+    degraded audio signals using LPC analysis and autocorrelation
+    logs.
+    """
+    clean_speech = ref_wav
+    processed_speech = deg_wav
+    clean_length = ref_wav.shape[0]
+    processed_length = deg_wav.shape[0]
+    assert clean_length == processed_length, clean_length
+
+    winlength = round(30 * srate / 1000.0)  # 240 wlen in samples
+    skiprate = np.floor(winlength / 4)
+    if srate < 10000:
+        # LPC analysis order
+        P = 10
+    else:
+        P = 16
+
+    # For each frame of input speech, calculate the Log Likelihood Ratio
+    num_frames = int(clean_length / skiprate - (winlength / skiprate))
+    start = 0
+    time = np.linspace(1, winlength, winlength) / (winlength + 1)
+    window = 0.5 * (1 - np.cos(2 * np.pi * time))
+    distortion = []
+
+    for frame_count in range(num_frames):
+        # (1) Get the Frames for the test and reference speeech.
+        # Multiply by Hanning window.
+        clean_frame = clean_speech[start : start + winlength]
+        processed_frame = processed_speech[start : start + winlength]
+        clean_frame = clean_frame * window
+        processed_frame = processed_frame * window
+
+        # (2) Get the autocorrelation logs and LPC params used
+        # to compute the LLR measure
+        R_clean, Ref_clean, A_clean = lpcoeff(clean_frame, P)
+        R_processed, Ref_processed, A_processed = lpcoeff(processed_frame, P)
+        A_clean = A_clean[None, :]
+        A_processed = A_processed[None, :]
+
+        # (3) Compute the LLR measure
+        numerator = A_processed.dot(toeplitz(R_clean)).dot(A_processed.T)
+        denominator = A_clean.dot(toeplitz(R_clean)).dot(A_clean.T)
+
+        if (numerator / denominator) <= 0:
+            print(f"Numerator: {numerator}")
+            print(f"Denominator: {denominator}")
+
+        log_ = np.log(numerator / denominator)
+        distortion.append(np.squeeze(log_))
+        start += int(skiprate)
+    return np.nan_to_num(np.array(distortion))
+
+
+# -------------------------------------------------------------------------- #
+
+if __name__ == "__main__":
+    clean_path = sys.argv[1]
+    enhanced_path = sys.argv[2]
+    csig, cbak, covl, count = 0, 0, 0, 0
+    for _file in tqdm(os.listdir(clean_path)):
+        if _file.endswith("wav"):
+            clean_path_f = os.path.join(clean_path, _file)
+            enhanced_path_f = os.path.join(
+                enhanced_path, _file[:-4] + "_enhanced.wav"
+            )
+            clean_sig = librosa.load(clean_path_f, sr=None)[0]
+            enhanced_sig = librosa.load(enhanced_path_f, sr=None)[0]
+            res = eval_composite(clean_sig, enhanced_sig)
+            csig += res["csig"]
+            cbak += res["cbak"]
+            covl += res["covl"]
+            pesq += res["pesq"]
+            count += 1
+    print(f"CSIG: {csig/count}, CBAK: {cbak/count}, COVL: {covl/count}")
diff --git a/recipes/DNS/enhancement/dnsmos_local.py b/recipes/DNS/enhancement/dnsmos_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e334e88527737b9b3f10d918aec81aed6eaae2c
--- /dev/null
+++ b/recipes/DNS/enhancement/dnsmos_local.py
@@ -0,0 +1,195 @@
+"""
+Usage:
+    python dnsmos_local.py -t path/to/sepformer_enhc_clips -o dnsmos_enhance.csv
+
+Ownership: Microsoft
+"""
+
+import argparse
+import concurrent.futures
+import glob
+import os
+
+import librosa
+import numpy as np
+import onnxruntime as ort
+import pandas as pd
+import soundfile as sf
+from tqdm import tqdm
+
+SAMPLING_RATE = 16000
+INPUT_LENGTH = 9.01
+
+
+class ComputeScore:
+    """A class for computing MOS scores using an ONNX model and polynomial fitting.
+    """
+
+    def __init__(self, primary_model_path) -> None:
+        """Initialize the ComputeScore class.
+        """
+        self.onnx_sess = ort.InferenceSession(primary_model_path)
+
+    def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):
+        """Calculate MOS scores using polynomial fitting.
+        Returns a tuple containing MOS scores for speech,
+        background, and overall quality.
+        """
+        # if is_personalized_MOS:
+        #     p_ovr = np.poly1d([-0.00533021,  0.005101  ,  1.18058466, -0.11236046])
+        #     p_sig = np.poly1d([-0.01019296,  0.02751166,  1.19576786, -0.24348726])
+        #     p_bak = np.poly1d([-0.04976499,  0.44276479, -0.1644611 ,  0.96883132])
+        # else:
+        p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
+        p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439])
+        p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])
+
+        sig_poly = p_sig(sig)
+        bak_poly = p_bak(bak)
+        ovr_poly = p_ovr(ovr)
+
+        return sig_poly, bak_poly, ovr_poly
+
+    def __call__(self, fpath, sampling_rate, is_personalized_MOS):
+        """Compute MOS scores for an audio segment.
+        """
+        aud, input_fs = sf.read(fpath)
+        fs = sampling_rate
+        if input_fs != fs:
+            audio = librosa.resample(aud, input_fs, fs)
+        else:
+            audio = aud
+        actual_audio_len = len(audio)
+        len_samples = int(INPUT_LENGTH * fs)
+        while len(audio) < len_samples:
+            audio = np.append(audio, audio)
+
+        num_hops = int(np.floor(len(audio) / fs) - INPUT_LENGTH) + 1
+        hop_len_samples = fs
+        predicted_mos_sig_seg_raw = []
+        predicted_mos_bak_seg_raw = []
+        predicted_mos_ovr_seg_raw = []
+        predicted_mos_sig_seg = []
+        predicted_mos_bak_seg = []
+        predicted_mos_ovr_seg = []
+
+        for idx in range(num_hops):
+            audio_seg = audio[
+                int(idx * hop_len_samples) : int(
+                    (idx + INPUT_LENGTH) * hop_len_samples
+                )
+            ]
+            if len(audio_seg) < len_samples:
+                continue
+
+            input_features = np.array(audio_seg).astype("float32")[
+                np.newaxis, :
+            ]
+            oi = {"input_1": input_features}
+            mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(
+                None, oi
+            )[0][0]
+            mos_sig, mos_bak, mos_ovr = self.get_polyfit_val(
+                mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS=0
+            )
+            predicted_mos_sig_seg_raw.append(mos_sig_raw)
+            predicted_mos_bak_seg_raw.append(mos_bak_raw)
+            predicted_mos_ovr_seg_raw.append(mos_ovr_raw)
+            predicted_mos_sig_seg.append(mos_sig)
+            predicted_mos_bak_seg.append(mos_bak)
+            predicted_mos_ovr_seg.append(mos_ovr)
+
+        clip_dict = {
+            "filename": fpath,
+            "len_in_sec": actual_audio_len / fs,
+            "sr": fs,
+        }
+        clip_dict["num_hops"] = num_hops
+        clip_dict["OVRL_raw"] = np.mean(predicted_mos_ovr_seg_raw)
+        clip_dict["SIG_raw"] = np.mean(predicted_mos_sig_seg_raw)
+        clip_dict["BAK_raw"] = np.mean(predicted_mos_bak_seg_raw)
+        clip_dict["OVRL"] = np.mean(predicted_mos_ovr_seg)
+        clip_dict["SIG"] = np.mean(predicted_mos_sig_seg)
+        clip_dict["BAK"] = np.mean(predicted_mos_bak_seg)
+        return clip_dict
+
+
+def main(args):
+    models = glob.glob(os.path.join(args.testset_dir, "*"))
+    audio_clips_list = []
+
+    if args.personalized_MOS:
+        primary_model_path = os.path.join("pDNSMOS", "sig_bak_ovr.onnx")
+    else:
+        primary_model_path = os.path.join("DNSMOS", "sig_bak_ovr.onnx")
+
+    compute_score = ComputeScore(primary_model_path)
+
+    rows = []
+    clips = []
+    clips = glob.glob(os.path.join(args.testset_dir, "*.wav"))
+    is_personalized_eval = args.personalized_MOS
+    desired_fs = SAMPLING_RATE
+    for m in tqdm(models):
+        max_recursion_depth = 10
+        audio_path = os.path.join(args.testset_dir, m)
+        audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav"))
+        while len(audio_clips_list) == 0 and max_recursion_depth > 0:
+            audio_path = os.path.join(audio_path, "**")
+            audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav"))
+            max_recursion_depth -= 1
+        clips.extend(audio_clips_list)
+
+    with concurrent.futures.ThreadPoolExecutor() as executor:
+        future_to_url = {
+            executor.submit(
+                compute_score, clip, desired_fs, is_personalized_eval
+            ): clip
+            for clip in clips
+        }
+        for future in tqdm(concurrent.futures.as_completed(future_to_url)):
+            clip = future_to_url[future]
+            try:
+                data = future.result()
+            except Exception as exc:
+                print("%r generated an exception: %s" % (clip, exc))
+            else:
+                rows.append(data)
+
+    df = pd.DataFrame(rows)
+    if args.csv_path:
+        csv_path = args.csv_path
+        df.to_csv(csv_path)
+    else:
+        print(df.describe())
+
+    print("======== DNSMOS scores ======== ")
+    print("SIG:", df.loc[:, "SIG"].mean())
+    print("BAK:", df.loc[:, "BAK"].mean())
+    print("OVRL:", df.loc[:, "OVRL"].mean())
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "-t",
+        "--testset_dir",
+        default=".",
+        help="Path to the dir containing audio clips in .wav to be evaluated",
+    )
+    parser.add_argument(
+        "-o",
+        "--csv_path",
+        default=None,
+        help="Dir to the csv that saves the results",
+    )
+    parser.add_argument(
+        "-p",
+        "--personalized_MOS",
+        action="store_true",
+        help="Flag to indicate if personalized MOS score is needed or regular",
+    )
+
+    args = parser.parse_args()
+
+    main(args)
diff --git a/recipes/DNS/enhancement/extra_requirements.txt b/recipes/DNS/enhancement/extra_requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0c9d0e2fce43d3e973ed4333d087bc51019fc987
--- /dev/null
+++ b/recipes/DNS/enhancement/extra_requirements.txt
@@ -0,0 +1,8 @@
+librosa
+mir_eval
+onnxruntime
+pesq
+pyroomacoustics==0.3.1
+pystoi
+tensorboard
+webdataset
diff --git a/recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml b/recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87a07c97aba76e5bf0d0e3f500ee261f1b935cec
--- /dev/null
+++ b/recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml
@@ -0,0 +1,203 @@
+# ################################
+# Model: SepFormer model for speech enhancement
+# https://arxiv.org/abs/2010.13154
+#
+# Author:  Sangeet Sagar 2022
+# Dataset : Microsoft-DNS 4
+# ################################
+
+# Basic parameters
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1234
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/sepformer-enhancement-16k/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Data params
+data_folder: !PLACEHOLDER   # ../noisyspeech_synthesizer/synthesized_data_shards/
+train_data: !ref <data_folder>/train_shards/
+valid_data: !ref <data_folder>/valid_shards/
+baseline_noisy_shards_folder: !PLACEHOLDER     # ../DNS-shards/devsets_fullband/
+baseline_shards: !ref <baseline_noisy_shards_folder>/shard-{000000..999999}.tar
+
+# Set to a directory on a large disk if using Webdataset shards hosted on the web.
+shard_cache_dir:
+
+# Basic parameters
+use_tensorboard: True
+tensorboard_logs: !ref <output_folder>/logs/
+dereverberate: False
+
+# Experiment params
+precision: fp16 # bf16, fp16 or fp32
+test_only: False
+num_spks: 1
+noprogressbar: False
+save_audio: True    # Save estimated sources on disk
+sample_rate: 16000
+audio_length: 4 # seconds
+n_audio_to_save: 20
+
+####################### Training Parameters ####################################
+N_epochs: 100
+batch_size: 4
+batch_size_test: 1
+lr: 0.00015
+clip_grad_norm: 5
+loss_upper_lim: 999999  # this is the upper limit for an acceptable loss
+# if True, the training sequences are cut to a specified length
+limit_training_signal_len: False
+# this is the length of sequences if we choose to limit
+# the signal length of training sequences
+training_signal_len: 32000
+ckpt_interval_minutes: 60
+
+# Parameters for data augmentation
+use_wavedrop: False
+use_speedperturb: True
+use_rand_shift: False
+min_shift: -8000
+max_shift: 8000
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# loss thresholding -- this thresholds the training loss
+threshold_byloss: True
+threshold: -30
+
+# Encoder parameters
+N_encoder_out: 256
+out_channels: 256
+kernel_size: 16
+kernel_stride: 8
+
+# Dataloader options
+dataloader_opts:
+    batch_size: !ref <batch_size>
+    num_workers: 3
+
+dataloader_opts_valid:
+    batch_size: !ref <batch_size>
+    num_workers: 3
+
+dataloader_opts_test:
+    batch_size: !ref <batch_size_test>
+    num_workers: 3
+
+# Specifying the network
+Encoder: !new:speechbrain.lobes.models.dual_path.Encoder
+    kernel_size: !ref <kernel_size>
+    out_channels: !ref <N_encoder_out>
+
+SBtfintra: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
+    num_layers: 8
+    d_model: !ref <out_channels>
+    nhead: 8
+    d_ffn: 1024
+    dropout: 0
+    use_positional_encoding: True
+    norm_before: True
+
+SBtfinter: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
+    num_layers: 8
+    d_model: !ref <out_channels>
+    nhead: 8
+    d_ffn: 1024
+    dropout: 0
+    use_positional_encoding: True
+    norm_before: True
+
+MaskNet: !new:speechbrain.lobes.models.dual_path.Dual_Path_Model
+    num_spks: !ref <num_spks>
+    in_channels: !ref <N_encoder_out>
+    out_channels: !ref <out_channels>
+    num_layers: 2
+    K: 250
+    intra_model: !ref <SBtfintra>
+    inter_model: !ref <SBtfinter>
+    norm: ln
+    linear_layer_after_inter_intra: False
+    skip_around_intra: True
+
+Decoder: !new:speechbrain.lobes.models.dual_path.Decoder
+    in_channels: !ref <N_encoder_out>
+    out_channels: 1
+    kernel_size: !ref <kernel_size>
+    stride: !ref <kernel_stride>
+    bias: False
+
+optimizer: !name:torch.optim.Adam
+    lr: !ref <lr>
+    weight_decay: 0
+
+loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper
+
+lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
+    factor: 0.5
+    patience: 2
+    dont_halve_until_epoch: 85
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <N_epochs>
+
+modules:
+    encoder: !ref <Encoder>
+    decoder: !ref <Decoder>
+    masknet: !ref <MaskNet>
+
+save_all_checkpoints: False
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        encoder: !ref <Encoder>
+        decoder: !ref <Decoder>
+        masknet: !ref <MaskNet>
+        counter: !ref <epoch_counter>
+        lr_scheduler: !ref <lr_scheduler>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+## Uncomment to fine-tune a pre-trained model.
+# pretrained_enhancement: !new:speechbrain.utils.parameter_transfer.Pretrainer
+#     collect_in: !ref <save_folder>
+#     loadables:
+#         encoder: !ref <Encoder>
+#         decoder: !ref <Decoder>
+#         masknet: !ref <MaskNet>
+#     paths:
+#         encoder: !PLACEHOLDER
+#         decoder: !PLACEHOLDER
+#         masknet: !PLACEHOLDER
diff --git a/recipes/DNS/enhancement/train.py b/recipes/DNS/enhancement/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..de2b6cf235fa8023fc3075b79d28887c68cbc44d
--- /dev/null
+++ b/recipes/DNS/enhancement/train.py
@@ -0,0 +1,875 @@
+#!/usr/bin/env/python3
+"""Recipe for training a speech enhancement system on Microsoft DNS
+(Deep Noise Suppression) challenge dataset using SepFormer architecture.
+The system employs an encoder,a decoder, and a masking network.
+
+To run this recipe, do the following:
+python train.py hparams/sepformer-dns-16k.yaml --data_folder <path/to/synthesized_shards_data> --baseline_noisy_shards_folder <path/to/baseline_shards_data>
+
+The experiment file is flexible enough to support different neural
+networks. By properly changing the parameter files, you can try
+different architectures.
+
+Authors
+ * Sangeet Sagar 2022
+ * Cem Subakan 2020
+ * Mirco Ravanelli 2020
+ * Samuele Cornell 2020
+ * Mirko Bronzi 2020
+ * Jianyuan Zhong 2020
+"""
+
+import os
+import glob
+import sys
+import csv
+import json
+import logging
+import numpy as np
+from tqdm import tqdm
+from typing import Dict
+from functools import partial
+
+import torch
+import torchaudio
+import braceexpand
+import webdataset as wds
+import torch.nn.functional as F
+
+import speechbrain as sb
+from hyperpyyaml import load_hyperpyyaml
+from composite_eval import eval_composite
+import speechbrain.nnet.schedulers as schedulers
+from speechbrain.utils.distributed import run_on_main
+from speechbrain.utils.metric_stats import MetricStats
+from speechbrain.processing.features import spectral_magnitude
+from speechbrain.dataio.batch import PaddedBatch
+from speechbrain.core import AMPConfig
+
+from pesq import pesq
+from pystoi import stoi
+
+
+# Define training procedure
+class Enhancement(sb.Brain):
+    def compute_forward(self, noisy, clean, stage, noise=None):
+        """Forward computations from the noisy to the separated signals."""
+        # Unpack lists and put tensors in the right device
+        noisy, noisy_lens = noisy
+        noisy, noisy_lens = noisy.to(self.device), noisy_lens.to(self.device)
+        # Convert clean to tensor
+        clean = clean[0].unsqueeze(-1).to(self.device)
+
+        # Add speech distortions
+        if stage == sb.Stage.TRAIN:
+            with torch.no_grad():
+                if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
+                    noisy, clean = self.add_speed_perturb(clean, noisy_lens)
+
+                    # Reverb already added, not adding any reverb
+                    clean_rev = clean
+                    noisy = clean.sum(-1)
+                    # if we reverberate, we set the clean to be reverberant
+                    if not self.hparams.dereverberate:
+                        clean = clean_rev
+
+                    noise = noise.to(self.device)
+                    len_noise = noise.shape[1]
+                    len_noisy = noisy.shape[1]
+                    min_len = min(len_noise, len_noisy)
+
+                    # add the noise
+                    noisy = noisy[:, :min_len] + noise[:, :min_len]
+
+                    # fix the length of clean also
+                    clean = clean[:, :min_len, :]
+
+                if self.hparams.use_wavedrop:
+                    noisy = self.hparams.drop_chunk(noisy, noisy_lens)
+                    noisy = self.hparams.drop_freq(noisy)
+
+                if self.hparams.limit_training_signal_len:
+                    noisy, clean = self.cut_signals(noisy, clean)
+
+        # Enhancement
+        if self.use_freq_domain:
+            noisy_w = self.compute_feats(noisy)
+            est_mask = self.modules.masknet(noisy_w)
+
+            sep_h = noisy_w * est_mask
+            est_source = self.hparams.resynth(torch.expm1(sep_h), noisy)
+        else:
+            noisy_w = self.hparams.Encoder(noisy)
+            est_mask = self.modules.masknet(noisy_w)
+
+            sep_h = noisy_w * est_mask
+            est_source = self.hparams.Decoder(sep_h[0])
+
+        # T changed after conv1d in encoder, fix it here
+        T_origin = noisy.size(1)
+        T_est = est_source.size(1)
+        est_source = est_source.squeeze(-1)
+        if T_origin > T_est:
+            est_source = F.pad(est_source, (0, T_origin - T_est))
+        else:
+            est_source = est_source[:, :T_origin]
+
+        return [est_source, sep_h], clean.squeeze(-1)
+
+    def compute_feats(self, wavs):
+        """Feature computation pipeline"""
+        feats = self.hparams.Encoder(wavs)
+        feats = spectral_magnitude(feats, power=0.5)
+        feats = torch.log1p(feats)
+        return feats
+
+    def compute_objectives(self, predictions, clean):
+        """Computes the si-snr loss"""
+        predicted_wavs, predicted_specs = predictions
+
+        if self.use_freq_domain:
+            target_specs = self.compute_feats(clean)
+            return self.hparams.loss(target_specs, predicted_specs)
+        else:
+            return self.hparams.loss(
+                clean.unsqueeze(-1), predicted_wavs.unsqueeze(-1)
+            )
+
+    def fit_batch(self, batch):
+        """Trains one batch"""
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
+
+        # Unpacking batch list
+        noisy = batch.noisy_sig
+        clean = batch.clean_sig
+        noise = batch.noise_sig[0]
+
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
+                    predictions, clean = self.compute_forward(
+                        noisy, clean, sb.Stage.TRAIN, noise
+                    )
+                    loss = self.compute_objectives(predictions, clean)
+
+                    # hard threshold the easy dataitems
+                    if self.hparams.threshold_byloss:
+                        th = self.hparams.threshold
+                        loss_to_keep = loss[loss > th]
+                        if loss_to_keep.nelement() > 0:
+                            loss = loss_to_keep.mean()
+                    else:
+                        loss = loss.mean()
+
+                if (
+                    loss < self.hparams.loss_upper_lim and loss.nelement() > 0
+                ):  # the fix for computational problems
+                    self.scaler.scale(loss).backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        self.scaler.unscale_(self.optimizer)
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.scaler.step(self.optimizer)
+                    self.scaler.update()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
+                    )
+                    loss.data = torch.tensor(0).to(self.device)
+            else:
+                predictions, clean = self.compute_forward(
+                    noisy, clean, sb.Stage.TRAIN, noise
+                )
+                loss = self.compute_objectives(predictions, clean)
+
+                if self.hparams.threshold_byloss:
+                    th = self.hparams.threshold
+                    loss_to_keep = loss[loss > th]
+                    if loss_to_keep.nelement() > 0:
+                        loss = loss_to_keep.mean()
+                else:
+                    loss = loss.mean()
+
+                if (
+                    loss < self.hparams.loss_upper_lim and loss.nelement() > 0
+                ):  # the fix for computational problems
+                    loss.backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.optimizer.step()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
+                    )
+                    loss.data = torch.tensor(0).to(self.device)
+        self.optimizer.zero_grad()
+
+        return loss.detach().cpu()
+
+    def evaluate_batch(self, batch, stage):
+        """Computations needed for validation/test batches"""
+
+        snt_id = batch.id
+        noisy = batch.noisy_sig
+        clean = batch.clean_sig
+
+        with torch.no_grad():
+            predictions, clean = self.compute_forward(noisy, clean, stage)
+            loss = self.compute_objectives(predictions, clean)
+            loss = torch.mean(loss)
+
+        if stage != sb.Stage.TRAIN:
+            self.pesq_metric.append(
+                ids=batch.id, predict=predictions[0].cpu(), target=clean.cpu()
+            )
+
+        # Manage audio file saving
+        if stage == sb.Stage.TEST and self.hparams.save_audio:
+            if hasattr(self.hparams, "n_audio_to_save"):
+                if self.hparams.n_audio_to_save > 0:
+                    self.save_audio(snt_id[0], noisy, clean, predictions[0])
+                    self.hparams.n_audio_to_save += -1
+            else:
+                self.save_audio(snt_id[0], noisy, clean, predictions[0])
+
+        return loss.detach()
+
+    def on_stage_start(self, stage, epoch=None):
+        """Gets called at the beginning of each epoch"""
+        if stage != sb.Stage.TRAIN:
+            # Define function taking (prediction, target) for parallel eval
+            def pesq_eval(pred_wav, target_wav):
+                """Computes the PESQ evaluation metric"""
+                psq_mode = "wb" if self.hparams.sample_rate == 16000 else "nb"
+                try:
+                    return pesq(
+                        fs=self.hparams.sample_rate,
+                        ref=target_wav.numpy(),
+                        deg=pred_wav.numpy(),
+                        mode=psq_mode,
+                    )
+                except Exception:
+                    print("pesq encountered an error for this data item")
+                    return 0
+
+            self.pesq_metric = MetricStats(
+                metric=pesq_eval, n_jobs=1, batch_eval=False
+            )
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of a epoch."""
+        # Compute/store important stats
+        stage_stats = {"si-snr": stage_loss}
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_stats
+        else:
+            stats = {
+                "si-snr": stage_loss,
+                "pesq": self.pesq_metric.summarize("average"),
+            }
+
+        # Perform end-of-iteration things, like annealing, logging, etc.
+        if stage == sb.Stage.VALID:
+            # Save valid logs in TensorBoard
+            valid_stats = {
+                "Epochs": epoch,
+                "Valid SI-SNR": stage_loss,
+                "Valid PESQ": self.pesq_metric.summarize("average"),
+            }
+            if self.hparams.use_tensorboard:
+                self.hparams.tensorboard_train_logger.log_stats(valid_stats)
+
+            # Learning rate annealing
+            if isinstance(
+                self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
+            ):
+                current_lr, next_lr = self.hparams.lr_scheduler(
+                    [self.optimizer], epoch, stage_loss
+                )
+                schedulers.update_learning_rate(self.optimizer, next_lr)
+            else:
+                # if we do not use the reducelronplateau, we do not change the lr
+                current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={"epoch": epoch, "lr": current_lr},
+                train_stats=self.train_stats,
+                valid_stats=stats,
+            )
+            if (
+                hasattr(self.hparams, "save_all_checkpoints")
+                and self.hparams.save_all_checkpoints
+            ):
+                self.checkpointer.save_checkpoint(meta={"pesq": stats["pesq"]})
+            else:
+                self.checkpointer.save_and_keep_only(
+                    meta={"pesq": stats["pesq"]}, max_keys=["pesq"],
+                )
+        elif stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stats,
+            )
+
+    def add_speed_perturb(self, clean, targ_lens):
+        """Adds speed perturbation and random_shift to the input signals"""
+
+        min_len = -1
+        recombine = False
+
+        if self.hparams.use_speedperturb:
+            # Performing speed change (independently on each source)
+            new_clean = []
+            recombine = True
+
+            for i in range(clean.shape[-1]):
+                new_target = self.hparams.speed_perturb(clean[:, :, i])
+                new_clean.append(new_target)
+                if i == 0:
+                    min_len = new_target.shape[-1]
+                else:
+                    if new_target.shape[-1] < min_len:
+                        min_len = new_target.shape[-1]
+
+            if self.hparams.use_rand_shift:
+                # Performing random_shift (independently on each source)
+                recombine = True
+                for i in range(clean.shape[-1]):
+                    rand_shift = torch.randint(
+                        self.hparams.min_shift, self.hparams.max_shift, (1,)
+                    )
+                    new_clean[i] = new_clean[i].to(self.device)
+                    new_clean[i] = torch.roll(
+                        new_clean[i], shifts=(rand_shift[0],), dims=1
+                    )
+
+            # Re-combination
+            if recombine:
+                if self.hparams.use_speedperturb:
+                    clean = torch.zeros(
+                        clean.shape[0],
+                        min_len,
+                        clean.shape[-1],
+                        device=clean.device,
+                        dtype=torch.float,
+                    )
+                for i, new_target in enumerate(new_clean):
+                    clean[:, :, i] = new_clean[i][:, 0:min_len]
+
+        noisy = clean.sum(-1)
+        return noisy, clean
+
+    def cut_signals(self, noisy, clean):
+        """This function selects a random segment of a given length withing the noisy.
+        The corresponding clean are selected accordingly"""
+        randstart = torch.randint(
+            0,
+            1 + max(0, noisy.shape[1] - self.hparams.training_signal_len),
+            (1,),
+        ).item()
+        clean = clean[
+            :, randstart : randstart + self.hparams.training_signal_len, :
+        ]
+        noisy = noisy[
+            :, randstart : randstart + self.hparams.training_signal_len
+        ]
+        return noisy, clean
+
+    def reset_layer_recursively(self, layer):
+        """Reinitializes the parameters of the neural networks"""
+        if hasattr(layer, "reset_parameters"):
+            layer.reset_parameters()
+        for child_layer in layer.modules():
+            if layer != child_layer:
+                self.reset_layer_recursively(child_layer)
+
+    def save_results(self, valid_data):
+        """This script calculates the SDR and SI-SNR metrics
+        and stores them in a CSV file. As this evaluation
+        method depends on a gold-standard reference signal,
+        it is applied exclusively to the valid set and excludes
+        the baseline data.
+        """
+        # This package is required for SDR computation
+        from mir_eval.separation import bss_eval_sources
+
+        # Create folders where to store audio
+        save_file = os.path.join(
+            self.hparams.output_folder, "valid_results.csv"
+        )
+
+        # Variable init
+        all_sdrs = []
+        all_sdrs_i = []
+        all_sisnrs = []
+        all_sisnrs_i = []
+        all_pesqs = []
+        all_stois = []
+        all_csigs = []
+        all_cbaks = []
+        all_covls = []
+        csv_columns = [
+            "snt_id",
+            "sdr",
+            "sdr_i",
+            "si-snr",
+            "si-snr_i",
+            "pesq",
+            "stoi",
+            "csig",
+            "cbak",
+            "covl",
+        ]
+
+        valid_loader = sb.dataio.dataloader.make_dataloader(
+            valid_data, **self.hparams.dataloader_opts_test
+        )
+
+        with open(save_file, "w") as results_csv:
+            writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
+            writer.writeheader()
+
+            # Loop over all test sentence
+            with tqdm(valid_loader, dynamic_ncols=True) as t:
+                for i, batch in enumerate(t):
+                    # Apply Enhancement
+                    noisy, noisy_len = batch.noisy_sig
+                    snt_id = batch.id
+                    clean = batch.clean_sig
+
+                    with torch.no_grad():
+                        predictions, clean = self.compute_forward(
+                            batch.noisy_sig, clean, sb.Stage.TEST
+                        )
+
+                    # Compute PESQ
+                    psq_mode = (
+                        "wb" if self.hparams.sample_rate == 16000 else "nb"
+                    )
+
+                    try:
+                        # Compute SI-SNR
+                        sisnr = self.compute_objectives(predictions, clean)
+
+                        # Compute SI-SNR improvement
+                        noisy_signal = noisy
+
+                        noisy_signal = noisy_signal.to(clean.device)
+                        sisnr_baseline = self.compute_objectives(
+                            [noisy_signal.squeeze(-1), None], clean
+                        )
+                        sisnr_i = sisnr - sisnr_baseline
+
+                        # Compute SDR
+                        sdr, _, _, _ = bss_eval_sources(
+                            clean[0].t().cpu().numpy(),
+                            predictions[0][0].t().detach().cpu().numpy(),
+                        )
+
+                        sdr_baseline, _, _, _ = bss_eval_sources(
+                            clean[0].t().cpu().numpy(),
+                            noisy_signal[0].t().detach().cpu().numpy(),
+                        )
+
+                        sdr_i = sdr.mean() - sdr_baseline.mean()
+
+                        # Compute PESQ
+                        psq = pesq(
+                            self.hparams.sample_rate,
+                            clean.squeeze().cpu().numpy(),
+                            predictions[0].squeeze().cpu().numpy(),
+                            mode=psq_mode,
+                        )
+                        # Compute STOI
+                        stoi_score = stoi(
+                            clean.squeeze().cpu().numpy(),
+                            predictions[0].squeeze().cpu().numpy(),
+                            fs_sig=self.hparams.sample_rate,
+                            extended=False,
+                        )
+                        # Compute CSIG, CBAK, COVL
+                        composite_metrics = eval_composite(
+                            clean.squeeze().cpu().numpy(),
+                            predictions[0].squeeze().cpu().numpy(),
+                            self.hparams.sample_rate,
+                        )
+                    except Exception:
+                        # this handles all sorts of error that may
+                        # occur when evaluating an enhanced file.
+                        continue
+
+                    # Saving on a csv file
+                    row = {
+                        "snt_id": snt_id[0],
+                        "sdr": sdr.mean(),
+                        "sdr_i": sdr_i,
+                        "si-snr": -sisnr.item(),
+                        "si-snr_i": -sisnr_i.item(),
+                        "pesq": psq,
+                        "stoi": stoi_score,
+                        "csig": composite_metrics["csig"],
+                        "cbak": composite_metrics["cbak"],
+                        "covl": composite_metrics["covl"],
+                    }
+                    writer.writerow(row)
+
+                    # Metric Accumulation
+                    all_sdrs.append(sdr.mean())
+                    all_sdrs_i.append(sdr_i.mean())
+                    all_sisnrs.append(-sisnr.item())
+                    all_sisnrs_i.append(-sisnr_i.item())
+                    all_pesqs.append(psq)
+                    all_stois.append(stoi_score)
+                    all_csigs.append(composite_metrics["csig"])
+                    all_cbaks.append(composite_metrics["cbak"])
+                    all_covls.append(composite_metrics["covl"])
+
+                row = {
+                    "snt_id": "avg",
+                    "sdr": np.array(all_sdrs).mean(),
+                    "sdr_i": np.array(all_sdrs_i).mean(),
+                    "si-snr": np.array(all_sisnrs).mean(),
+                    "si-snr_i": np.array(all_sisnrs_i).mean(),
+                    "pesq": np.array(all_pesqs).mean(),
+                    "stoi": np.array(all_stois).mean(),
+                    "csig": np.array(all_csigs).mean(),
+                    "cbak": np.array(all_cbaks).mean(),
+                    "covl": np.array(all_covls).mean(),
+                }
+                writer.writerow(row)
+
+        logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
+        logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
+        logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
+        logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))
+        logger.info("Mean PESQ {}".format(np.array(all_pesqs).mean()))
+        logger.info("Mean STOI {}".format(np.array(all_stois).mean()))
+        logger.info("Mean CSIG {}".format(np.array(all_csigs).mean()))
+        logger.info("Mean CBAK {}".format(np.array(all_cbaks).mean()))
+        logger.info("Mean COVL {}".format(np.array(all_covls).mean()))
+
+    def save_audio(self, snt_id, noisy, clean, predictions):
+        "saves the test audio (noisy, clean, and estimated sources) on disk"
+        print("Saving enhanced sources (valid set)")
+
+        # Create output folders
+        save_path = os.path.join(
+            self.hparams.save_folder, "valid_audio_results"
+        )
+        save_path_enhanced = os.path.join(save_path, "enhanced_sources")
+        save_path_clean = os.path.join(save_path, "clean_sources")
+        save_path_noisy = os.path.join(save_path, "noisy_sources")
+
+        for path in [save_path_enhanced, save_path_clean, save_path_noisy]:
+            if not os.path.exists(path):
+                os.makedirs(path)
+
+        # Estimated source
+        signal = predictions[0, :]
+        signal = signal / signal.abs().max()
+        save_file = os.path.join(
+            save_path_enhanced, "item{}_sourcehat.wav".format(snt_id)
+        )
+        torchaudio.save(
+            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
+        )
+
+        # Original source
+        signal = clean[0, :]
+        signal = signal / signal.abs().max()
+        save_file = os.path.join(
+            save_path_clean, "item{}_source.wav".format(snt_id)
+        )
+        torchaudio.save(
+            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
+        )
+
+        # Noisy source
+        signal = noisy[0][0, :]
+        signal = signal / signal.abs().max()
+        save_file = os.path.join(
+            save_path_noisy, "item{}_noisy.wav".format(snt_id)
+        )
+        torchaudio.save(
+            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
+        )
+
+
+def dataio_prep(hparams):
+    """Creates data processing pipeline"""
+    speech_dirs = [
+        "read_speech",
+        "german_speech",
+        "french_speech",
+        "italian_speech",
+        "spanish_speech",
+        "russian_speech",
+    ]
+    audio_length = hparams["audio_length"]
+
+    train_shard_patterns = []
+    for dir in speech_dirs:
+        if not os.path.exists(os.path.join(hparams["train_data"], dir)):
+            dir = ""
+        shard_pattern = os.path.join(hparams["train_data"], dir, "shard-*.tar")
+        shard_files = glob.glob(shard_pattern)
+        train_shard_patterns.extend(shard_files)
+
+    valid_shard_patterns = []
+    for dir in speech_dirs:
+        if not os.path.exists(os.path.join(hparams["valid_data"], dir)):
+            dir = ""
+        shard_pattern = os.path.join(hparams["valid_data"], dir, "shard-*.tar")
+        shard_files = glob.glob(shard_pattern)
+        valid_shard_patterns.extend(shard_files)
+
+    def meta_loader(split_path):
+        # Initialize the total number of samples
+        total_samples = 0
+
+        # Walk through the all subdirs
+        # eg. german_speech, read_speech, ...
+        for root, _, files in os.walk(split_path):
+            for file in files:
+                if file == "meta.json":
+                    meta_json_path = os.path.join(root, file)
+                    with open(meta_json_path, "rb") as f:
+                        meta = json.load(f)
+                    total_samples += meta.get("num_data_samples", 0)
+
+        return total_samples
+
+    def train_audio_pipeline(sample_dict: Dict, random_chunk=True):
+        key = sample_dict["__key__"]
+        clean_wav = sample_dict["clean_file"]
+        noise_wav = sample_dict["noise_file"]
+        noisy_wav = sample_dict["noisy_file"]
+        clean_sig = sample_dict["clean_audio.pth"].squeeze()
+        noise_sig = sample_dict["noise_audio.pth"].squeeze()
+        noisy_sig = sample_dict["noisy_audio.pth"].squeeze()
+
+        return {
+            "id": key,
+            "clean_wav": clean_wav,
+            "clean_sig": clean_sig,
+            "noise_wav": noise_wav,
+            "noise_sig": noise_sig,
+            "noisy_wav": noisy_wav,
+            "noisy_sig": noisy_sig,
+        }
+
+    def baseline_audio_pipeline(sample_dict: Dict, random_chunk=True):
+        key = sample_dict["__key__"]
+        noisy_sig = sample_dict["audio.pth"].squeeze()
+
+        return {
+            "id": key,
+            "noisy_wav": key,
+            "noisy_sig": noisy_sig,
+        }
+
+    def create_combined_dataset(shard_patterns, cache_dir):
+        # mix multiple datasets, where each dataset consists of multiple shards
+        # e.g. combine read_speech, german_speech etc. each with multiple shards.
+        urls = [
+            url
+            for shard in shard_patterns
+            for url in braceexpand.braceexpand(shard)
+        ]
+
+        combined_dataset = (
+            wds.WebDataset(urls, shardshuffle=True, cache_dir=cache_dir,)
+            .repeat()
+            .shuffle(1000)
+            .decode("pil")
+            .map(partial(train_audio_pipeline, random_chunk=True))
+        )
+
+        return combined_dataset
+
+    train_data = create_combined_dataset(
+        train_shard_patterns, hparams["shard_cache_dir"]
+    )
+    train_samples = meta_loader(hparams["train_data"])
+    logger.info(f"Training data- Number of samples: {train_samples}")
+    logger.info(
+        f"Training data - Total duration: {train_samples * audio_length/ 3600:.2f} hours"
+    )
+
+    valid_data = create_combined_dataset(
+        valid_shard_patterns, hparams["shard_cache_dir"]
+    )
+    valid_samples = meta_loader(hparams["valid_data"])
+    logger.info(f"Valid data- Number of samples: {valid_samples}")
+    logger.info(
+        f"Valid data - Total duration: {valid_samples * audio_length  / 3600:.2f} hours"
+    )
+
+    baseline_data = (
+        wds.WebDataset(
+            hparams["baseline_shards"], cache_dir=hparams["shard_cache_dir"],
+        )
+        .repeat()
+        .shuffle(1000)
+        .decode("pil")
+        .map(partial(baseline_audio_pipeline, random_chunk=True))
+    )
+
+    return train_data, valid_data, train_samples, valid_samples, baseline_data
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Initialize ddp (useful only for multi-GPU DDP training)
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Logger info
+    logger = logging.getLogger(__name__)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Update precision to bf16 if the device is CPU and precision is fp16
+    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
+        hparams["precision"] = "bf16"
+
+    if hparams["use_tensorboard"]:
+        from speechbrain.utils.train_logger import TensorboardLogger
+
+        hparams["tensorboard_train_logger"] = TensorboardLogger(
+            hparams["tensorboard_logs"]
+        )
+
+    (
+        train_data,
+        valid_data,
+        num_train_samples,
+        num_valid_samples,
+        baseline_data,
+    ) = dataio_prep(hparams)
+
+    # add collate_fn to dataloader options
+    hparams["dataloader_opts"]["collate_fn"] = PaddedBatch
+    hparams["dataloader_opts_valid"]["collate_fn"] = PaddedBatch
+    hparams["dataloader_opts_test"]["collate_fn"] = PaddedBatch
+
+    hparams["dataloader_opts"]["looped_nominal_epoch"] = (
+        num_train_samples // hparams["dataloader_opts"]["batch_size"]
+    )
+    hparams["dataloader_opts_valid"]["looped_nominal_epoch"] = (
+        num_valid_samples // hparams["dataloader_opts_valid"]["batch_size"]
+    )
+    hparams["dataloader_opts_test"]["looped_nominal_epoch"] = (
+        num_valid_samples // hparams["dataloader_opts_test"]["batch_size"]
+    )
+
+    # Load pretrained model if pretrained_enhancement is present in the yaml
+    if "pretrained_enhancement" in hparams:
+        run_on_main(hparams["pretrained_enhancement"].collect_files)
+        hparams["pretrained_enhancement"].load_collected()
+
+    # Brain class initialization
+    enhancement = Enhancement(
+        modules=hparams["modules"],
+        opt_class=hparams["optimizer"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    # re-initialize the parameters if we don't use a pretrained model
+    if "pretrained_enhancement" not in hparams:
+        for module in enhancement.modules.values():
+            enhancement.reset_layer_recursively(module)
+
+    # determine if frequency domain enhancement or not
+    use_freq_domain = hparams.get("use_freq_domain", False)
+    enhancement.use_freq_domain = use_freq_domain
+
+    if not hparams["test_only"]:
+        # Training
+        enhancement.fit(
+            enhancement.hparams.epoch_counter,
+            train_data,
+            valid_data,
+            train_loader_kwargs=hparams["dataloader_opts"],
+            valid_loader_kwargs=hparams["dataloader_opts_valid"],
+        )
+
+    ## Evaluation on valid data
+    # (since our test set is blind)
+    enhancement.evaluate(
+        valid_data,
+        max_key="pesq",
+        test_loader_kwargs=hparams["dataloader_opts_valid"],
+    )
+    enhancement.save_results(valid_data)
+
+    ## Save enhanced sources of baseline noisy testclips
+    def save_baseline_audio(snt_id, predictions):
+        "saves the  estimated sources on disk"
+        # Create outout folder
+        save_path = os.path.join(
+            hparams["save_folder"], "baseline_audio_results"
+        )
+        save_path_enhanced = os.path.join(save_path, "enhanced_testclips")
+
+        if not os.path.exists(save_path_enhanced):
+            os.makedirs(save_path_enhanced)
+
+        # Estimated source
+        signal = predictions[0, :]
+        signal = signal / signal.abs().max()
+        save_file = os.path.join(save_path_enhanced, snt_id) + ".wav"
+
+        torchaudio.save(
+            save_file, signal.unsqueeze(0).cpu(), hparams["sample_rate"]
+        )
+
+    test_loader = sb.dataio.dataloader.make_dataloader(
+        baseline_data, **hparams["dataloader_opts_test"]
+    )
+
+    # Loop over all noisy baseline shards and save the enahanced clips
+    print("Saving enhanced sources (baseline set)")
+    with tqdm(test_loader, dynamic_ncols=True) as t:
+        for i, batch in enumerate(t):
+            # Apply Enhancement
+            snt_id = batch.id[0]
+
+            with torch.no_grad():
+                # Since only noisy sources are provided for baseline
+                # we use the compute_forward function with the same noisy
+                # signal for all inputs. (ugly hack)
+                predictions, clean = enhancement.compute_forward(
+                    batch.noisy_sig,
+                    batch.noisy_sig,
+                    batch.noisy_sig,
+                    sb.Stage.TEST,
+                )
+                predictions = predictions[0]
+
+            # Write enhanced wavs
+            save_baseline_audio(snt_id, predictions)
diff --git a/recipes/DNS/noisyspeech_synthesizer/README.md b/recipes/DNS/noisyspeech_synthesizer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b4c3e6870ff35054006a6b1d9a6b88dd15a44c19
--- /dev/null
+++ b/recipes/DNS/noisyspeech_synthesizer/README.md
@@ -0,0 +1,34 @@
+# **DNS: Noisy speech synthesizer**
+This folder contains scripts to synthesize noisy audio for training.
+Scripts have been taken from [official GitHub repo](https://github.com/microsoft/DNS-Challenge).
+
+Modify parameters like `sampling_rate`, `audio_length` , `total_hours` etc in the YAML file as per your requirement.
+
+## Synthesize clean-noisy data and create the Webdataset shards
+Synthesize clean-noisy data and create WebDataset shards.
+
+### **Usage**
+To create noisy dataset, run
+```
+## synthesize read speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name read_speech --synthesized_data_dir synthesized_data_shards
+
+## synthesize German speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name german_speech --synthesized_data_dir synthesized_data_shards
+
+## synthesize Italian speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name italian_speech --synthesized_data_dir synthesized_data_shards
+
+## synthesize French speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name french_speech --synthesized_data_dir synthesized_data_shards
+
+## synthesize Spanish speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name spanish_speech --synthesized_data_dir synthesized_data_shards
+
+## synthesize Russian speech
+python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name russian_speech --synthesized_data_dir synthesized_data_shards
+```
+
+It's recommended to execute these commands in parallel for quicker synthesis.
+
+**Time** : It takes about 140 HRS to synthesize a dataset of 500 HRS. This calls the need for dynamic mixing.
diff --git a/recipes/DNS/noisyspeech_synthesizer/audiolib.py b/recipes/DNS/noisyspeech_synthesizer/audiolib.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a1787923ce349a5972a34419831e6c4e9c24894
--- /dev/null
+++ b/recipes/DNS/noisyspeech_synthesizer/audiolib.py
@@ -0,0 +1,352 @@
+"""
+Source: https://github.com/microsoft/DNS-Challenge
+Ownership: Microsoft
+
+* Author
+    chkarada
+"""
+
+import os
+import numpy as np
+import soundfile as sf
+import subprocess
+import glob
+import librosa
+
+EPS = np.finfo(float).eps
+np.random.seed(0)
+
+
+def is_clipped(audio, clipping_threshold=0.99):
+    """Check if an audio signal is clipped.
+    """
+    return any(abs(audio) > clipping_threshold)
+
+
+def normalize(audio, target_level=-25):
+    """Normalize the signal to the target level"""
+    rms = (audio ** 2).mean() ** 0.5
+    scalar = 10 ** (target_level / 20) / (rms + EPS)
+    audio = audio * scalar
+    return audio
+
+
+def normalize_segmental_rms(audio, rms, target_level=-25):
+    """Normalize the signal to the target level
+    based on segmental RMS"""
+    scalar = 10 ** (target_level / 20) / (rms + EPS)
+    audio = audio * scalar
+    return audio
+
+
+def audioread(path, norm=False, start=0, stop=None, target_level=-25):
+    """Function to read audio"""
+
+    path = os.path.abspath(path)
+    if not os.path.exists(path):
+        raise ValueError("[{}] does not exist!".format(path))
+    try:
+        audio, sample_rate = sf.read(path, start=start, stop=stop)
+    except RuntimeError:  # fix for sph pcm-embedded shortened v2
+        print("WARNING: Audio type not supported")
+        return (None, None)
+
+    if len(audio.shape) == 1:  # mono
+        if norm:
+            rms = (audio ** 2).mean() ** 0.5
+            scalar = 10 ** (target_level / 20) / (rms + EPS)
+            audio = audio * scalar
+    else:  # multi-channel
+        audio = audio.T
+        audio = audio.sum(axis=0) / audio.shape[0]
+        if norm:
+            audio = normalize(audio, target_level)
+
+    return audio, sample_rate
+
+
+def audiowrite(
+    destpath,
+    audio,
+    sample_rate=16000,
+    norm=False,
+    target_level=-25,
+    clipping_threshold=0.99,
+    clip_test=False,
+):
+    """Function to write audio"""
+
+    if clip_test:
+        if is_clipped(audio, clipping_threshold=clipping_threshold):
+            raise ValueError(
+                "Clipping detected in audiowrite()! "
+                + destpath
+                + " file not written to disk."
+            )
+
+    if norm:
+        audio = normalize(audio, target_level)
+        max_amp = max(abs(audio))
+        if max_amp >= clipping_threshold:
+            audio = audio / max_amp * (clipping_threshold - EPS)
+
+    destpath = os.path.abspath(destpath)
+    destdir = os.path.dirname(destpath)
+
+    if not os.path.exists(destdir):
+        os.makedirs(destdir)
+
+    sf.write(destpath, audio, sample_rate)
+    return
+
+
+def add_reverb(sasxExe, input_wav, filter_file, output_wav):
+    """ Function to add reverb"""
+    command_sasx_apply_reverb = "{0} -r {1} \
+        -f {2} -o {3}".format(
+        sasxExe, input_wav, filter_file, output_wav
+    )
+
+    subprocess.call(command_sasx_apply_reverb)
+    return output_wav
+
+
+def add_clipping(audio, max_thresh_perc=0.8):
+    """Function to add clipping"""
+    threshold = max(abs(audio)) * max_thresh_perc
+    audioclipped = np.clip(audio, -threshold, threshold)
+    return audioclipped
+
+
+def adsp_filter(Adspvqe, nearEndInput, nearEndOutput, farEndInput):
+
+    command_adsp_clean = "{0} --breakOnErrors 0 --sampleRate 16000 --useEchoCancellation 0 \
+                    --operatingMode 2 --useDigitalAgcNearend 0 --useDigitalAgcFarend 0 \
+                    --useVirtualAGC 0 --useComfortNoiseGenerator 0 --useAnalogAutomaticGainControl 0 \
+                    --useNoiseReduction 0 --loopbackInputFile {1} --farEndInputFile {2} \
+                    --nearEndInputFile {3} --nearEndOutputFile {4}".format(
+        Adspvqe, farEndInput, farEndInput, nearEndInput, nearEndOutput
+    )
+    subprocess.call(command_adsp_clean)
+
+
+def snr_mixer(
+    params, clean, noise, snr, target_level=-25, clipping_threshold=0.99
+):
+    """Function to mix clean speech and noise at various SNR levels"""
+    # cfg = params['cfg']
+    if len(clean) > len(noise):
+        noise = np.append(noise, np.zeros(len(clean) - len(noise)))
+    else:
+        clean = np.append(clean, np.zeros(len(noise) - len(clean)))
+
+    # Normalizing to -25 dB FS
+    clean = clean / (max(abs(clean)) + EPS)
+    clean = normalize(clean, target_level)
+    rmsclean = (clean ** 2).mean() ** 0.5
+
+    noise = noise / (max(abs(noise)) + EPS)
+    noise = normalize(noise, target_level)
+    rmsnoise = (noise ** 2).mean() ** 0.5
+
+    # Set the noise level for a given SNR
+    noisescalar = rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)
+    noisenewlevel = noise * noisescalar
+
+    # Mix noise and clean speech
+    noisyspeech = clean + noisenewlevel
+
+    # Randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
+    # There is a chance of clipping that might happen with very less probability, which is not a major issue.
+    noisy_rms_level = np.random.randint(
+        params["target_level_lower"], params["target_level_upper"]
+    )
+    rmsnoisy = (noisyspeech ** 2).mean() ** 0.5
+    scalarnoisy = 10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)
+    noisyspeech = noisyspeech * scalarnoisy
+    clean = clean * scalarnoisy
+    noisenewlevel = noisenewlevel * scalarnoisy
+
+    # Final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
+    if is_clipped(noisyspeech):
+        noisyspeech_maxamplevel = max(abs(noisyspeech)) / (
+            clipping_threshold - EPS
+        )
+        noisyspeech = noisyspeech / noisyspeech_maxamplevel
+        clean = clean / noisyspeech_maxamplevel
+        noisenewlevel = noisenewlevel / noisyspeech_maxamplevel
+        noisy_rms_level = int(
+            20
+            * np.log10(scalarnoisy / noisyspeech_maxamplevel * (rmsnoisy + EPS))
+        )
+
+    return clean, noisenewlevel, noisyspeech, noisy_rms_level
+
+
+def segmental_snr_mixer(
+    params, clean, noise, snr, target_level=-25, clipping_threshold=0.99
+):
+    """Function to mix clean speech and noise at various segmental SNR levels"""
+    # cfg = params['cfg']
+    if len(clean) > len(noise):
+        noise = np.append(noise, np.zeros(len(clean) - len(noise)))
+    else:
+        clean = np.append(clean, np.zeros(len(noise) - len(clean)))
+    clean = clean / (max(abs(clean)) + EPS)
+    noise = noise / (max(abs(noise)) + EPS)
+    rmsclean, rmsnoise = active_rms(clean=clean, noise=noise)
+    clean = normalize_segmental_rms(
+        clean, rms=rmsclean, target_level=target_level
+    )
+    noise = normalize_segmental_rms(
+        noise, rms=rmsnoise, target_level=target_level
+    )
+    # Set the noise level for a given SNR
+    noisescalar = rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)
+    noisenewlevel = noise * noisescalar
+
+    # Mix noise and clean speech
+    noisyspeech = clean + noisenewlevel
+    # Randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
+    # There is a chance of clipping that might happen with very less probability, which is not a major issue.
+    noisy_rms_level = np.random.randint(
+        params["target_level_lower"], params["target_level_upper"]
+    )
+    rmsnoisy = (noisyspeech ** 2).mean() ** 0.5
+    scalarnoisy = 10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)
+    noisyspeech = noisyspeech * scalarnoisy
+    clean = clean * scalarnoisy
+    noisenewlevel = noisenewlevel * scalarnoisy
+    # Final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
+    if is_clipped(noisyspeech):
+        noisyspeech_maxamplevel = max(abs(noisyspeech)) / (
+            clipping_threshold - EPS
+        )
+        noisyspeech = noisyspeech / noisyspeech_maxamplevel
+        clean = clean / noisyspeech_maxamplevel
+        noisenewlevel = noisenewlevel / noisyspeech_maxamplevel
+        noisy_rms_level = int(
+            20
+            * np.log10(scalarnoisy / noisyspeech_maxamplevel * (rmsnoisy + EPS))
+        )
+
+    return clean, noisenewlevel, noisyspeech, noisy_rms_level
+
+
+def active_rms(clean, noise, fs=16000, energy_thresh=-50):
+    """Returns the clean and noise RMS of the noise calculated only in the active portions"""
+    window_size = 100  # in ms
+    window_samples = int(fs * window_size / 1000)
+    sample_start = 0
+    noise_active_segs = []
+    clean_active_segs = []
+
+    while sample_start < len(noise):
+        sample_end = min(sample_start + window_samples, len(noise))
+        noise_win = noise[sample_start:sample_end]
+        clean_win = clean[sample_start:sample_end]
+        noise_seg_rms = (noise_win ** 2).mean() ** 0.5
+        # Considering frames with energy
+        if noise_seg_rms > energy_thresh:
+            noise_active_segs = np.append(noise_active_segs, noise_win)
+            clean_active_segs = np.append(clean_active_segs, clean_win)
+        sample_start += window_samples
+
+    if len(noise_active_segs) != 0:
+        noise_rms = (noise_active_segs ** 2).mean() ** 0.5
+    else:
+        noise_rms = EPS
+
+    if len(clean_active_segs) != 0:
+        clean_rms = (clean_active_segs ** 2).mean() ** 0.5
+    else:
+        clean_rms = EPS
+
+    return clean_rms, noise_rms
+
+
+def activitydetector(audio, fs=16000, energy_thresh=0.13, target_level=-25):
+    """Return the percentage of the time the audio signal is above an energy threshold"""
+
+    audio = normalize(audio, target_level)
+    window_size = 50  # in ms
+    window_samples = int(fs * window_size / 1000)
+    sample_start = 0
+    cnt = 0
+    prev_energy_prob = 0
+    active_frames = 0
+
+    a = -1
+    b = 0.2
+    alpha_rel = 0.05
+    alpha_att = 0.8
+
+    while sample_start < len(audio):
+        sample_end = min(sample_start + window_samples, len(audio))
+        audio_win = audio[sample_start:sample_end]
+        frame_rms = 20 * np.log10(sum(audio_win ** 2) + EPS)
+        frame_energy_prob = 1.0 / (1 + np.exp(-(a + b * frame_rms)))
+
+        if frame_energy_prob > prev_energy_prob:
+            smoothed_energy_prob = (
+                frame_energy_prob * alpha_att
+                + prev_energy_prob * (1 - alpha_att)
+            )
+        else:
+            smoothed_energy_prob = (
+                frame_energy_prob * alpha_rel
+                + prev_energy_prob * (1 - alpha_rel)
+            )
+
+        if smoothed_energy_prob > energy_thresh:
+            active_frames += 1
+        prev_energy_prob = frame_energy_prob
+        sample_start += window_samples
+        cnt += 1
+
+    perc_active = active_frames / cnt
+    return perc_active
+
+
+def resampler(input_dir, target_sr=16000, ext="*.wav"):
+    """Resamples the audio files in input_dir to target_sr"""
+    files = glob.glob(f"{input_dir}/" + ext)
+    for pathname in files:
+        print(pathname)
+        try:
+            audio, fs = audioread(pathname)
+            audio_resampled = librosa.core.resample(audio, fs, target_sr)
+            audiowrite(pathname, audio_resampled, target_sr)
+        except:  # noqa
+            continue
+
+
+def audio_segmenter(input_dir, dest_dir, segment_len=10, ext="*.wav"):
+    """Segments the audio clips in dir to segment_len in secs"""
+    files = glob.glob(f"{input_dir}/" + ext)
+    for i in range(len(files)):
+        audio, fs = audioread(files[i])
+
+        if (
+            len(audio) > (segment_len * fs)
+            and len(audio) % (segment_len * fs) != 0
+        ):
+            audio = np.append(
+                audio,
+                audio[0 : segment_len * fs - (len(audio) % (segment_len * fs))],
+            )
+        if len(audio) < (segment_len * fs):
+            while len(audio) < (segment_len * fs):
+                audio = np.append(audio, audio)
+            audio = audio[: segment_len * fs]
+
+        num_segments = int(len(audio) / (segment_len * fs))
+        audio_segments = np.split(audio, num_segments)
+
+        basefilename = os.path.basename(files[i])
+        basename, ext = os.path.splitext(basefilename)
+
+        for j in range(len(audio_segments)):
+            newname = basename + "_" + str(j) + ext
+            destpath = os.path.join(dest_dir, newname)
+            audiowrite(destpath, audio_segments[j], fs)
diff --git a/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer.yaml b/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..26e3073702f310c314fbafd27c614e1e246a96c4
--- /dev/null
+++ b/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer.yaml
@@ -0,0 +1,101 @@
+# yamllint disable
+################################
+# Configuration for generating Noisy Speech Dataset
+# - sampling_rate: Specify the sampling rate. Default is 16 kHz
+# - audioformat: default is .wav
+# - audio_length: Minimum Length of each audio clip (noisy and clean speech)
+#   in seconds that will be generated by augmenting utterances.
+# - silence_length: Duration of silence introduced between clean speech
+#   utterances.
+# - total_hours: Total number of hours of data required. Units are in hours.
+# - snr_lower: Lower bound for SNR required (default: 0 dB)
+# - snr_upper: Upper bound for SNR required (default: 40 dB)
+# - target_level_lower: Lower bound for the target audio level
+#   before audiowrite (default: -35 dB)
+# - target_level_upper: Upper bound for the target audio level
+#   before audiowrite (default: -15 dB)
+# - total_snrlevels: Number of SNR levels required (default: 5, which means
+#   there are 5 levels between snr_lower and snr_upper)
+# - clean_activity_threshold: Activity threshold for clean speech
+# - noise_activity_threshold: Activity threshold for noise
+# - fileindex_start: Starting file ID that will be used in filenames
+# - fileindex_end: Last file ID that will be used in filenames
+# - is_test_set: Set it to True if it is the test set, else False for the
+# - log_dir: Specify path to the directory to store all the log files
+# ################################
+# yamllint enable
+
+
+# Data storage params
+input_shards_dir: !PLACEHOLDER  #  ../DNS-shards
+split_name: !PLACEHOLDER # read_speech, german_speech, italian_speech, french_speech etc
+rirs: RIR_table_simple.csv
+
+# Noisy data synthesis params
+sampling_rate: 16000 # sampling rate of synthesized signal
+audioformat: "*.wav"
+audio_length: 4
+silence_length: 0.2
+total_hours: 100
+snr_lower: -5
+snr_upper: 15
+randomize_snr: True
+target_level_lower: -35
+target_level_upper: -15
+total_snrlevels: 21
+clean_activity_threshold: 0.6
+noise_activity_threshold: 0.0
+fileindex_start: None
+fileindex_end: None
+is_test_set: False
+
+# Source dir
+rir_table_csv: !ref <rirs>
+
+# Directory path where Webdatasets of DNS clean and noise shards are located.
+input_sampling_rate: 48000  # sampling rate of input signal
+clean_meta: !ref <input_shards_dir>/clean_fullband/<split_name>/meta.json
+noise_meta: !ref <input_shards_dir>/noise_fullband/meta.json
+clean_fullband_shards: !ref <input_shards_dir>/clean_fullband/<split_name>/shard-{000000..999999}.tar
+noise_fullband_shards: !ref <input_shards_dir>/noise_fullband/shard-{000000..999999}.tar
+
+# Configuration for synthesizing shards of clean-noisy pairs.
+samples_per_shard: 5000
+
+# Destination directory for storing shards of synthesized data.
+synthesized_data_dir: !PLACEHOLDER  #  synthesized_data_shards
+train_shard_destination: !ref <synthesized_data_dir>/train_shards/<split_name>
+valid_shard_destination: !ref <synthesized_data_dir>/valid_shards/<split_name>
+
+# Set to a directory on a large disk if using Webdataset shards hosted on the web.
+shard_cache_dir:
+
+# These can be skipped. (uncomment if you want to use them)
+# clean_singing: !PLACEHOLDER # ../DNS-shards/clean_fullband/VocalSet_48kHz_mono/
+# clean_emotion: !PLACEHOLDER # ../DNS-shards/clean_fullband/emotional_speech/
+## Aishell data needs to be downloaded separately.
+# clean_mandarin: !PLACEHOLDER # ../DNS-shards/clean_fullband/mandrin_speech/data_aishell
+
+log_dir: !ref <split_name>_logs
+noise_types_excluded: None
+
+## Config: add singing voice to clean speech
+use_singing_data: 0 # 0 for no, 1 for yes
+# 1 for only male, 2 for only female, 3 (default) for both male and female
+singing_choice: 3
+
+## Config: add emotional data to clean speech
+# 0 for no, 1 for yes
+use_emotion_data: 0
+
+## Config: add Chinese (mandarin) data to clean speech
+# 0 for no, 1 for yes
+use_mandarin_data: 0
+
+## Config: add reverb to clean speech
+# 1 for only real rir, 2 for only synthetic rir, 3 (default) use both real and synthetic
+rir_choice: 3
+# lower bound of t60 range in seconds
+lower_t60: 0.3
+# upper bound of t60 range in seconds
+upper_t60: 1.3
diff --git a/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer_singleprocess.py b/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer_singleprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..478c786d054a5e3e5f832476802c4e27a11c2ab1
--- /dev/null
+++ b/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer_singleprocess.py
@@ -0,0 +1,720 @@
+"""
+Source: https://github.com/microsoft/DNS-Challenge
+Ownership: Microsoft
+
+This script will attempt to use each clean and noise
+webdataset shards to synthesize clean-noisy pairs of
+audio. The output is again stored in webdataset shards.
+
+* Author
+    chkarada
+
+* Further modified
+    Sangeet Sagar (2023)
+"""
+
+# Note: This single process audio synthesizer will attempt to use each clean
+# speech sourcefile once (from the webdataset shards), as it does not
+# randomly sample from these files
+
+import sys
+import os
+from pathlib import Path
+import random
+import time
+
+import numpy as np
+from scipy import signal
+from scipy.io import wavfile
+
+import librosa
+
+import utils
+from audiolib import (
+    segmental_snr_mixer,
+    activitydetector,
+    is_clipped,
+)
+
+import pandas as pd
+import json
+from functools import partial
+from typing import Dict
+from collections import defaultdict
+
+
+import speechbrain as sb
+import webdataset as wds
+from hyperpyyaml import load_hyperpyyaml
+import torch
+
+np.random.seed(5)
+random.seed(5)
+
+MAXTRIES = 50
+MAXFILELEN = 100
+
+start = time.time()
+
+
+def add_pyreverb(clean_speech, rir):
+    """
+    Add reverb to cean signal
+    """
+    reverb_speech = signal.fftconvolve(clean_speech, rir, mode="full")
+
+    # make reverb_speech same length as clean_speech
+    reverb_speech = reverb_speech[0 : clean_speech.shape[0]]
+
+    return reverb_speech
+
+
+def build_audio(is_clean, params, index, audio_samples_length=-1):
+    """Construct an audio signal from source files"""
+
+    fs_output = params["fs"]
+    silence_length = params["silence_length"]
+    if audio_samples_length == -1:
+        audio_samples_length = int(params["audio_length"] * params["fs"])
+
+    output_audio = np.zeros(0)
+    remaining_length = audio_samples_length
+    files_used = []
+    clipped_files = []
+
+    if is_clean:
+        data_iterator = iter(params["clean_data"])
+        idx = index
+    else:
+        data_iterator = iter(params["noise_data"])
+        idx = index
+
+    # initialize silence
+    silence = np.zeros(int(fs_output * silence_length))
+
+    # iterate through multiple clips until we have a long enough signal
+    tries_left = MAXTRIES
+    while remaining_length > 0 and tries_left > 0:
+        # read next audio file and resample if necessary
+        fs_input = params["fs_input"]
+        batch = next(data_iterator)
+        input_audio = batch["sig"].numpy()
+
+        if input_audio is None:
+            sys.stderr.write(
+                "\nWARNING: Cannot read file: %s\n" % batch["__key__"]
+            )
+            continue
+        if fs_input != fs_output:
+            input_audio = librosa.resample(
+                input_audio, orig_sr=fs_input, target_sr=fs_output
+            )
+
+        # if current file is longer than remaining desired length, and this is
+        # noise generation or this is training set, subsample it randomly
+        if len(input_audio) > remaining_length and (
+            not is_clean or not params["is_test_set"]
+        ):
+            idx_seg = np.random.randint(0, len(input_audio) - remaining_length)
+            input_audio = input_audio[idx_seg : idx_seg + remaining_length]
+
+        # check for clipping, and if found move onto next file
+        if is_clipped(input_audio):
+            clipped_files.append(batch["__key__"])
+            tries_left -= 1
+            continue
+
+        # concatenate current input audio to output audio stream
+        files_used.append(batch["__key__"])
+        output_audio = np.append(output_audio, input_audio)
+        remaining_length -= len(input_audio)
+
+        # add some silence if we have not reached desired audio length
+        if remaining_length > 0:
+            silence_len = min(remaining_length, len(silence))
+            output_audio = np.append(output_audio, silence[:silence_len])
+            remaining_length -= silence_len
+
+    if tries_left == 0 and not is_clean and "noise_data" in params.keys():
+        print(
+            "There are not enough non-clipped files in the "
+            + "given noise directory to complete the audio build"
+        )
+        return [], [], clipped_files, idx
+
+    return output_audio, files_used, clipped_files, idx
+
+
+def gen_audio(is_clean, params, index, audio_samples_length=-1):
+    """Calls build_audio() to get an audio signal, and verify that it meets the
+    activity threshold"""
+
+    clipped_files = []
+    low_activity_files = []
+    if audio_samples_length == -1:
+        audio_samples_length = int(params["audio_length"] * params["fs"])
+    if is_clean:
+        activity_threshold = params["clean_activity_threshold"]
+    else:
+        activity_threshold = params["noise_activity_threshold"]
+
+    while True:
+        audio, source_files, new_clipped_files, index = build_audio(
+            is_clean, params, index, audio_samples_length
+        )
+
+        clipped_files += new_clipped_files
+        if len(audio) < audio_samples_length:
+            continue
+
+        if activity_threshold == 0.0:
+            break
+
+        percactive = activitydetector(audio=audio)
+        if percactive > activity_threshold:
+            break
+        else:
+            low_activity_files += source_files
+
+    return audio, source_files, clipped_files, low_activity_files, index
+
+
+def main_gen(params):
+    """Calls gen_audio() to generate the audio signals, verifies that they meet
+    the requirements, and writes the files to storage"""
+
+    clean_source_files = []
+    clean_clipped_files = []
+    clean_low_activity_files = []
+    noise_source_files = []
+    noise_clipped_files = []
+    noise_low_activity_files = []
+
+    clean_index = 0
+    noise_index = 0
+
+    # write shards
+    train_shards_path = Path(params["train_shard_destination"])
+    train_shards_path.mkdir(exist_ok=True, parents=True)
+    valid_shards_path = Path(params["valid_shard_destination"])
+    valid_shards_path.mkdir(exist_ok=True, parents=True)
+
+    all_keys = set()
+    train_pattern = str(train_shards_path / "shard") + "-%06d.tar"
+    valid_pattern = str(valid_shards_path / "shard") + "-%06d.tar"
+    samples_per_shard = params["samples_per_shard"]
+
+    # track statistics on data
+    train_sample_keys = defaultdict(list)
+    valid_sample_keys = defaultdict(list)
+
+    # Define the percentage of data to be used for validation
+    validation_percentage = 0.05
+
+    # Calculate the number of samples for training and validation
+    total_samples = params["fileindex_end"] - params["fileindex_start"] + 1
+    num_validation_samples = int(total_samples * validation_percentage)
+
+    # Define separate ShardWriters for training and validation data
+    train_writer = wds.ShardWriter(train_pattern, maxcount=samples_per_shard)
+    valid_writer = wds.ShardWriter(valid_pattern, maxcount=samples_per_shard)
+
+    # Initialize counters and data lists for statistics
+    file_num = params["fileindex_start"]
+    train_data_tuples = []
+    valid_data_tuples = []
+
+    while file_num <= params["fileindex_end"]:
+        print(
+            "\rFiles synthesized {:4d}/{:4d}".format(
+                file_num, params["fileindex_end"]
+            ),
+            end="",
+        )
+        # CLEAN SPEECH GENERATION
+        clean, clean_sf, clean_cf, clean_laf, clean_index = gen_audio(
+            True, params, clean_index
+        )
+
+        # add reverb with selected RIR
+        rir_index = random.randint(0, len(params["myrir"]) - 1)
+
+        my_rir = os.path.normpath(os.path.join(params["myrir"][rir_index]))
+        (fs_rir, samples_rir) = wavfile.read(my_rir)
+
+        my_channel = int(params["mychannel"][rir_index])
+
+        if samples_rir.ndim == 1:
+            samples_rir_ch = np.array(samples_rir)
+
+        elif my_channel > 1:
+            samples_rir_ch = samples_rir[:, my_channel - 1]
+        else:
+            samples_rir_ch = samples_rir[:, my_channel - 1]
+            # print(samples_rir.shape)
+            # print(my_channel)
+
+        # REVERB MIXED TO THE CLEAN SPEECH
+        clean = add_pyreverb(clean, samples_rir_ch)
+
+        # generate noise
+        noise, noise_sf, noise_cf, noise_laf, noise_index = gen_audio(
+            False, params, noise_index, len(clean)
+        )
+
+        clean_clipped_files += clean_cf
+        clean_low_activity_files += clean_laf
+        noise_clipped_files += noise_cf
+        noise_low_activity_files += noise_laf
+
+        # mix clean speech and noise
+        # if specified, use specified SNR value
+        if not params["randomize_snr"]:
+            snr = params["snr"]
+        # use a randomly sampled SNR value between the specified bounds
+        else:
+            snr = np.random.randint(params["snr_lower"], params["snr_upper"])
+
+        # NOISE ADDED TO THE REVERBED SPEECH
+        clean_snr, noise_snr, noisy_snr, target_level = segmental_snr_mixer(
+            params=params, clean=clean, noise=noise, snr=snr
+        )
+        # Uncomment the below lines if you need segmental SNR and comment the above lines using snr_mixer
+        # clean_snr, noise_snr, noisy_snr, target_level = segmental_snr_mixer(params=params,
+        #                                                         clean=clean,
+        #                                                          noise=noise,
+        #                                                         snr=snr)
+        # unexpected clipping
+        if (
+            is_clipped(clean_snr)
+            or is_clipped(noise_snr)
+            or is_clipped(noisy_snr)
+        ):
+            print(
+                "\nWarning: File #"
+                + str(file_num)
+                + " has unexpected clipping, "
+                + "returning without writing audio to disk"
+            )
+            continue
+
+        clean_source_files += clean_sf
+        noise_source_files += noise_sf
+
+        # write resultant audio streams to files
+        hyphen = "-"
+        clean_source_filenamesonly = [
+            i[:-4].split(os.path.sep)[-1] for i in clean_sf
+        ]
+        clean_files_joined = hyphen.join(clean_source_filenamesonly)[
+            :MAXFILELEN
+        ]
+        noise_source_filenamesonly = [
+            i[:-4].split(os.path.sep)[-1] for i in noise_sf
+        ]
+        noise_files_joined = hyphen.join(noise_source_filenamesonly)[
+            :MAXFILELEN
+        ]
+
+        noisyfilename = (
+            clean_files_joined
+            + "_"
+            + noise_files_joined
+            + "_snr"
+            + str(snr)
+            + "_tl"
+            + str(target_level)
+            + "_fileid_"
+            + str(file_num)
+        )
+
+        # Period is not allowed in a WebDataset key name
+        cleanfilename = "clean_fileid_" + str(file_num)
+        cleanfilename = cleanfilename.replace(".", "_")
+        noisefilename = "noise_fileid_" + str(file_num)
+        noisefilename = noisefilename.replace(".", "_")
+
+        file_num += 1
+
+        # store statistics
+        key = noisyfilename
+        key = key.replace(".", "_")
+        lang = params["split_name"].split("_")[0]
+        t = (key, lang)
+
+        # verify key is unique
+        assert cleanfilename not in all_keys
+        all_keys.add(cleanfilename)
+
+        # Split the data between training and validation based on the file number
+        if file_num % total_samples <= num_validation_samples:
+            # Write to validation set
+            valid_sample_keys[lang].append(key)
+            valid_data_tuples.append(t)
+            sample = {
+                "__key__": key,
+                "noisy_file": key,
+                "clean_file": cleanfilename,
+                "noise_file": noisefilename,
+                "clean_audio.pth": torch.tensor(clean_snr).to(torch.float32),
+                "noise_audio.pth": torch.tensor(noise_snr).to(torch.float32),
+                "noisy_audio.pth": torch.tensor(noisy_snr).to(torch.float32),
+            }
+            valid_writer.write(sample)
+        else:
+            # Write to training set
+            train_sample_keys[lang].append(key)
+            train_data_tuples.append(t)
+            sample = {
+                "__key__": key,
+                "noisy_file": key,
+                "clean_file": cleanfilename,
+                "noise_file": noisefilename,
+                "clean_audio.pth": torch.tensor(clean_snr).to(torch.float32),
+                "noise_audio.pth": torch.tensor(noise_snr).to(torch.float32),
+                "noisy_audio.pth": torch.tensor(noisy_snr).to(torch.float32),
+            }
+            train_writer.write(sample)
+
+    train_writer.close()
+    valid_writer.close()
+
+    # Write meta.json files for both training and validation
+    train_meta_dict = {
+        "language_id": lang,
+        "sample_keys_per_language": train_sample_keys,
+        "num_data_samples": len(train_data_tuples),
+    }
+    valid_meta_dict = {
+        "language_id": lang,
+        "sample_keys_per_language": valid_sample_keys,
+        "num_data_samples": len(valid_data_tuples),
+    }
+
+    with (train_shards_path / "meta.json").open("w") as f:
+        json.dump(train_meta_dict, f, indent=4)
+
+    with (valid_shards_path / "meta.json").open("w") as f:
+        json.dump(valid_meta_dict, f, indent=4)
+
+    return (
+        clean_source_files,
+        clean_clipped_files,
+        clean_low_activity_files,
+        noise_source_files,
+        noise_clipped_files,
+        noise_low_activity_files,
+    )
+
+
+def main_body():  # noqa
+    """Main body of this file"""
+
+    params = dict()
+
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Data Directories and Settings
+    params["split_name"] = hparams["split_name"]
+
+    # Audio Settings
+    params["fs"] = int(hparams["sampling_rate"])
+    params["fs_input"] = int(
+        hparams["input_sampling_rate"]
+    )  # Sampling rate of input data
+    params["audioformat"] = hparams["audioformat"]
+    params["audio_length"] = float(hparams["audio_length"])
+    params["silence_length"] = float(hparams["silence_length"])
+    params["total_hours"] = float(hparams["total_hours"])
+
+    # Clean Data Categories
+    params["use_singing_data"] = int(hparams["use_singing_data"])
+    if hasattr(hparams, "clean_singing"):
+        params["clean_singing"] = str(hparams["clean_singing"])
+    params["singing_choice"] = int(hparams["singing_choice"])
+
+    params["use_emotion_data"] = int(hparams["use_emotion_data"])
+    if hasattr(hparams, "clean_emotion"):
+        params["clean_emotion"] = str(hparams["clean_emotion"])
+
+    params["use_mandarin_data"] = int(hparams["use_mandarin_data"])
+    if hasattr(hparams, "clean_mandarin"):
+        params["clean_mandarin"] = str(hparams["clean_mandarin"])
+
+    # Room Impulse Response (RIR) Settings
+    params["rir_choice"] = int(hparams["rir_choice"])
+    params["lower_t60"] = float(hparams["lower_t60"])
+    params["upper_t60"] = float(hparams["upper_t60"])
+    params["rir_table_csv"] = str(hparams["rir_table_csv"])
+
+    # File Indexing
+    if (
+        hparams["fileindex_start"] != "None"
+        and hparams["fileindex_end"] != "None"
+    ):
+        params["num_files"] = int(hparams["fileindex_end"]) - int(
+            params["fileindex_start"]
+        )
+        params["fileindex_start"] = int(hparams["fileindex_start"])
+        params["fileindex_end"] = int(hparams["fileindex_end"])
+    else:
+        params["num_files"] = int(
+            (params["total_hours"] * 60 * 60) / params["audio_length"]
+        )
+        params["fileindex_start"] = 0
+        params["fileindex_end"] = params["num_files"]
+
+    print("Number of files to be synthesized:", params["num_files"])
+
+    # Data Generation and Synthesis Settings
+    params["is_test_set"] = utils.str2bool(str(hparams["is_test_set"]))
+    params["clean_activity_threshold"] = float(
+        hparams["clean_activity_threshold"]
+    )
+    params["noise_activity_threshold"] = float(
+        hparams["noise_activity_threshold"]
+    )
+    params["snr_lower"] = int(hparams["snr_lower"])
+    params["snr_upper"] = int(hparams["snr_upper"])
+    params["randomize_snr"] = utils.str2bool(str(hparams["randomize_snr"]))
+    params["target_level_lower"] = int(hparams["target_level_lower"])
+    params["target_level_upper"] = int(hparams["target_level_upper"])
+
+    if hasattr(hparams, "snr"):
+        params["snr"] = int(hparams["snr"])
+    else:
+        params["snr"] = int((params["snr_lower"] + params["snr_upper"]) / 2)
+
+    # Synthesized Data Destination
+    params["samples_per_shard"] = hparams["samples_per_shard"]
+    params["train_shard_destination"] = hparams["train_shard_destination"]
+    params["valid_shard_destination"] = hparams["valid_shard_destination"]
+
+    #### Shard data extraction ~~~
+    # load the meta info json file
+
+    with wds.gopen(hparams["clean_meta"], "rb") as f:
+        clean_meta = json.load(f)
+    with wds.gopen(hparams["noise_meta"], "rb") as f:
+        noise_meta = json.load(f)
+
+    def audio_pipeline(sample_dict: Dict, random_chunk=True):
+        key = sample_dict["__key__"]
+        audio_tensor = sample_dict["audio.pth"]
+
+        sig = audio_tensor.squeeze()
+
+        return {
+            "sig": sig,
+            "id": key,
+        }
+
+    clean_data = (
+        wds.WebDataset(
+            hparams["clean_fullband_shards"],
+            cache_dir=hparams["shard_cache_dir"],
+        )
+        .repeat()
+        .shuffle(1000)
+        .decode("pil")
+        .map(partial(audio_pipeline, random_chunk=True))
+    )
+    print(f"Clean data consist of {clean_meta['num_data_samples']} samples")
+
+    noise_data = (
+        wds.WebDataset(
+            hparams["noise_fullband_shards"],
+            cache_dir=hparams["shard_cache_dir"],
+        )
+        .repeat()
+        .shuffle(1000)
+        .decode("pil")
+        .map(partial(audio_pipeline, random_chunk=True))
+    )
+    print(f"Noise data consist of {noise_meta['num_data_samples']} samples")
+
+    params["clean_data"] = clean_data
+    params["noise_data"] = noise_data
+
+    # add singing voice to clean speech
+    if params["use_singing_data"] == 1:
+        raise NotImplementedError("Add sining voice to clean speech")
+    else:
+        print("NOT using singing data for training!")
+
+    # add emotion data to clean speech
+    if params["use_emotion_data"] == 1:
+        raise NotImplementedError("Add emotional data to clean speech")
+    else:
+        print("NOT using emotion data for training!")
+
+    # add mandarin data to clean speech
+    if params["use_mandarin_data"] == 1:
+        raise NotImplementedError("Add Mandarin data to clean speech")
+    else:
+        print("NOT using non-english (Mandarin) data for training!")
+
+    # rir
+    temp = pd.read_csv(
+        params["rir_table_csv"],
+        skiprows=[1],
+        sep=",",
+        header=None,
+        names=["wavfile", "channel", "T60_WB", "C50_WB", "isRealRIR"],
+    )
+    temp.keys()
+    # temp.wavfile
+
+    rir_wav = temp["wavfile"][1:]  # 115413
+    rir_channel = temp["channel"][1:]
+    rir_t60 = temp["T60_WB"][1:]
+    rir_isreal = temp["isRealRIR"][1:]
+
+    rir_wav2 = [w.replace("\\", "/") for w in rir_wav]
+    rir_channel2 = [w for w in rir_channel]
+    rir_t60_2 = [w for w in rir_t60]
+    rir_isreal2 = [w for w in rir_isreal]
+
+    myrir = []
+    mychannel = []
+    myt60 = []
+
+    lower_t60 = params["lower_t60"]
+    upper_t60 = params["upper_t60"]
+
+    if params["rir_choice"] == 1:  # real 3076 IRs
+        real_indices = [i for i, x in enumerate(rir_isreal2) if x == "1"]
+
+        chosen_i = []
+        for i in real_indices:
+            if (float(rir_t60_2[i]) >= lower_t60) and (
+                float(rir_t60_2[i]) <= upper_t60
+            ):
+                chosen_i.append(i)
+
+        myrir = [rir_wav2[i] for i in chosen_i]
+        mychannel = [rir_channel2[i] for i in chosen_i]
+        myt60 = [rir_t60_2[i] for i in chosen_i]
+
+    elif params["rir_choice"] == 2:  # synthetic 112337 IRs
+        synthetic_indices = [i for i, x in enumerate(rir_isreal2) if x == "0"]
+
+        chosen_i = []
+        for i in synthetic_indices:
+            if (float(rir_t60_2[i]) >= lower_t60) and (
+                float(rir_t60_2[i]) <= upper_t60
+            ):
+                chosen_i.append(i)
+
+        myrir = [rir_wav2[i] for i in chosen_i]
+        mychannel = [rir_channel2[i] for i in chosen_i]
+        myt60 = [rir_t60_2[i] for i in chosen_i]
+
+    elif params["rir_choice"] == 3:  # both real and synthetic
+        all_indices = [i for i, x in enumerate(rir_isreal2)]
+
+        chosen_i = []
+        for i in all_indices:
+            if (float(rir_t60_2[i]) >= lower_t60) and (
+                float(rir_t60_2[i]) <= upper_t60
+            ):
+                chosen_i.append(i)
+
+        myrir = [rir_wav2[i] for i in chosen_i]
+        mychannel = [rir_channel2[i] for i in chosen_i]
+        myt60 = [rir_t60_2[i] for i in chosen_i]
+
+    else:  # default both real and synthetic
+        all_indices = [i for i, x in enumerate(rir_isreal2)]
+
+        chosen_i = []
+        for i in all_indices:
+            if (float(rir_t60_2[i]) >= lower_t60) and (
+                float(rir_t60_2[i]) <= upper_t60
+            ):
+                chosen_i.append(i)
+
+        myrir = [rir_wav2[i] for i in chosen_i]
+        mychannel = [rir_channel2[i] for i in chosen_i]
+        myt60 = [rir_t60_2[i] for i in chosen_i]
+
+    params["myrir"] = myrir
+    params["mychannel"] = mychannel
+    params["myt60"] = myt60
+
+    # Call main_gen() to generate audio
+    (
+        clean_source_files,
+        clean_clipped_files,
+        clean_low_activity_files,
+        noise_source_files,
+        noise_clipped_files,
+        noise_low_activity_files,
+    ) = main_gen(params)
+
+    # Create log directory if needed, and write log files of clipped and low activity files
+    log_dir = utils.get_dir(hparams, "log_dir", "Logs")
+
+    utils.write_log_file(
+        log_dir, "source_files.csv", clean_source_files + noise_source_files
+    )
+    utils.write_log_file(
+        log_dir, "clipped_files.csv", clean_clipped_files + noise_clipped_files
+    )
+    utils.write_log_file(
+        log_dir,
+        "low_activity_files.csv",
+        clean_low_activity_files + noise_low_activity_files,
+    )
+
+    # Compute and print stats about percentange of clipped and low activity files
+    total_clean = (
+        len(clean_source_files)
+        + len(clean_clipped_files)
+        + len(clean_low_activity_files)
+    )
+    total_noise = (
+        len(noise_source_files)
+        + len(noise_clipped_files)
+        + len(noise_low_activity_files)
+    )
+    pct_clean_clipped = round(len(clean_clipped_files) / total_clean * 100, 1)
+    pct_noise_clipped = round(len(noise_clipped_files) / total_noise * 100, 1)
+    pct_clean_low_activity = round(
+        len(clean_low_activity_files) / total_clean * 100, 1
+    )
+    pct_noise_low_activity = round(
+        len(noise_low_activity_files) / total_noise * 100, 1
+    )
+
+    print(
+        "\nOf the "
+        + str(total_clean)
+        + " clean speech files analyzed, "
+        + str(pct_clean_clipped)
+        + "% had clipping, and "
+        + str(pct_clean_low_activity)
+        + "% had low activity "
+        + "(below "
+        + str(params["clean_activity_threshold"] * 100)
+        + "% active percentage)"
+    )
+    print(
+        "Of the "
+        + str(total_noise)
+        + " noise files analyzed, "
+        + str(pct_noise_clipped)
+        + "% had clipping, and "
+        + str(pct_noise_low_activity)
+        + "% had low activity "
+        + "(below "
+        + str(params["noise_activity_threshold"] * 100)
+        + "% active percentage)"
+    )
+
+
+if __name__ == "__main__":
+    main_body()
diff --git a/recipes/DNS/noisyspeech_synthesizer/utils.py b/recipes/DNS/noisyspeech_synthesizer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f3826a906f347b8a5e9f1c2cb91e2d288d7fe26
--- /dev/null
+++ b/recipes/DNS/noisyspeech_synthesizer/utils.py
@@ -0,0 +1,54 @@
+"""
+Source: https://github.com/microsoft/DNS-Challenge
+Ownership: Microsoft
+
+* Author
+    rocheng
+"""
+import os
+import csv
+from shutil import copyfile
+import glob
+
+
+def get_dir(cfg, param_name, new_dir_name):
+    """Helper function to retrieve directory name if it exists,
+       create it if it doesn't exist"""
+
+    if param_name in cfg:
+        dir_name = cfg[param_name]
+    else:
+        dir_name = os.path.join(os.path.dirname(__file__), new_dir_name)
+    if not os.path.exists(dir_name):
+        os.makedirs(dir_name)
+    return dir_name
+
+
+def write_log_file(log_dir, log_filename, data):
+    """Helper function to write log file"""
+    # data = zip(*data)
+    with open(
+        os.path.join(log_dir, log_filename), mode="w", newline=""
+    ) as csvfile:
+        csvwriter = csv.writer(
+            csvfile, delimiter=" ", quotechar="|", quoting=csv.QUOTE_MINIMAL
+        )
+        for row in data:
+            csvwriter.writerow([row])
+
+
+def str2bool(string):
+    """Convert a string to a boolean value.
+    """
+    return string.lower() in ("yes", "true", "t", "1")
+
+
+def rename_copyfile(src_path, dest_dir, prefix="", ext="*.wav"):
+    """Copy and rename files from a source directory to a destination directory.
+    """
+    srcfiles = glob.glob(f"{src_path}/" + ext)
+    for i in range(len(srcfiles)):
+        dest_path = os.path.join(
+            dest_dir, prefix + "_" + os.path.basename(srcfiles[i])
+        )
+        copyfile(srcfiles[i], dest_path)
diff --git a/recipes/DVoice/ASR/CTC/hparams/train_amh_with_wav2vec.yaml b/recipes/DVoice/ASR/CTC/hparams/train_amh_with_wav2vec.yaml
index a52d737d22f193dd0df27ada16226bc47525e845..e9e1f43100f823790eec7e5d2649688dbb5d5574 100644
--- a/recipes/DVoice/ASR/CTC/hparams/train_amh_with_wav2vec.yaml
+++ b/recipes/DVoice/ASR/CTC/hparams/train_amh_with_wav2vec.yaml
@@ -32,12 +32,12 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 15.0
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -58,7 +58,7 @@ test_dataloader_options:
 token_type: char  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
 freeze_wav2vec: False
@@ -78,10 +78,41 @@ eos_index: 2
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -104,7 +135,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/DVoice/ASR/CTC/hparams/train_dar_with_wav2vec.yaml b/recipes/DVoice/ASR/CTC/hparams/train_dar_with_wav2vec.yaml
index 71f4be0936195fb96f8ae731d0af0a1ba616ec9e..d1e2c66842028ae18a908eb06ba949e019977d66 100644
--- a/recipes/DVoice/ASR/CTC/hparams/train_dar_with_wav2vec.yaml
+++ b/recipes/DVoice/ASR/CTC/hparams/train_dar_with_wav2vec.yaml
@@ -32,12 +32,12 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 15.0
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -58,7 +58,7 @@ test_dataloader_options:
 token_type: char  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
 freeze_wav2vec: False
@@ -78,10 +78,42 @@ eos_index: 2
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -104,7 +136,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/DVoice/ASR/CTC/hparams/train_fon_with_wav2vec.yaml b/recipes/DVoice/ASR/CTC/hparams/train_fon_with_wav2vec.yaml
index 9ac55d60e3862330d75dfa5c551747d4c63433dd..fca0230de227bd68b964653f04808e43d0e484c9 100644
--- a/recipes/DVoice/ASR/CTC/hparams/train_fon_with_wav2vec.yaml
+++ b/recipes/DVoice/ASR/CTC/hparams/train_fon_with_wav2vec.yaml
@@ -32,12 +32,12 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 15.0
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -58,7 +58,7 @@ test_dataloader_options:
 token_type: char  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
 freeze_wav2vec: False
@@ -78,10 +78,42 @@ eos_index: 2
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -104,7 +136,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/DVoice/ASR/CTC/hparams/train_multi_with_wav2vec.yaml b/recipes/DVoice/ASR/CTC/hparams/train_multi_with_wav2vec.yaml
index 358e59546d9fe3759d48e6b535248ac4e0d41727..89fedade8f51ce16f72897b6d763993ed0f36d04 100644
--- a/recipes/DVoice/ASR/CTC/hparams/train_multi_with_wav2vec.yaml
+++ b/recipes/DVoice/ASR/CTC/hparams/train_multi_with_wav2vec.yaml
@@ -31,12 +31,12 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 15.0
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -57,7 +57,7 @@ test_dataloader_options:
 token_type: char  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
 freeze_wav2vec: False
@@ -77,10 +77,42 @@ eos_index: 2
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -103,7 +135,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/DVoice/ASR/CTC/hparams/train_sw_with_wav2vec.yaml b/recipes/DVoice/ASR/CTC/hparams/train_sw_with_wav2vec.yaml
index ec3f1fd9cd157b3834dd40f03f593ec84ce9007a..0194fd8776e231be71a999e9f87d5c5b675cdf69 100644
--- a/recipes/DVoice/ASR/CTC/hparams/train_sw_with_wav2vec.yaml
+++ b/recipes/DVoice/ASR/CTC/hparams/train_sw_with_wav2vec.yaml
@@ -32,12 +32,12 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 15.0
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -58,7 +58,7 @@ test_dataloader_options:
 token_type: char  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
 freeze_wav2vec: False
@@ -78,10 +78,42 @@ eos_index: 2
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -104,7 +136,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/DVoice/ASR/CTC/hparams/train_wol_with_wav2vec.yaml b/recipes/DVoice/ASR/CTC/hparams/train_wol_with_wav2vec.yaml
index 5fe3e3133a3fe496ceea73648919cf98cd95816f..8470ce3a1c81b603cd54093f763c8323a28032cc 100644
--- a/recipes/DVoice/ASR/CTC/hparams/train_wol_with_wav2vec.yaml
+++ b/recipes/DVoice/ASR/CTC/hparams/train_wol_with_wav2vec.yaml
@@ -32,12 +32,12 @@ skip_prep: False # Skip data preparation
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 15.0
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
@@ -58,7 +58,7 @@ test_dataloader_options:
 token_type: char  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
 freeze_wav2vec: False
@@ -78,10 +78,41 @@ eos_index: 2
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <wav2vec_output_dim>]
     linear1: !name:speechbrain.nnet.linear.Linear
@@ -104,7 +135,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
     bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
     activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
diff --git a/recipes/DVoice/ASR/CTC/train_with_wav2vec2.py b/recipes/DVoice/ASR/CTC/train_with_wav2vec2.py
index 36844bb8e341190ca0c5fd1444f37c0dc6b3f8e4..12bcb869588b8f36344842af094a80f036711e65 100644
--- a/recipes/DVoice/ASR/CTC/train_with_wav2vec2.py
+++ b/recipes/DVoice/ASR/CTC/train_with_wav2vec2.py
@@ -43,12 +43,11 @@ class ASR(sb.core.Brain):
 
         batch = batch.to(self.device)
         wavs, wav_lens = batch.sig
-        tokens_bos, _ = batch.tokens_bos
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
 
         # Forward pass
         feats = self.modules.wav2vec2(wavs, wav_lens)
@@ -64,9 +63,13 @@ class ASR(sb.core.Brain):
         p_ctc, wav_lens = predictions
 
         ids = batch.id
-        tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
+        # Label Augmentation
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens = self.hparams.wav_augment.replicate_labels(tokens)
+            tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens)
+
         loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
 
         if stage != sb.Stage.TRAIN:
@@ -86,61 +89,6 @@ class ASR(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        should_step = self.step % self.grad_accumulation_factor == 0
-        # Managing automatic mixed precision
-        # TOFIX: CTC fine-tuning currently is unstable
-        # This is certainly due to CTC being done in fp16 instead of fp32
-        if self.auto_mix_prec:
-            with torch.cuda.amp.autocast():
-                with self.no_sync():
-                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            with self.no_sync(not should_step):
-                self.scaler.scale(
-                    loss / self.grad_accumulation_factor
-                ).backward()
-            if should_step:
-
-                if not self.hparams.wav2vec2.freeze:
-                    self.scaler.unscale_(self.wav2vec_optimizer)
-                self.scaler.unscale_(self.model_optimizer)
-                if self.check_gradients(loss):
-                    if not self.hparams.wav2vec2.freeze:
-                        self.scaler.step(self.wav2vec_optimizer)
-                    self.scaler.step(self.model_optimizer)
-                self.scaler.update()
-                self.zero_grad()
-                self.optimizer_step += 1
-        else:
-            # This is mandatory because HF models have a weird behavior with DDP
-            # on the forward pass
-            with self.no_sync():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-
-            with self.no_sync(not should_step):
-                (loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                if self.check_gradients(loss):
-                    if not self.hparams.wav2vec2.freeze:
-                        self.wav2vec_optimizer.step()
-                    self.model_optimizer.step()
-                self.zero_grad()
-                self.optimizer_step += 1
-
-        self.on_fit_batch_end(batch, outputs, loss, should_step)
-        return loss.detach().cpu()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -196,6 +144,10 @@ class ASR(sb.core.Brain):
     def init_optimizers(self):
         "Initializes the wav2vec2 optimizer and model optimizer"
 
+        self.model_optimizer = self.hparams.model_opt_class(
+            self.hparams.model.parameters()
+        )
+
         # If the wav2vec encoder is unfrozen, we create the optimizer
         if not self.hparams.wav2vec2.freeze:
             self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
@@ -205,18 +157,25 @@ class ASR(sb.core.Brain):
                 self.checkpointer.add_recoverable(
                     "wav2vec_opt", self.wav2vec_optimizer
                 )
-
-        self.model_optimizer = self.hparams.model_opt_class(
-            self.hparams.model.parameters()
-        )
+            self.optimizers_dict = {
+                "wav2vec_optimizer": self.wav2vec_optimizer,
+                "model_optimizer": self.model_optimizer,
+            }
+        else:
+            self.optimizers_dict = {"model_optimizer": self.model_optimizer}
 
         if self.checkpointer is not None:
             self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
 
-    def zero_grad(self, set_to_none=False):
+    def freeze_optimizers(self, optimizers):
+        """Freezes the wav2vec2 optimizer according to the warmup steps"""
+        valid_optimizers = {}
         if not self.hparams.wav2vec2.freeze:
-            self.wav2vec_optimizer.zero_grad(set_to_none)
-        self.model_optimizer.zero_grad(set_to_none)
+            valid_optimizers["wav2vec_optimizer"] = optimizers[
+                "wav2vec_optimizer"
+            ]
+        valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+        return valid_optimizers
 
 
 # Define custom data procedure
@@ -316,7 +275,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/DVoice/dvoice_prepare.py b/recipes/DVoice/dvoice_prepare.py
index 8cedaebd514e72f8d22c5bca40442d37a0bd6071..f471f024049cc657b614a095f78a54f0bec45469 100644
--- a/recipes/DVoice/dvoice_prepare.py
+++ b/recipes/DVoice/dvoice_prepare.py
@@ -11,7 +11,6 @@ import os
 import csv
 import re
 import logging
-import torchaudio
 import unicodedata
 from tqdm.contrib import tzip
 import random
@@ -372,12 +371,6 @@ def create_csv(
         spk_id = line.split("\t")[0].replace(".wav", "")
         snt_id = os.path.basename(file_name)
 
-        # Setting torchaudio backend to sox-io (needed to read mp3 files)
-        if torchaudio.get_audio_backend() != "sox_io":
-            logger.warning("This recipe needs the sox-io backend of torchaudio")
-            logger.warning("The torchaudio backend is changed to sox_io")
-            torchaudio.set_audio_backend("sox_io")
-
         duration = float(line.split("\t")[2])
         total_duration += duration
 
diff --git a/recipes/ESC50/classification/hparams/cnn14_classifier.yaml b/recipes/ESC50/classification/hparams/cnn14_classifier.yaml
index e8034bfdd169bd85bf10eda4504bd5f46a9a8656..bc0a83bbd7415221e3f4c5610363cb0fdfc0ffde 100644
--- a/recipes/ESC50/classification/hparams/cnn14_classifier.yaml
+++ b/recipes/ESC50/classification/hparams/cnn14_classifier.yaml
@@ -41,7 +41,7 @@ skip_manifest_creation: False
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 200
 batch_size: 32
 lr: 0.0002
diff --git a/recipes/ESC50/classification/hparams/conv2d_classifier.yaml b/recipes/ESC50/classification/hparams/conv2d_classifier.yaml
index 2b0a49bcd36172932655a544e2c436d528392988..284d5681fc5799e1d84bd7ce3c865e60e5b92e46 100644
--- a/recipes/ESC50/classification/hparams/conv2d_classifier.yaml
+++ b/recipes/ESC50/classification/hparams/conv2d_classifier.yaml
@@ -41,7 +41,7 @@ skip_manifest_creation: False
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 200
 batch_size: 32
 lr: 0.00002
diff --git a/recipes/ESC50/esc50_prepare.py b/recipes/ESC50/esc50_prepare.py
index 83637fa11f3e12079a7c0f1ba5659818ab88a0cb..f65eebbc05b3bd025af6b12f0442fe063ea210be 100644
--- a/recipes/ESC50/esc50_prepare.py
+++ b/recipes/ESC50/esc50_prepare.py
@@ -18,7 +18,7 @@ import logging
 import torchaudio
 from speechbrain.dataio.dataio import read_audio
 from speechbrain.dataio.dataio import load_data_csv
-from speechbrain.pretrained import fetch
+from speechbrain.utils.fetching import fetch
 
 logger = logging.getLogger(__name__)
 
diff --git a/recipes/ESC50/interpret/hparams/l2i_cnn14.yaml b/recipes/ESC50/interpret/hparams/l2i_cnn14.yaml
index cdd9bcb7e410ceda7f2c7ff6dc3e681646cce937..00acd1ff372cb8b3a97e13a11f531181dfcb3502 100644
--- a/recipes/ESC50/interpret/hparams/l2i_cnn14.yaml
+++ b/recipes/ESC50/interpret/hparams/l2i_cnn14.yaml
@@ -39,7 +39,7 @@ skip_manifest_creation: False
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 200
 batch_size: 2
 lr: 0.0001
@@ -47,8 +47,6 @@ sample_rate: 44100
 interpret_period: 1
 relevance_th: 0.2
 
-device: "cuda"
-
 # Feature parameters
 n_mels: 80
 left_frames: 0
diff --git a/recipes/ESC50/interpret/hparams/l2i_conv2dclassifier.yaml b/recipes/ESC50/interpret/hparams/l2i_conv2dclassifier.yaml
index 318eaedfd6fb7d41c81ebddd7b4fd05a3a15c12d..4f6cb9b909871f749b60b8be054b9ed68b64c6e9 100644
--- a/recipes/ESC50/interpret/hparams/l2i_conv2dclassifier.yaml
+++ b/recipes/ESC50/interpret/hparams/l2i_conv2dclassifier.yaml
@@ -39,7 +39,7 @@ skip_manifest_creation: False
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 200
 batch_size: 16
 lr: 0.0002
@@ -47,7 +47,6 @@ sample_rate: 16000
 interpret_period: 1
 relevance_th: 0.2
 
-device: "cuda"
 
 # Feature parameters
 n_mels: 80
diff --git a/recipes/ESC50/interpret/hparams/nmf.yaml b/recipes/ESC50/interpret/hparams/nmf.yaml
index af8ac369266e4ad60fed8fb7b5acc72925e2d22b..e4da313ba1a107174e8e584541fce92d2e811ba0 100644
--- a/recipes/ESC50/interpret/hparams/nmf.yaml
+++ b/recipes/ESC50/interpret/hparams/nmf.yaml
@@ -40,14 +40,12 @@ skip_manifest_creation: False
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 100
 batch_size: 2
 lr: 0.0002
 sample_rate: 44100
 
-device: "cpu"
-
 shuffle: True
 dataloader_options:
     batch_size: !ref <batch_size>
diff --git a/recipes/ESC50/interpret/hparams/piq.yaml b/recipes/ESC50/interpret/hparams/piq.yaml
index 4ecf103076001256d307fd3475e1b72ff4661d67..68f8c06deb3f2ab2c13c19b93d31f119b650416a 100644
--- a/recipes/ESC50/interpret/hparams/piq.yaml
+++ b/recipes/ESC50/interpret/hparams/piq.yaml
@@ -42,7 +42,7 @@ skip_manifest_creation: False
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 200
 batch_size: 16
 lr: 0.0002
@@ -52,7 +52,6 @@ rec_loss_coef: 1
 use_mask_output: True
 mask_th: 0.35
 
-device: "cuda"
 
 # Feature parameters
 n_mels: 80
diff --git a/recipes/ESC50/interpret/train_l2i.py b/recipes/ESC50/interpret/train_l2i.py
index 137d19766f8eb34776d6a8dd27582a7d3873930c..92d99f1879587ab4a7ccea49d0c200f4fb8995f9 100644
--- a/recipes/ESC50/interpret/train_l2i.py
+++ b/recipes/ESC50/interpret/train_l2i.py
@@ -655,9 +655,9 @@ if __name__ == "__main__":
         hparams["pretrained_esc50"].load_collected()
 
     # transfer the frozen parts to the model to the device
-    hparams["embedding_model"].to(hparams["device"])
-    hparams["classifier"].to(hparams["device"])
-    hparams["nmf_decoder"].to(hparams["device"])
+    hparams["embedding_model"].to(run_opts["device"])
+    hparams["classifier"].to(run_opts["device"])
+    hparams["nmf_decoder"].to(run_opts["device"])
     hparams["embedding_model"].eval()
 
     Interpreter_brain.fit(
diff --git a/recipes/ESC50/interpret/train_piq.py b/recipes/ESC50/interpret/train_piq.py
index 843b92d2e2cef34cdd37582fb51369c97608990e..d95a1ccd37f3bf20045be693e2af6c8c9a2c1e19 100644
--- a/recipes/ESC50/interpret/train_piq.py
+++ b/recipes/ESC50/interpret/train_piq.py
@@ -739,8 +739,8 @@ if __name__ == "__main__":
         run_on_main(hparams["pretrained_esc50"].collect_files)
         hparams["pretrained_esc50"].load_collected()
 
-    hparams["embedding_model"].to(hparams["device"])
-    hparams["classifier"].to(hparams["device"])
+    hparams["embedding_model"].to(run_opts["device"])
+    hparams["classifier"].to(run_opts["device"])
     hparams["embedding_model"].eval()
 
     Interpreter_brain.fit(
@@ -755,7 +755,6 @@ if __name__ == "__main__":
 
     Interpreter_brain.checkpointer.recover_if_possible(
         max_key="valid_top-3_fid",
-        device=torch.device(Interpreter_brain.device),
     )
 
     test_stats = Interpreter_brain.evaluate(
diff --git a/recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/conformer.yaml b/recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/conformer.yaml
index bf099397b88f4a5ea420699cae3aeda68b4e46f5..49a7321f79c6268d9058ab16b2a920543e707a9f 100644
--- a/recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/conformer.yaml
+++ b/recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/conformer.yaml
@@ -26,17 +26,6 @@ data_folder: !PLACEHOLDER # Path to the folder generated by the preparation scri
 
 tokenizer_file: !PLACEHOLDER # Path to the file of the Tokenizer model (.model)
 
-# Tokenier initialization
-tokenizer: !new:sentencepiece.SentencePieceProcessor
-
-# Pretrain the tokenizer
-pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
-    collect_in: ./tokenizer
-    loadables:
-        tokenizer: !ref <tokenizer>
-    paths:
-        tokenizer: !ref <tokenizer_file>
-
 # The train logger writes training statistics to a file, as well as stdout.
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
@@ -55,17 +44,21 @@ normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
     update_until_epoch: 4
 
-speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
+# Speed perturbation
+speed_changes: [90, 100, 110]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
     orig_freq: !ref <sample_rate>
-    speeds: [90, 100, 110]
+    speeds: !ref <speed_changes>
 
 # Trainer settings
 number_of_epochs: 50
 valid_search_eopch: 10
 batch_size: 4 # this works for 2 GPUs with 11GB
-gradient_accumulation: 16
+grad_accumulation_factor: 16
 loss_reduction: batchmean
 sorting: random
+avg_checkpoints: 5 # Number of checkpoints to average for evaluation
 
 # stages related parameters
 stage_one_epochs: 100 # not gonna changing optimizer in this recipe
@@ -88,7 +81,7 @@ test_dataloader_opts:
     batch_size: !ref <batch_size>
     num_workers: !ref <num_workers>
 
-####################### Model parameters ###########################
+####################### Model Parameters ###########################
 # Transformer
 d_model: 256
 nhead: 4
@@ -202,29 +195,25 @@ SGD: !name:torch.optim.SGD
     momentum: 0.99
     nesterov: True
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, null]
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
     using_eos_threshold: False
-    length_normalization: False
+    length_normalization: True
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, null]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
     using_eos_threshold: True
     length_normalization: True
-    ctc_weight: 0
-    lm_weight: 0
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -242,6 +231,17 @@ noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
     n_warmup_steps: 35000
     model_size: !ref <d_model>
 
+# Tokenier initialization
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+# Pretrain the tokenizer
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>/tokenizer
+    loadables:
+        tokenizer: !ref <tokenizer>
+    paths:
+        tokenizer: !ref <tokenizer_file>
+
 # Checkpoint setting
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
diff --git a/recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/transformer.yaml b/recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/transformer.yaml
index e9ed5cc5dbde5e0b5f7c7bf41b5d5cd2b6b94d34..4310e2d6b5cc19f3fd8766f062b41a367c27cd1a 100644
--- a/recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/transformer.yaml
+++ b/recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/transformer.yaml
@@ -54,17 +54,21 @@ normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
     update_until_epoch: 4
 
-speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
+# Speed perturbation
+speed_changes: [90, 100, 110]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
     orig_freq: !ref <sample_rate>
-    speeds: [90, 100, 110]
+    speeds: !ref <speed_changes>
 
 # Trainer settings
 number_of_epochs: 50
 valid_search_eopch: 10
 batch_size: 8 # this works for 2 GPUs with 11GB
-gradient_accumulation: 8
+grad_accumulation_factor: 8
 loss_reduction: batchmean
 sorting: random
+avg_checkpoints: 5 # Number of checkpoints to average for evaluation
 
 # stages related parameters
 stage_one_epochs: 100 # not gonna changing optimizer in this recipe
@@ -87,7 +91,7 @@ test_dataloader_opts:
     batch_size: !ref <batch_size>
     num_workers: !ref <num_workers>
 
-####################### Model parameters ###########################
+####################### Model Parameters ###########################
 # Transformer
 d_model: 256
 nhead: 4
@@ -196,29 +200,25 @@ SGD: !name:torch.optim.SGD
     momentum: 0.99
     nesterov: True
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, null]
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
     using_eos_threshold: False
-    length_normalization: False
+    length_normalization: True
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, null]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
     using_eos_threshold: True
     length_normalization: True
-    ctc_weight: 0
-    lm_weight: 0
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
diff --git a/recipes/Fisher-Callhome-Spanish/ST/transformer/train.py b/recipes/Fisher-Callhome-Spanish/ST/transformer/train.py
index 2c71cb6df69b66f52f7bc3b8d8d37c8eda631758..da8bc61210d1c45475887fe47d77e09c7994ee30 100644
--- a/recipes/Fisher-Callhome-Spanish/ST/transformer/train.py
+++ b/recipes/Fisher-Callhome-Spanish/ST/transformer/train.py
@@ -82,17 +82,20 @@ class ST(sb.core.Brain):
             mt_pred = self.modules.seq_lin(mt_pred)
             mt_p_seq = self.hparams.log_softmax(mt_pred)
 
-        # compute outputs
+        # Compute outputs
         hyps = None
-        if stage == sb.Stage.TRAIN:
-            hyps = None
-        elif stage == sb.Stage.VALID:
-            hyps = None
-            current_epoch = self.hparams.epoch_counter.current
-            if current_epoch % self.hparams.valid_search_interval == 0:
-                hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens)
-        elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
+        current_epoch = self.hparams.epoch_counter.current
+        is_valid_search = (
+            stage == sb.Stage.VALID
+            and current_epoch % self.hparams.valid_search_interval == 0
+        )
+        is_test_search = stage == sb.Stage.TEST
+        if is_valid_search:
+            hyps, _, _, _ = self.hparams.valid_search(
+                enc_out.detach(), wav_lens
+            )
+        elif is_test_search:
+            hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
         return p_ctc, p_seq, asr_p_seq, mt_p_seq, wav_lens, hyps
 
@@ -202,29 +205,17 @@ class ST(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
+    def on_fit_batch_start(self, batch, should_step):
+        """Gets called at the beginning of each fit_batch."""
         # check if we need to switch optimizer
         # if so change the optimizer from Adam to SGD
         self.check_and_reset_optimizer()
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        # normalize the loss by gradient_accumulation step
-        (loss / self.hparams.gradient_accumulation).backward()
-
-        if self.step % self.hparams.gradient_accumulation == 0:
-            # gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
 
-            self.optimizer.step()
-            self.optimizer.zero_grad()
-
-            # anneal lr every update
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
             self.hparams.noam_annealing(self.optimizer)
 
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -251,7 +242,7 @@ class ST(sb.core.Brain):
                 stage_stats["BLEU"] = self.bleu_metric.summarize("BLEU")
 
         # log stats and save checkpoint at end-of-epoch
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
             current_epoch = self.hparams.epoch_counter.current
 
             # report different epoch stages according current stage
@@ -279,7 +270,7 @@ class ST(sb.core.Brain):
             self.checkpointer.save_and_keep_only(
                 meta={"ACC": stage_stats["ACC"], "epoch": epoch},
                 max_keys=["ACC"],
-                num_to_keep=5,
+                num_to_keep=self.hparams.avg_checkpoints,
             )
 
         elif stage == sb.Stage.TEST:
@@ -338,9 +329,7 @@ class ST(sb.core.Brain):
                 if "momentum" not in group:
                     return
 
-                self.checkpointer.recover_if_possible(
-                    device=torch.device(self.device)
-                )
+                self.checkpointer.recover_if_possible()
 
     def on_evaluate_start(self, max_key=None, min_key=None):
         """perform checkpoint averge if needed"""
@@ -350,7 +339,7 @@ class ST(sb.core.Brain):
             max_key=max_key, min_key=min_key
         )
         ckpt = sb.utils.checkpoints.average_checkpoints(
-            ckpts, recoverable_name="model", device=self.device
+            ckpts, recoverable_name="model",
         )
 
         self.hparams.model.load_state_dict(ckpt, strict=True)
@@ -603,7 +592,7 @@ if __name__ == "__main__":
 
     # transcription/translation tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # We can now directly create the datasets for training, valid, and test
     datasets = dataio_prepare(hparams)
diff --git a/recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py b/recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py
index 660c7d75d5115a41de0957e1726b81c12f325c16..876e74fa57f4351a850731ebe1beda20ad44a33f 100644
--- a/recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py
+++ b/recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py
@@ -22,7 +22,7 @@ import torchaudio
 from tqdm import tqdm
 from speechbrain.utils.data_utils import get_all_files
 from speechbrain.utils.torch_audio_backend import check_torchaudio_backend
-from speechbrain.processing.speech_augmentation import Resample
+from speechbrain.augment.time_domain import Resample
 
 try:
     from sacremoses import MosesPunctNormalizer, MosesTokenizer
diff --git a/recipes/Google-speech-commands/hparams/xvect.yaml b/recipes/Google-speech-commands/hparams/xvect.yaml
index a9d58fd7204d8e022b43018c0e33468d8a1623bb..417cecfdf3b8bd2be10529bcdd4c2515d930cb5b 100644
--- a/recipes/Google-speech-commands/hparams/xvect.yaml
+++ b/recipes/Google-speech-commands/hparams/xvect.yaml
@@ -14,14 +14,19 @@ output_folder: !ref results/xvect_v<number_of_commands>/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/GSC
-train_annotation: !ref <output_folder>/train.csv
-valid_annotation: !ref <output_folder>/valid.csv
-test_annotation: !ref <output_folder>/test.csv
-
-# Folder to extract data augmentation files
-rir_folder: !ref <data_folder> # Change it if needed
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
+train_annotation: !ref <save_folder>/train.csv
+valid_annotation: !ref <save_folder>/valid.csv
+test_annotation: !ref <save_folder>/test.csv
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
 
 # Percentage of files used for validation and test
 validation_percentage: 10
@@ -32,12 +37,10 @@ testing_percentage: 10
 percentage_unknown: 10 # Set this to 0 for the V2 35 task
 percentage_silence: 10 # Set this to 0 for the V2 35 task
 
-# Whether to use data augmentation
-apply_data_augmentation: True
 skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 100
 batch_size: 32
 lr: 0.001
@@ -56,10 +59,11 @@ deltas: False
 # Number of classes (i.e. different commands)
 out_n_neurons: !ref <number_of_commands>  #includes core commands & auxiliary words
 
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
     shuffle: !ref <shuffle>
-    num_workers: 2
+    num_workers: !ref <num_workers>
 
 # Functions
 compute_features: !new:speechbrain.lobes.features.Fbank
@@ -91,54 +95,80 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
 
-augment_wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [100]
-
-augment_speed: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-add_rev: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 1.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 0.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 1.0  # seconds
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_rev_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 1.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-
-# Definition of the augmentation pipeline.
-# If concat_augment = False, the augmentation techniques are applied
-# in sequence. If concat_augment = True, all the augmented signals
-# # are concatenated in a single big batch.
-augment_pipeline: [
-    !ref <augment_wavedrop>,
-    !ref <augment_speed>,
-    !ref <add_rev>,
-    !ref <add_noise>,
-    !ref <add_rev_noise>
-]
-concat_augment: True
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: True
+    concat_original: True
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 mean_var_norm: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
@@ -146,11 +176,6 @@ mean_var_norm: !new:speechbrain.processing.features.InputNormalization
 
 modules:
     compute_features: !ref <compute_features>
-    augment_wavedrop: !ref <augment_wavedrop>
-    augment_speed: !ref <augment_speed>
-    add_rev: !ref <add_rev>
-    add_noise: !ref <add_noise>
-    add_rev_noise: !ref <add_rev_noise>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     softmax: !ref <softmax>
diff --git a/recipes/Google-speech-commands/hparams/xvect_leaf.yaml b/recipes/Google-speech-commands/hparams/xvect_leaf.yaml
index 1a9cf4e9791d8cb37de4643ee317e4dfd3f0df94..f2897af22c1252385255b99cf4dfa5e4fa0a03e3 100644
--- a/recipes/Google-speech-commands/hparams/xvect_leaf.yaml
+++ b/recipes/Google-speech-commands/hparams/xvect_leaf.yaml
@@ -15,14 +15,20 @@ output_folder: !ref results/xvect_leaf_legacy_complex_mvnorm_v<number_of_command
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
+# Data files
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/GSC
-train_annotation: !ref <output_folder>/train.csv
-valid_annotation: !ref <output_folder>/valid.csv
-test_annotation: !ref <output_folder>/test.csv
-
-# Folder to extract data augmentation files
-rir_folder: !ref <data_folder> # Change it if needed
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
+train_annotation: !ref <save_folder>/train.csv
+valid_annotation: !ref <save_folder>/valid.csv
+test_annotation: !ref <save_folder>/test.csv
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
 
 # Percentage of files used for validation and test
 validation_percentage: 10
@@ -33,12 +39,10 @@ testing_percentage: 10
 percentage_unknown: 10 # Set this to 0 for the V2 35 task
 percentage_silence: 10 # Set this to 0 for the V2 35 task
 
-# Whether to use data augmentation
-apply_data_augmentation: True
 skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 100
 batch_size: 32
 lr: 0.001
@@ -54,10 +58,11 @@ n_features: 24
 # Number of classes (i.e. different commands)
 out_n_neurons: !ref <number_of_commands>  #includes core commands & auxiliary words
 
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
     shuffle: !ref <shuffle>
-    num_workers: 2
+    num_workers: !ref <num_workers>
 
 # Functions
 compute_features: !new:speechbrain.lobes.features.Leaf
@@ -90,54 +95,80 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
 
-augment_wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [100]
-
-augment_speed: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-add_rev: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 1.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 0.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 1.0  # seconds
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_rev_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 1.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-
-# Definition of the augmentation pipeline.
-# If concat_augment = False, the augmentation techniques are applied
-# in sequence. If concat_augment = True, all the augmented signals
-# # are concatenated in a single big batch.
-augment_pipeline: [
-    !ref <augment_wavedrop>,
-    !ref <augment_speed>,
-    !ref <add_rev>,
-    !ref <add_noise>,
-    !ref <add_rev_noise>
-]
-concat_augment: True
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: True
+    concat_original: True
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 mean_var_norm: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
@@ -145,11 +176,6 @@ mean_var_norm: !new:speechbrain.processing.features.InputNormalization
 
 modules:
     compute_features: !ref <compute_features>
-    augment_wavedrop: !ref <augment_wavedrop>
-    augment_speed: !ref <augment_speed>
-    add_rev: !ref <add_rev>
-    add_noise: !ref <add_noise>
-    add_rev_noise: !ref <add_rev_noise>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     softmax: !ref <softmax>
diff --git a/recipes/Google-speech-commands/train.py b/recipes/Google-speech-commands/train.py
index d4f8e8ede55d782040c52d1e17d94b6f4cf0119a..5f524bf2b06f9bdc56d7414afb2e2aed05c28c77 100644
--- a/recipes/Google-speech-commands/train.py
+++ b/recipes/Google-speech-commands/train.py
@@ -38,33 +38,9 @@ class SpeakerBrain(sb.core.Brain):
         batch = batch.to(self.device)
         wavs, lens = batch.sig
 
-        if stage == sb.Stage.TRAIN and self.hparams.apply_data_augmentation:
-
-            # Applying the augmentation pipeline
-            wavs_aug_tot = []
-            wavs_aug_tot.append(wavs)
-            for count, augment in enumerate(self.hparams.augment_pipeline):
-
-                # Apply augment
-                wavs_aug = augment(wavs, lens)
-
-                # Managing speed change
-                if wavs_aug.shape[1] > wavs.shape[1]:
-                    wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
-                else:
-                    zero_sig = torch.zeros_like(wavs)
-                    zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
-                    wavs_aug = zero_sig
-
-                if self.hparams.concat_augment:
-                    wavs_aug_tot.append(wavs_aug)
-                else:
-                    wavs = wavs_aug
-                    wavs_aug_tot[0] = wavs
-
-            wavs = torch.cat(wavs_aug_tot, dim=0)
-            self.n_augment = len(wavs_aug_tot)
-            lens = torch.cat([lens] * self.n_augment)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, lens = self.hparams.wav_augment(wavs, lens)
 
         if isinstance(
             self.modules.compute_features, speechbrain.lobes.features.Leaf
@@ -96,8 +72,8 @@ class SpeakerBrain(sb.core.Brain):
         command, _ = batch.command_encoded
 
         # Concatenate labels (due to data augmentation)
-        if stage == sb.Stage.TRAIN and self.hparams.apply_data_augmentation:
-            command = torch.cat([command] * self.n_augment, dim=0)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            command = self.hparams.wav_augment.replicate_labels(command)
 
         # compute the cost function
         loss = self.hparams.compute_cost(predictions, command, lens)
@@ -306,6 +282,8 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
+    sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
+    sb.utils.distributed.run_on_main(hparams["prepare_rir_data"])
 
     # Dataset IO prep: creating Dataset objects and proper encodings for phones
     train_data, valid_data, test_data, label_encoder = dataio_prep(hparams)
diff --git a/recipes/IEMOCAP/emotion_recognition/hparams/train.yaml b/recipes/IEMOCAP/emotion_recognition/hparams/train.yaml
index 80935b552aea1a14c856b264c6f1ffee08c3e084..28f690675904b3a7362aca508408b502425e5a35 100644
--- a/recipes/IEMOCAP/emotion_recognition/hparams/train.yaml
+++ b/recipes/IEMOCAP/emotion_recognition/hparams/train.yaml
@@ -43,7 +43,7 @@ ckpt_interval_minutes: 15 # save checkpoint every N min
 # Training Parameters
 number_of_epochs: 30
 batch_size: 16
-gradient_accumulation: 2
+grad_accumulation_factor: 2
 lr: 0.0001
 weight_decay: 0.00002
 base_lr: 0.000001
diff --git a/recipes/IEMOCAP/emotion_recognition/hparams/train_with_wav2vec2.yaml b/recipes/IEMOCAP/emotion_recognition/hparams/train_with_wav2vec2.yaml
index 22dcb28904987d5d37568af77ab194eb65b0346a..d1b63d7bf65549ce4d301b89d86aa2c3686ff3da 100644
--- a/recipes/IEMOCAP/emotion_recognition/hparams/train_with_wav2vec2.yaml
+++ b/recipes/IEMOCAP/emotion_recognition/hparams/train_with_wav2vec2.yaml
@@ -38,7 +38,7 @@ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 batch_size: 4
 lr: 0.0001
@@ -50,7 +50,7 @@ freeze_wav2vec2: False
 # We see an improvement of 2% with freezing CNNs
 freeze_wav2vec2_conv: True
 
-# Model parameters
+####################### Model Parameters #######################################
 encoder_dim: 768
 
 # Number of emotions
@@ -63,7 +63,7 @@ dataloader_options:
     drop_last: False
 
 # Wav2vec2 encoder
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec2>
diff --git a/recipes/IEMOCAP/emotion_recognition/train.py b/recipes/IEMOCAP/emotion_recognition/train.py
index be4ae55e28c6f19e1924679b8c9ea38fb37cd083..fa1f6ef240ec0f930bea90ec0d89e8d44c52ed2c 100644
--- a/recipes/IEMOCAP/emotion_recognition/train.py
+++ b/recipes/IEMOCAP/emotion_recognition/train.py
@@ -39,23 +39,6 @@ class EmoIdBrain(sb.Brain):
 
         return outputs
 
-    def fit_batch(self, batch):
-        """Trains the parameters given a single batch in input"""
-
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        # normalize the loss by gradient_accumulation step
-        (loss / self.hparams.gradient_accumulation).backward()
-
-        if self.step % self.hparams.gradient_accumulation == 0:
-            # gradient clipping & early stop if loss is not finite
-            self.check_gradients(loss)
-            self.optimizer.step()
-            self.optimizer.zero_grad()
-
-        return loss.detach()
-
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss using speaker-id as label.
         """
diff --git a/recipes/IEMOCAP/emotion_recognition/train_with_wav2vec2.py b/recipes/IEMOCAP/emotion_recognition/train_with_wav2vec2.py
index dc91dec7b9fdb439a6a8c2128b1f6f70fb5a8d84..f508ec8844f4ce69016db5ac8fcdeb31631c4352 100644
--- a/recipes/IEMOCAP/emotion_recognition/train_with_wav2vec2.py
+++ b/recipes/IEMOCAP/emotion_recognition/train_with_wav2vec2.py
@@ -47,21 +47,6 @@ class EmoIdBrain(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Trains the parameters given a single batch in input"""
-
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.wav2vec2_optimizer.step()
-            self.optimizer.step()
-
-        self.wav2vec2_optimizer.zero_grad()
-        self.optimizer.zero_grad()
-
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch=None):
         """Gets called at the beginning of each epoch.
         Arguments
@@ -152,9 +137,10 @@ class EmoIdBrain(sb.Brain):
             )
             self.checkpointer.add_recoverable("optimizer", self.optimizer)
 
-    def zero_grad(self, set_to_none=False):
-        self.wav2vec2_optimizer.zero_grad(set_to_none)
-        self.optimizer.zero_grad(set_to_none)
+        self.optimizers_dict = {
+            "model_optimizer": self.optimizer,
+            "wav2vec2_optimizer": self.wav2vec2_optimizer,
+        }
 
 
 def dataio_prep(hparams):
diff --git a/recipes/IEMOCAP/iemocap_prepare.py b/recipes/IEMOCAP/iemocap_prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..d42fcff19cb8ee245c45c2e38442834a98e43bb2
--- /dev/null
+++ b/recipes/IEMOCAP/iemocap_prepare.py
@@ -0,0 +1,341 @@
+"""
+Downloads and creates data manifest files for IEMOCAP
+(https://paperswithcode.com/dataset/iemocap).
+
+Authors:
+ * Mirco Ravanelli, 2021
+ * Modified by Pierre-Yves Yanni, 2021
+ * Abdel Heba, 2021
+ * Yingzhi Wang, 2022
+"""
+
+import os
+import re
+import json
+import random
+import logging
+from speechbrain.dataio.dataio import read_audio
+
+logger = logging.getLogger(__name__)
+SAMPLERATE = 16000
+NUMBER_UTT = 5531
+
+
+def prepare_data(
+    data_original,
+    save_json_train,
+    save_json_valid,
+    save_json_test,
+    split_ratio=[80, 10, 10],
+    different_speakers=False,
+    test_spk_id=1,
+    seed=12,
+):
+    """
+    Prepares the json files for the IEMOCAP dataset.
+
+    Arguments
+    ---------
+    data_original : str
+        Path to the folder where the original IEMOCAP dataset is stored.
+    save_json_train : str
+        Path where the train data specification file will be saved.
+    save_json_valid : str
+        Path where the validation data specification file will be saved.
+    save_json_test : str
+        Path where the test data specification file will be saved.
+    split_ratio: list
+        List composed of three integers that sets split ratios for train,
+        valid, and test sets, respecively.
+        For instance split_ratio=[80, 10, 10] will assign 80% of the sentences
+        to training, 10% for validation, and 10% for test.
+    test_spk_id: int
+        Id of speaker used for test set, 10 speakers in total.
+        Here a leave-two-speaker strategy is used for the split,
+        if one test_spk_id is selected for test, the other spk_id in the same
+        session is automatically used for validation.
+        To perform a 10-fold cross-validation,
+        10 experiments with test_spk_id from 1 to 10 should be done.
+    seed : int
+        Seed for reproducibility
+
+    Example
+    -------
+    >>> data_original = '/path/to/iemocap/IEMOCAP_full_release'
+    >>> prepare_data(data_original, 'train.json', 'valid.json', 'test.json')
+    """
+    data_original = data_original + "/Session"
+    # setting seeds for reproducible code.
+    random.seed(seed)
+
+    # Check if this phase is already done (if so, skip it)
+    if skip(save_json_train, save_json_valid, save_json_test):
+        logger.info("Preparation completed in previous run, skipping.")
+        return
+
+    speaker_dict = transform_data(data_original)
+
+    if sum([len(value) for value in speaker_dict.values()]) != NUMBER_UTT:
+        raise ValueError(
+            "Error: Number of utterances is not 5531, please check your IEMOCAP folder"
+        )
+
+    # List files and create manifest from list
+    logger.info(
+        f"Creating {save_json_train}, {save_json_valid}, and {save_json_test}"
+    )
+
+    if different_speakers:
+        data_split = split_different_speakers(speaker_dict, test_spk_id)
+    else:
+        data_split = split_sets(speaker_dict, split_ratio)
+
+    # Creating json files
+    create_json(data_split["train"], save_json_train)
+    create_json(data_split["valid"], save_json_valid)
+    create_json(data_split["test"], save_json_test)
+
+
+def create_json(wav_list, json_file):
+    """
+    Creates the json file given a list of wav information.
+
+    Arguments
+    ---------
+    wav_list : list of list
+        The list of wav information (path, label, gender).
+    json_file : str
+        The path of the output json file
+    """
+
+    json_dict = {}
+    for obj in wav_list:
+        wav_file = obj[0]
+        emo = obj[1]
+        # Read the signal (to retrieve duration in seconds)
+        signal = read_audio(wav_file)
+        duration = signal.shape[0] / SAMPLERATE
+
+        uttid = wav_file.split("/")[-1][:-4]
+
+        # Create entry for this utterance
+        json_dict[uttid] = {
+            "wav": wav_file,
+            "length": duration,
+            "emo": emo,
+        }
+
+    # Writing the dictionary to the json file
+    with open(json_file, mode="w") as json_f:
+        json.dump(json_dict, json_f, indent=2)
+
+    logger.info(f"{json_file} successfully created!")
+
+
+def skip(*filenames):
+    """
+    Detects if the data preparation has been already done.
+    If the preparation has been done, we can skip it.
+
+    Returns
+    -------
+    bool
+        if True, the preparation phase can be skipped.
+        if False, it must be done.
+    """
+    for filename in filenames:
+        if not os.path.isfile(filename):
+            return False
+    return True
+
+
+def split_different_speakers(speaker_dict, test_spk_id):
+    """Constructs train, validation and test sets that do not share common
+    speakers. There are two different speakers in each session. Train set is
+    constituted of 4 sessions (8 speakers), while validation set and test set
+    contain each 1 speaker. If test_spk_id is 1, then speaker 2 is selected
+    automatically for validation set, and training set contains other 8 speakers.
+    If test_spk_id is 2, then speaker 1 is selected for validation set.
+
+    Arguments
+    ---------
+    speaker_dict: dict
+        a dictionary of speaker id and its corresponding audio information
+    test_spk_id: int
+        Id of speaker used for test set, 10 speakers in total.
+        Session1 contains speaker 1&2, Session2 contains speaker 3&4, ...
+
+    Returns
+    ------
+    dictionary containing train, valid, and test splits.
+    """
+    data_split = {k: [] for k in ["train", "valid", "test"]}
+    data_split["test"].extend(speaker_dict[str(test_spk_id)])
+
+    # use the speaker in the same session as validation set
+    if test_spk_id % 2 == 0:
+        valid_spk_num = test_spk_id - 1
+    else:
+        valid_spk_num = test_spk_id + 1
+
+    data_split["valid"].extend(speaker_dict[str(valid_spk_num)])
+
+    for i in range(1, 11):
+        if i != valid_spk_num and i != test_spk_id:
+            data_split["train"].extend(speaker_dict[str(i)])
+
+    return data_split
+
+
+def split_sets(speaker_dict, split_ratio):
+    """Randomly splits the wav list into training, validation, and test lists.
+    Note that a better approach is to make sure that all the classes have the
+    same proportion of samples (e.g, spk01 should have 80% of samples in
+    training, 10% validation, 10% test, the same for speaker2 etc.). This
+    is the approach followed in some recipes such as the Voxceleb one. For
+    simplicity, we here simply split the full list without necessarly
+    respecting the split ratio within each class.
+
+    Arguments
+    ---------
+    speaker_dict : list
+        a dictionary of speaker id and its corresponding audio information
+    split_ratio: list
+        List composed of three integers that sets split ratios for train,
+        valid, and test sets, respectively.
+        For instance split_ratio=[80, 10, 10] will assign 80% of the sentences
+        to training, 10% for validation, and 10% for test.
+
+    Returns
+    ------
+    dictionary containing train, valid, and test splits.
+    """
+
+    wav_list = []
+    for key in speaker_dict.keys():
+        wav_list.extend(speaker_dict[key])
+
+    # Random shuffle of the list
+    random.shuffle(wav_list)
+    tot_split = sum(split_ratio)
+    tot_snts = len(wav_list)
+    data_split = {}
+    splits = ["train", "valid"]
+
+    for i, split in enumerate(splits):
+        n_snts = int(tot_snts * split_ratio[i] / tot_split)
+        data_split[split] = wav_list[0:n_snts]
+        del wav_list[0:n_snts]
+    data_split["test"] = wav_list
+
+    return data_split
+
+
+def transform_data(path_loadSession):
+    """
+    Create a dictionary that maps speaker id and corresponding wavs
+
+    Arguments
+    ---------
+    path_loadSession : str
+        Path to the folder where the original IEMOCAP dataset is stored.
+
+    Example
+    -------
+    >>> data_original = '/path/to/iemocap/IEMOCAP_full_release/Session'
+    >>> data_transformed = '/path/to/iemocap/IEMOCAP_ahsn_leave-two-speaker-out'
+    >>> transform_data(data_original, data_transformed)
+    """
+
+    speaker_dict = {str(i + 1): [] for i in range(10)}
+
+    speaker_count = 0
+    for k in range(5):
+        session = load_session("%s%s" % (path_loadSession, k + 1))
+        for idx in range(len(session)):
+            if session[idx][2] == "F":
+                speaker_dict[str(speaker_count + 1)].append(session[idx])
+            else:
+                speaker_dict[str(speaker_count + 2)].append(session[idx])
+        speaker_count += 2
+
+    return speaker_dict
+
+
+def load_utterInfo(inputFile):
+    """
+    Load utterInfo from original IEMOCAP database
+    """
+
+    # this regx allow to create a list with:
+    # [START_TIME - END_TIME] TURN_NAME EMOTION [V, A, D]
+    # [V, A, D] means [Valence, Arousal, Dominance]
+    pattern = re.compile(
+        "[\[]*[0-9]*[.][0-9]*[ -]*[0-9]*[.][0-9]*[\]][\t][a-z0-9_]*[\t][a-z]{3}[\t][\[][0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[\]]",
+        re.IGNORECASE,
+    )  # noqa
+    with open(inputFile, "r") as myfile:
+        data = myfile.read().replace("\n", " ")
+    result = pattern.findall(data)
+    out = []
+    for i in result:
+        a = i.replace("[", "")
+        b = a.replace(" - ", "\t")
+        c = b.replace("]", "")
+        x = c.replace(", ", "\t")
+        out.append(x.split("\t"))
+    return out
+
+
+def load_session(pathSession):
+    """Load wav file from IEMOCAP session
+    and keep only the following 4 emotions:
+    [neural, happy, sad, anger].
+
+    Arguments
+    ---------
+        pathSession: str
+            Path folder of IEMOCAP session.
+    Returns
+    -------
+        improvisedUtteranceList: list
+            List of improvised utterancefor IEMOCAP session.
+    """
+    pathEmo = pathSession + "/dialog/EmoEvaluation/"
+    pathWavFolder = pathSession + "/sentences/wav/"
+
+    improvisedUtteranceList = []
+    for emoFile in [
+        f
+        for f in os.listdir(pathEmo)
+        if os.path.isfile(os.path.join(pathEmo, f))
+    ]:
+        for utterance in load_utterInfo(pathEmo + emoFile):
+            if (
+                (utterance[3] == "neu")
+                or (utterance[3] == "hap")
+                or (utterance[3] == "sad")
+                or (utterance[3] == "ang")
+                or (utterance[3] == "exc")
+            ):
+                path = (
+                    pathWavFolder
+                    + utterance[2][:-5]
+                    + "/"
+                    + utterance[2]
+                    + ".wav"
+                )
+
+                label = utterance[3]
+                if label == "exc":
+                    label = "hap"
+
+                if emoFile[7] != "i" and utterance[2][7] == "s":
+                    improvisedUtteranceList.append(
+                        [path, label, utterance[2][18]]
+                    )
+                else:
+                    improvisedUtteranceList.append(
+                        [path, label, utterance[2][15]]
+                    )
+    return improvisedUtteranceList
diff --git a/recipes/IEMOCAP/quantization/README.md b/recipes/IEMOCAP/quantization/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d72c564ec6ad3cc307d9fc12eb44ce8374508716
--- /dev/null
+++ b/recipes/IEMOCAP/quantization/README.md
@@ -0,0 +1,49 @@
+
+# K-means (Quantization)
+This folder contains recipes for training K-means clustering model for the IEMOCAP Dataset.
+The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation.
+It supports  kmeans model using the features from  HuBERT, WAVLM or Wav2Vec.
+
+You can download IEMOCAP at https://sail.usc.edu/iemocap/
+
+## Installing Extra Dependencies
+
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+
+```
+pip install -r extra_requirements.txt
+```
+
+# How to run:
+```shell
+python train.py hparams/train_with_{SSL_model}.yaml
+```
+
+# Results
+
+The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0).
+
+The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository.
+
+
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/IEMOCAP/quantization/extra-requirements.txt b/recipes/IEMOCAP/quantization/extra-requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d5e06028d853376200623c16ebcf4992f4ae60c2
--- /dev/null
+++ b/recipes/IEMOCAP/quantization/extra-requirements.txt
@@ -0,0 +1 @@
+scikit-learn
diff --git a/recipes/IEMOCAP/quantization/hparams/train_with_hubert.yaml b/recipes/IEMOCAP/quantization/hparams/train_with_hubert.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..79bf2633eb5c33d2f4435a7e87b2a324664d34a7
--- /dev/null
+++ b/recipes/IEMOCAP/quantization/hparams/train_with_hubert.yaml
@@ -0,0 +1,61 @@
+################################
+# Recipe for Training K-Means Clustering on IEMOCAP Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from IEMOCAP data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/IEMOCAP/clustering/hubert/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+# Dataset will be downloaded to the `data_original`
+data_folder: !PLACEHOLDER  # e.g., /path/to/IEMOCAP_full_release
+
+# different speakers for train, valid and test sets
+different_speakers: False
+# which speaker is used for test set, value from 1 to 10
+test_spk_id: 1
+# Path where data manifest files will be stored
+train_annotation: !ref <output_folder>/train.json
+valid_annotation: !ref <output_folder>/valid.json
+test_annotation: !ref <output_folder>/test.json
+split_ratio: [80, 10, 10]
+skip_prep: False
+sample_rate: 16000
+
+ssl_hub: facebook/hubert-base-ls960
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/hubert_checkpoint
+ssl_layer_num: 7
+batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/IEMOCAP/quantization/hparams/train_with_wav2vec.yaml b/recipes/IEMOCAP/quantization/hparams/train_with_wav2vec.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b21a7eafb8a5781019e0f1d78c1cb23c6d815e4d
--- /dev/null
+++ b/recipes/IEMOCAP/quantization/hparams/train_with_wav2vec.yaml
@@ -0,0 +1,61 @@
+################################
+# Recipe for Training K-Means Clustering on IEMOCAP Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from IEMOCAP data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/IEMOCAP/clustering/wav2vec/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+# Dataset will be downloaded to the `data_original`
+data_folder: !PLACEHOLDER  # e.g., /path/to/IEMOCAP_full_release
+
+# different speakers for train, valid and test sets
+different_speakers: False
+# which speaker is used for test set, value from 1 to 10
+test_spk_id: 1
+# Path where data manifest files will be stored
+train_annotation: !ref <output_folder>/train.json
+valid_annotation: !ref <output_folder>/valid.json
+test_annotation: !ref <output_folder>/test.json
+split_ratio: [80, 10, 10]
+skip_prep: False
+sample_rate: 16000
+
+ssl_hub: facebook/wav2vec2-large-960h-lv60-self
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/wav2vec_checkpoint
+ssl_layer_num: 7
+batch_size: 64 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/IEMOCAP/quantization/hparams/train_with_wavlm.yaml b/recipes/IEMOCAP/quantization/hparams/train_with_wavlm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2682a0610718a731a31ca66305127149c1a47ff3
--- /dev/null
+++ b/recipes/IEMOCAP/quantization/hparams/train_with_wavlm.yaml
@@ -0,0 +1,61 @@
+################################
+# Recipe for Training K-Means Clustering on IEMOCAP Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from IEMOCAP data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/IEMOCAP/clustering/wavlm/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+# Dataset will be downloaded to the `data_original`
+data_folder: !PLACEHOLDER  # e.g., /path/to/IEMOCAP_full_release
+
+# different speakers for train, valid and test sets
+different_speakers: False
+# which speaker is used for test set, value from 1 to 10
+test_spk_id: 1
+# Path where data manifest files will be stored
+train_annotation: !ref <output_folder>/train.json
+valid_annotation: !ref <output_folder>/valid.json
+test_annotation: !ref <output_folder>/test.json
+split_ratio: [80, 10, 10]
+skip_prep: False
+sample_rate: 16000
+
+ssl_hub: microsoft/wavlm-large
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/wavlm_checkpoint
+ssl_layer_num: 7
+batch_size: 32 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/IEMOCAP/quantization/iemocap_prepare.py b/recipes/IEMOCAP/quantization/iemocap_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..bd20ddfb84553dea941d41c27d952c7d62d7a27f
--- /dev/null
+++ b/recipes/IEMOCAP/quantization/iemocap_prepare.py
@@ -0,0 +1 @@
+../iemocap_prepare.py
\ No newline at end of file
diff --git a/recipes/IEMOCAP/quantization/train.py b/recipes/IEMOCAP/quantization/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a0c080f88b0cf4babecc2e4e02628a2607d5e1
--- /dev/null
+++ b/recipes/IEMOCAP/quantization/train.py
@@ -0,0 +1,150 @@
+"""
+Recipe  to train K-means clustering model on self-supervised representations.
+
+To run this recipe, do the following:
+> python train.py hparams/train_with_[SSL-model].yaml --data_folder=/path/to/LibriSPeech
+Author
+ * Pooneh Mousavi 2023
+"""
+
+import os
+import sys
+import logging
+import speechbrain as sb
+from speechbrain.utils.distributed import run_on_main
+from hyperpyyaml import load_hyperpyyaml
+from torch.utils.data import DataLoader
+from speechbrain.dataio.dataloader import LoopedLoader
+from speechbrain.utils.kmeans import fetch_kmeans_model, train, save_model
+import torchaudio
+
+logger = logging.getLogger(__name__)
+
+
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined
+    functions. We expect `prepare_mini_librispeech` to have been called before
+    this, so that the `train.json`, `valid.json`,  and `valid.json` manifest
+    files are available.
+
+    Arguments
+    ---------
+    hparams : dict
+        This dictionary is loaded from the `train.yaml` file, and it includes
+        all the hyperparameters needed for dataset construction and loading.
+
+    Returns
+    -------
+    datasets : dict
+        Contains two keys, "train" and "valid" that correspond
+        to the appropriate DynamicItemDataset object.
+    """
+
+    # Define audio pipeline
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        """Load the signal, and pass it and its length to the corruption class.
+        This is done on the CPU in the `collate_fn`."""
+        sig = sb.dataio.dataio.read_audio(wav)
+        info = torchaudio.info(wav)
+        resampled = torchaudio.transforms.Resample(
+            info.sample_rate, hparams["sample_rate"],
+        )(sig)
+        return resampled
+
+    # Define datasets. We also connect the dataset with the data processing
+    # functions defined above.
+    datasets = {}
+    data_info = {
+        "train": hparams["train_annotation"],
+    }
+    for dataset in data_info:
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=data_info[dataset],
+            replacements={"data_root": hparams["data_folder"]},
+            dynamic_items=[audio_pipeline],
+            output_keys=["id", "sig"],
+        )
+    # Load or compute the label encoder (with multi-GPU DDP support)
+    # Please, take a look into the lab_enc_file to see the label to index
+    # mappinng.
+
+    return datasets["train"]
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing Librispeech)
+    from iemocap_prepare import prepare_data  # noqa E402
+
+    # Data preparation, to be run on only one process.
+    if not hparams["skip_prep"]:
+        run_on_main(
+            prepare_data,
+            kwargs={
+                "data_original": hparams["data_folder"],
+                "save_json_train": hparams["train_annotation"],
+                "save_json_valid": hparams["valid_annotation"],
+                "save_json_test": hparams["test_annotation"],
+                "split_ratio": hparams["split_ratio"],
+                "different_speakers": hparams["different_speakers"],
+                "test_spk_id": hparams["test_spk_id"],
+                "seed": hparams["seed"],
+            },
+        )
+
+    # Load SSL model
+    hparams["ssl_model"] = hparams["ssl_model"].to(run_opts["device"])
+
+    # Make training Dataloader
+    train_set = dataio_prepare(hparams)
+    if not (
+        isinstance(train_set, DataLoader) or isinstance(train_set, LoopedLoader)
+    ):
+        train_set = sb.dataio.dataloader.make_dataloader(
+            train_set, **hparams["train_dataloader_opts"]
+        )
+
+    # Load pretrained KMeans model if it exists. Otherwise,  create new one.
+    checkpoint_path = os.path.join(
+        hparams["save_folder"], f"kmeans_{hparams['num_clusters']}.pt"
+    )
+    kmeans_model = fetch_kmeans_model(
+        n_clusters=hparams["num_clusters"],
+        init=hparams["init"],
+        max_iter=hparams["max_iter"],
+        batch_size=hparams["batch_size"],
+        tol=hparams["tol"],
+        max_no_improvement=hparams["max_no_improvement"],
+        n_init=hparams["n_init"],
+        reassignment_ratio=hparams["reassignment_ratio"],
+        random_state=hparams["seed"],
+        checkpoint_path=checkpoint_path,
+    )
+
+    # Train and save Kmeans model
+    train(
+        kmeans_model,
+        train_set,
+        hparams["ssl_model"],
+        hparams["ssl_layer_num"],
+        kmeans_batch_size=hparams["kmeans_batch_size"],
+        device=run_opts["device"],
+    )
+
+    logger.info(f"Saving kmeans model at {checkpoint_path}.")
+    save_model(kmeans_model, checkpoint_path)
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/README.md b/recipes/IWSLT22_lowresource/AST/transformer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8447ce3fe13b9dac8ff4286e822f025abc6ae888
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/README.md
@@ -0,0 +1,107 @@
+# IWSLT 2022 Low-resource Task: Tamasheq-French end-to-end Speech Translation
+
+
+## Description
+
+This is the recipe for the best system from the IWSLT 2022 low-resource task, as described in the original paper.
+The speech translation model comprises a wav2vec 2.0 encoder and a Transformer decoder. It is trained end-to-end without any auxiliary loss. The recipe allows for removing the last layers of the Transformer Encoder inside the wav2vec 2.0 in order to reduce the number of training parameters.
+
+This recipe also provides a flexible use of text-based sequence-to-sequence models, such as [mBART](https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt) or [NLLB](https://huggingface.co/facebook/nllb-200-1.3B) model, to initialize the decoder of the speech translation model. This pratice has been proven more effective in a wide range of settings in comparison with the randomly initialized decoder.
+
+## Data Downloading
+
+For downloading the dataset used for this experiment, please run the following command.
+
+```
+git clone https://github.com/mzboito/IWSLT2022_Tamasheq_data.git
+```
+
+## Installing Extra Dependencies
+
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+
+```
+pip install -r extra_requirements.txt
+```
+
+## Training
+
+For training the model, please update the variables at hparams/train_w2v2_st.yaml.
+
+Note that in order to drop the last layers of the wav2vec 2.0 module, it is necessary to update the parameter "keep_n_layers".
+For instance: Using ``keep_n_layers: 10'' means that only the first 10 layers inside the wav2vec 2.0 Transformer encoder will be used for training. The remaining layers are removed.
+
+For launching training:
+```
+python train.py hparams/train_w2v2_st.yaml --root_data_folder=your/data/path # e.g., /workspace/speechbrain/recipes/IWSLT22_lowresource/IWSLT2022_Tamasheq_data/taq_fra_clean/
+
+```
+
+## Training with mBART/NLLB
+
+For training the model with the mBART/NLLB model, please refer to the hparams/train_w2v2_mbart_st.yaml or hparams/train_w2v2_nllb_st.yaml file.
+
+For launching training:
+```
+python train_with_w2v_mbart.py hparams/train_w2v2_mbart_st.yaml --root_data_folder=your/data/path # e.g., /workspace/speechbrain/recipes/IWSLT22_lowresource/IWSLT2022_Tamasheq_data/taq_fra_clean
+```
+
+One should change hparams/train_w2v2_mbart_st.yaml to hparams/train_w2v2_nllb_st.yaml in the above training command for using NLLB model instead.
+
+## Pre-training Semantically-Aligned Multimodal Utterance-level (SAMU) wav2vec
+
+Inspired by [SAMU-XLSR](https://arxiv.org/abs/2205.08180), a model that unifies speech and text modality for making the pre-trained speech foundation model more semantically aware, we introduce here a recipe for fine-tuning a pre-trained wav2vec 2.0 model in the same manner. Training data can be paired speech/text data of the kind used by ASR or AST. In this recipe, we use directly the IWSLT2022_Tamasheq_data AST data.
+
+For launching SAMU training:
+```
+python train_samu.py hparams/train_samu.yaml --root_data_folder=your/data/path # e.g., /workspace/speechbrain/recipes/IWSLT22_lowresource/IWSLT2022_Tamasheq_data/taq_fra_clean
+```
+
+After the SAMU model is pre-trained, one can use it in the same manner as wav2vec 2.0 model. We found that using SAMU model as speech encoder coupled with a decoder from mBART or NLLB helps further improve BLEU scores on this challenging dataset.
+
+For launching AST training:
+```
+train_with_samu_mbart.py hparams/train_samu_mbart_st.yaml --root_data_folder=your/data/path --pre_trained_samu=your/samu/ckpt
+```
+
+Examples of the two parameters:
+--root_data_folder=/workspace/speechbrain/recipes/IWSLT22_lowresource/IWSLT2022_Tamasheq_data/taq_fra_clean
+--pre_trained_samu=/workspace/speechbrain/recipes/IWSLT22_lowresource/results/samu_pretraining/7777/save/CKPT+checkpoint_epoch100/wav2vec2.ckpt
+
+One should change hparams/train_samu_mbart_st.yaml to hparams/train_samu_nllb_st.yaml in the above training command for using NLLB model instead.
+
+# Results
+
+| No. | hyperparams file |  dev BLEU | test BLEU | Model Link |
+| --- |:----------------:|:---------:|:--------:|:--------:|
+| 1 | train_w2v2_st.yaml | 7.63 | 5.38 | Not avail. | Not avail. |
+| 2 | train_w2v2_mbart_st.yaml | 9.62 | 7.73 | [DropBox](https://www.dropbox.com/sh/xjo0ou739oksnus/AAAgyrCwywmDRRuUiDnUva2za?dl=0) |
+| 3 | train_w2v2_nllb_st.yaml | 11.09 | 8.70 | [DropBox](https://www.dropbox.com/sh/spp2ijgfdbzuz26/AABkJ97e72D7aKzNLTm1qmWEa?dl=0) |
+| 4 | train_samu_mbart_st.yaml | 13.41 | 10.28 | [DropBox](https://www.dropbox.com/sh/98s1xyc3chreaw6/AABom3FnwY5SsIvg4en9tWC2a?dl=0) |
+| 5 | train_samu_nllb_st.yaml | 13.89 | 11.32 | [DropBox](https://www.dropbox.com/sh/ekkpl9c3kxsgllj/AABa0q2LrJe_o7JF-TTbfxZ-a?dl=0) |
+
+## Citation
+```
+@inproceedings{boito-etal-2022-trac,
+    title = "{ON}-{TRAC} Consortium Systems for the {IWSLT} 2022 Dialect and Low-resource Speech Translation Tasks",
+    author = {Boito, Marcely Zanon  and
+      Ortega, John  and
+      Riguidel, Hugo  and
+      Laurent, Antoine  and
+      Barrault, Lo{\"\i}c  and
+      Bougares, Fethi  and
+      Chaabani, Firas  and
+      Nguyen, Ha  and
+      Barbier, Florentin  and
+      Gahbiche, Souhir  and
+      Est{\`e}ve, Yannick},
+    booktitle = "Proceedings of the 19th International Conference on Spoken Language Translation (IWSLT 2022)",
+    month = may,
+    year = "2022",
+    address = "Dublin, Ireland (in-person and online)",
+    publisher = "Association for Computational Linguistics",
+    url = "https://aclanthology.org/2022.iwslt-1.28",
+    doi = "10.18653/v1/2022.iwslt-1.28",
+    pages = "308--318"
+}
+```
diff --git a/recipes/IWSLT22_lowresource/extra_requirements.txt b/recipes/IWSLT22_lowresource/AST/transformer/extra_requirements.txt
similarity index 55%
rename from recipes/IWSLT22_lowresource/extra_requirements.txt
rename to recipes/IWSLT22_lowresource/AST/transformer/extra_requirements.txt
index b340e2d6bae9db91610eac32d9c2a32fb985218c..506764464faa3a259b2dc0bed4e7759abea487e5 100644
--- a/recipes/IWSLT22_lowresource/extra_requirements.txt
+++ b/recipes/IWSLT22_lowresource/AST/transformer/extra_requirements.txt
@@ -1,2 +1,2 @@
+protobuf
 sacremoses
-
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu.yaml b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3901391a5480ca4f6ad9ddba1a951ef442262402
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu.yaml
@@ -0,0 +1,121 @@
+# ############################################################################
+# Model: SAMU model
+# losses: cosine similarity
+# Training: Tamasheq-French corpus
+# Author:  Ha Nguyen, 2023
+# ############################################################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 7777
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+debug: False
+output_folder: !ref results/samu_pretraining/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+wer_file: !ref <output_folder>/wer.txt
+
+# root data folder points to 17h version inside the github folder (IWSLT2022_Tamasheq_data/taq_fra_clean/)
+root_data_folder: !PLACEHOLDER # e.g., /users/hnguyen/IWSLT2022_Tamasheq_data/taq_fra_clean
+# data folder is the place where the json files will be stored prior to training
+data_folder: !ref <root_data_folder>/json_version/
+# Data files
+train_set: !ref <data_folder>/train.json
+valid_set: !ref <data_folder>/valid.json
+test_set: !ref <data_folder>/test.json
+skip_prep: False
+
+# URL for the HuggingFace model we want to load (BASE here)
+wav2vec2_hub: LIA-AvignonUniversity/IWSLT2022-tamasheq-only
+
+# wav2vec 2.0 specific parameters
+wav2vec2_frozen: False
+
+####################### Training Parameters ####################################
+number_of_epochs: 100
+lr: 0.001
+lr_wav2vec: 0.00001
+lr_labse: 0.00001
+sorting: ascending
+batch_size: 2
+test_batch_size: 1
+ckpt_interval_minutes: 15 # save checkpoint every N min
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: 4
+
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    num_workers: 4
+
+# Transformer
+d_model: 768
+loss_scale: 50
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <wav2vec2_hub>
+    output_norm: False
+    freeze: !ref <wav2vec2_frozen>
+    save_path: !ref <save_folder>/wav2vec2_checkpoint
+
+attn_pooling: !new:speechbrain.nnet.pooling.AttentionPooling
+    input_dim: !ref <d_model>
+
+#LaBSE
+labse_path: setu4993/LaBSE
+labse_frozen: True
+LaBSE: !new:speechbrain.lobes.models.huggingface_transformers.labse.LaBSE
+    source: !ref <labse_path>
+    freeze: !ref <labse_frozen>
+    output_norm: True
+    save_path: !ref <save_folder>/labse_checkpoint
+
+modules:
+    wav2vec2: !ref <wav2vec2>
+    attn_pooling: !ref <attn_pooling>
+    LaBSE: !ref <LaBSE>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <attn_pooling>, !ref <attn_pooling>]
+
+adam_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr>
+
+wav2vec_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_wav2vec>
+
+labse_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_labse>
+
+lr_annealing_adam: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.5
+    patient: 2
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr_wav2vec>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.9
+
+lr_annealing_labse: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr_labse>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.9
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        wav2vec2: !ref <wav2vec2>
+        LaBSE: !ref <LaBSE>
+        lr_annealing_adam: !ref <lr_annealing_adam>
+        lr_annealing_wav2vec: !ref <lr_annealing_wav2vec>
+        lr_annealing_labse: !ref <lr_annealing_labse>
+        counter: !ref <epoch_counter>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_mbart_st.yaml b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_mbart_st.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6887c3a4084cd6f4d8eff06de83bcf620d2aa939
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_mbart_st.yaml
@@ -0,0 +1,198 @@
+# ############################################################################
+# Model: E2E ST with SAMU encoder and mBART decoder
+# Encoder: SAMU
+# Decoder: mBART decoder
+# losses: NLL
+# Training: Tamasheq-French corpus
+# Author:  Ha Nguyen, 2023
+# ############################################################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1337 #7777
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+debug: False
+output_folder: !ref results/samu_mbart/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+wer_file: !ref <output_folder>/wer.txt
+bleu_file: !ref <output_folder>/bleu.txt
+
+# root data folder points to 17h version inside the github folder (IWSLT2022_Tamasheq_data/taq_fra_clean/)
+root_data_folder: !PLACEHOLDER # e.g., /users/hnguyen/IWSLT2022_Tamasheq_data/taq_fra_clean
+# data folder is the place where the json files will be stored prior to training
+data_folder: !ref <root_data_folder>/json_version/
+lang: "fr" #for the BLEU score detokenization
+target_lang: "fr_XX" # for mbart initialization
+
+annotation_train: !ref <data_folder>/train.json
+annotation_valid: !ref <data_folder>/valid.json
+annotation_test: !ref <data_folder>/test.json
+skip_prep: False
+
+# URL for the HuggingFace model we want to load (BASE here)
+wav2vec2_hub: LIA-AvignonUniversity/IWSLT2022-tamasheq-only
+wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
+
+# wav2vec 2.0 specific parameters
+wav2vec2_frozen: False
+
+####################### Training Parameters ####################################
+number_of_epochs: 500
+lr: 0.001
+lr_wav2vec: 0.0001
+lr_mbart: 0.0001
+batch_size: 2
+test_batch_size: 1
+grad_accumulation_factor: 6
+valid_search_interval: 4
+loss_reduction: batchmean
+ckpt_interval_minutes: 15 # save checkpoint every N min
+
+# Data sorting parameters: sorting_debug_duration replaces sorting_min_duration in debug mode
+sorting: ascending
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: 4
+
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    num_workers: 4
+
+# Feature parameters (W2V2 etc)
+features_dim: 768 # base wav2vec output dimension, for large replace by 1024
+
+#projection for w2v
+enc_dnn_layers: 1
+enc_dnn_neurons: 1024 #256
+
+# Transformer
+activation: !name:torch.nn.GELU
+
+# Outputs
+label_smoothing: 0.1
+pad_index: 1      # pad_index defined by mbart model
+bos_index: 250008 # fr_XX bos_index defined by mbart model
+eos_index: 2
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+min_decode_ratio: 0.0
+max_decode_ratio: 0.25
+valid_beam_size: 5
+
+############################## models ################################
+#wav2vec model
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <wav2vec2_hub>
+    output_norm: True
+    freeze: !ref <wav2vec2_frozen>
+    save_path: !ref <wav2vec2_folder>
+
+#linear projection
+enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+    input_shape: [null, null, !ref <features_dim>]
+    activation: !ref <activation>
+    dnn_blocks: !ref <enc_dnn_layers>
+    dnn_neurons: !ref <enc_dnn_neurons>
+
+#mBART
+mbart_path: facebook/mbart-large-50-many-to-many-mmt
+mbart_frozen: False
+vocab_size: 250054
+mBART: !new:speechbrain.lobes.models.huggingface_transformers.mbart.mBART
+    source: !ref <mbart_path>
+    freeze: !ref <mbart_frozen>
+    save_path: !ref <save_folder>/mbart_checkpoint
+    target_lang: !ref <target_lang>
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+    apply_log: True
+
+modules:
+    wav2vec2: !ref <wav2vec2>
+    enc: !ref <enc>
+    mBART: !ref <mBART>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <enc>]
+
+adam_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr>
+
+wav2vec_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_wav2vec>
+
+mbart_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_mbart>
+
+seq_cost: !name:speechbrain.nnet.losses.nll_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+lr_annealing_adam: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.5
+    patient: 2
+
+warmup: 8000
+hold: 32000
+cooldown: 40000
+optimizer_step_limit: 80000
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.TriStageLRSchedule
+    lr: !ref <lr_wav2vec>
+    warmup_steps: !ref <warmup>
+    hold_steps: !ref <hold>
+    decay_steps: !ref <cooldown>
+    total_steps: !ref <optimizer_step_limit>
+
+lr_annealing_mbart: !new:speechbrain.nnet.schedulers.TriStageLRSchedule
+    lr: !ref <lr_mbart>
+    warmup_steps: !ref <warmup>
+    hold_steps: !ref <hold>
+    decay_steps: !ref <cooldown>
+    total_steps: !ref <optimizer_step_limit>
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        wav2vec2: !ref <wav2vec2>
+        mBART: !ref <mBART>
+        lr_annealing_wav2vec: !ref <lr_annealing_wav2vec>
+        lr_annealing_mbart: !ref <lr_annealing_mbart>
+        counter: !ref <epoch_counter>
+
+valid_search: !new:speechbrain.decoders.S2SHFTextBasedBeamSearcher
+    modules: [!ref <mBART>, null, null]
+    vocab_size: !ref <vocab_size>
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: True
+    length_normalization: True
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+bleu_computer: !name:speechbrain.utils.bleu.BLEUStats
+    merge_words: False
+    lang: !ref <lang>
+
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+
+# Path to the samu checkpoint
+pre_trained_samu: !PLACEHOLDER # e.g., /users/hnguyen/output_samu_pretraining/7777/save/CKPT+checkpoint_epoch100/wav2vec2.ckpt
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>
+    loadables:
+        wav2vec: !ref <wav2vec2>
+    paths:
+        wav2vec: !ref <pre_trained_samu>
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_nllb_st.yaml b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_nllb_st.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b86cef685336bf528a8e1bd65941870283121c8e
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_nllb_st.yaml
@@ -0,0 +1,198 @@
+# ############################################################################
+# Model: E2E ST with SAMU encoder and mBART decoder
+# Encoder: SAMU
+# Decoder: mBART decoder
+# losses: NLL
+# Training: Tamasheq-French corpus
+# Author:  Ha Nguyen, 2023
+# ############################################################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1337 #7777
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+debug: False
+output_folder: !ref results/samu_nllb1.3B/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+wer_file: !ref <output_folder>/wer.txt
+bleu_file: !ref <output_folder>/bleu.txt
+
+# root data folder points to 17h version inside the github folder (IWSLT2022_Tamasheq_data/taq_fra_clean/)
+root_data_folder: !PLACEHOLDER # e.g., /users/hnguyen/IWSLT2022_Tamasheq_data/taq_fra_clean
+# data folder is the place where the json files will be stored prior to training
+data_folder: !ref <root_data_folder>/json_version/
+lang: "fr" #for the BLEU score detokenization
+target_lang: "fra_Latn" # for nllb initialization
+
+annotation_train: !ref <data_folder>/train.json
+annotation_valid: !ref <data_folder>/valid.json
+annotation_test: !ref <data_folder>/test.json
+skip_prep: False
+
+# URL for the HuggingFace model we want to load (BASE here)
+wav2vec2_hub: LIA-AvignonUniversity/IWSLT2022-tamasheq-only
+wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
+
+# wav2vec 2.0 specific parameters
+wav2vec2_frozen: False
+
+####################### Training Parameters ####################################
+number_of_epochs: 500
+lr: 0.001
+lr_wav2vec: 0.0001
+lr_mbart: 0.0001
+batch_size: 2
+test_batch_size: 1
+grad_accumulation_factor: 6
+valid_search_interval: 4
+loss_reduction: batchmean
+ckpt_interval_minutes: 15 # save checkpoint every N min
+
+# Data sorting parameters: sorting_debug_duration replaces sorting_min_duration in debug mode
+sorting: ascending
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: 4
+
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    num_workers: 4
+
+# Feature parameters (W2V2 etc)
+features_dim: 768 # base wav2vec output dimension, for large replace by 1024
+
+#projection for w2v
+enc_dnn_layers: 1
+enc_dnn_neurons: 1024 #256
+
+# Transformer
+activation: !name:torch.nn.GELU
+
+# Outputs
+label_smoothing: 0.1
+pad_index: 1      # pad_index defined by nllb model
+bos_index: 256057 # fra_Latn bos_index defined by nllb model
+eos_index: 2
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+min_decode_ratio: 0.0
+max_decode_ratio: 0.25
+valid_beam_size: 5
+
+############################## models ################################
+#wav2vec model
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <wav2vec2_hub>
+    output_norm: True
+    freeze: !ref <wav2vec2_frozen>
+    save_path: !ref <wav2vec2_folder>
+
+#linear projection
+enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+    input_shape: [null, null, !ref <features_dim>]
+    activation: !ref <activation>
+    dnn_blocks: !ref <enc_dnn_layers>
+    dnn_neurons: !ref <enc_dnn_neurons>
+
+#mBART
+mbart_path: facebook/nllb-200-1.3B
+mbart_frozen: False
+vocab_size: 256206
+mBART: !new:speechbrain.lobes.models.huggingface_transformers.nllb.NLLB
+    source: !ref <mbart_path>
+    freeze: !ref <mbart_frozen>
+    save_path: !ref <save_folder>/mbart_checkpoint
+    target_lang: !ref <target_lang>
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+    apply_log: True
+
+modules:
+    wav2vec2: !ref <wav2vec2>
+    enc: !ref <enc>
+    mBART: !ref <mBART>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <enc>]
+
+adam_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr>
+
+wav2vec_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_wav2vec>
+
+mbart_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_mbart>
+
+seq_cost: !name:speechbrain.nnet.losses.nll_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+lr_annealing_adam: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.5
+    patient: 2
+
+warmup: 8000
+hold: 32000
+cooldown: 40000
+optimizer_step_limit: 80000
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.TriStageLRSchedule
+    lr: !ref <lr_wav2vec>
+    warmup_steps: !ref <warmup>
+    hold_steps: !ref <hold>
+    decay_steps: !ref <cooldown>
+    total_steps: !ref <optimizer_step_limit>
+
+lr_annealing_mbart: !new:speechbrain.nnet.schedulers.TriStageLRSchedule
+    lr: !ref <lr_mbart>
+    warmup_steps: !ref <warmup>
+    hold_steps: !ref <hold>
+    decay_steps: !ref <cooldown>
+    total_steps: !ref <optimizer_step_limit>
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        wav2vec2: !ref <wav2vec2>
+        mBART: !ref <mBART>
+        lr_annealing_wav2vec: !ref <lr_annealing_wav2vec>
+        lr_annealing_mbart: !ref <lr_annealing_mbart>
+        counter: !ref <epoch_counter>
+
+valid_search: !new:speechbrain.decoders.S2SHFTextBasedBeamSearcher
+    modules: [!ref <mBART>, null, null]
+    vocab_size: !ref <vocab_size>
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: True
+    length_normalization: True
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+bleu_computer: !name:speechbrain.utils.bleu.BLEUStats
+    merge_words: False
+    lang: !ref <lang>
+
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+
+# Path to the samu checkpoint
+pre_trained_samu: !PLACEHOLDER # e.g., /users/hnguyen/output_samu_pretraining/7777/save/CKPT+checkpoint_epoch100/wav2vec2.ckpt
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>
+    loadables:
+        wav2vec: !ref <wav2vec2>
+    paths:
+        wav2vec: !ref <pre_trained_samu>
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_mbart_st.yaml b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_mbart_st.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..77b7c8cd6ceb36b004fb81e798a65a08dce43025
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_mbart_st.yaml
@@ -0,0 +1,189 @@
+# ############################################################################
+# Model: E2E ST with wav2vec 2.0 encoder and mBART decoder
+# Encoder: wav2vec 2.0
+# Decoder: mBART decoder
+# losses: NLL
+# Training: Tamasheq-French corpus
+# Author:  Ha Nguyen, 2023
+# ############################################################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1337 #7777
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+debug: False
+output_folder: !ref results/w2v2_mbart/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+wer_file: !ref <output_folder>/wer.txt
+bleu_file: !ref <output_folder>/bleu.txt
+
+# root data folder points to 17h version inside the github folder (IWSLT2022_Tamasheq_data/taq_fra_clean/)
+root_data_folder: !PLACEHOLDER # e.g., /users/hnguyen/IWSLT2022_Tamasheq_data/taq_fra_clean
+# data folder is the place where the json files will be stored prior to training
+data_folder: !ref <root_data_folder>/json_version/
+lang: "fr" #for the BLEU score detokenization
+target_lang: "fr_XX" # for mbart initialization
+
+annotation_train: !ref <data_folder>/train.json
+annotation_valid: !ref <data_folder>/valid.json
+annotation_test: !ref <data_folder>/test.json
+skip_prep: False
+
+# URL for the HuggingFace model we want to load (BASE here)
+wav2vec2_hub: LIA-AvignonUniversity/IWSLT2022-tamasheq-only
+wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
+
+# wav2vec 2.0 specific parameters
+wav2vec2_frozen: False
+
+####################### Training Parameters ####################################
+number_of_epochs: 500
+lr: 0.001
+lr_wav2vec: 0.0001
+lr_mbart: 0.0001
+batch_size: 2
+test_batch_size: 1
+grad_accumulation_factor: 6
+valid_search_interval: 4
+loss_reduction: batchmean
+ckpt_interval_minutes: 15 # save checkpoint every N min
+
+# Data sorting parameters: sorting_debug_duration replaces sorting_min_duration in debug mode
+sorting: ascending
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: 4
+
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    num_workers: 4
+
+# Feature parameters (W2V2 etc)
+features_dim: 768 # base wav2vec output dimension, for large replace by 1024
+
+#projection for w2v
+enc_dnn_layers: 1
+enc_dnn_neurons: 1024 #256
+
+# Transformer
+activation: !name:torch.nn.GELU
+
+# Outputs
+label_smoothing: 0.1
+pad_index: 1      # pad_index defined by mbart model
+bos_index: 250008 # fr_XX bos_index defined by mbart model
+eos_index: 2
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+min_decode_ratio: 0.0
+max_decode_ratio: 0.25
+valid_beam_size: 5
+
+############################## models ################################
+#wav2vec model
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <wav2vec2_hub>
+    output_norm: True
+    freeze: !ref <wav2vec2_frozen>
+    save_path: !ref <wav2vec2_folder>
+
+#linear projection
+enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+    input_shape: [null, null, !ref <features_dim>]
+    activation: !ref <activation>
+    dnn_blocks: !ref <enc_dnn_layers>
+    dnn_neurons: !ref <enc_dnn_neurons>
+
+#mBART
+mbart_path: facebook/mbart-large-50-many-to-many-mmt
+mbart_frozen: False
+vocab_size: 250054
+mBART: !new:speechbrain.lobes.models.huggingface_transformers.mbart.mBART
+    source: !ref <mbart_path>
+    freeze: !ref <mbart_frozen>
+    save_path: !ref <save_folder>/mbart_checkpoint
+    target_lang: !ref <target_lang>
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+    apply_log: True
+
+modules:
+    wav2vec2: !ref <wav2vec2>
+    enc: !ref <enc>
+    mBART: !ref <mBART>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <enc>]
+
+adam_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr>
+
+wav2vec_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_wav2vec>
+
+mbart_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_mbart>
+
+seq_cost: !name:speechbrain.nnet.losses.nll_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+lr_annealing_adam: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.5
+    patient: 2
+
+warmup: 8000
+hold: 32000
+cooldown: 40000
+optimizer_step_limit: 80000
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.TriStageLRSchedule
+    lr: !ref <lr_wav2vec>
+    warmup_steps: !ref <warmup>
+    hold_steps: !ref <hold>
+    decay_steps: !ref <cooldown>
+    total_steps: !ref <optimizer_step_limit>
+
+lr_annealing_mbart: !new:speechbrain.nnet.schedulers.TriStageLRSchedule
+    lr: !ref <lr_mbart>
+    warmup_steps: !ref <warmup>
+    hold_steps: !ref <hold>
+    decay_steps: !ref <cooldown>
+    total_steps: !ref <optimizer_step_limit>
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        wav2vec2: !ref <wav2vec2>
+        mBART: !ref <mBART>
+        lr_annealing_wav2vec: !ref <lr_annealing_wav2vec>
+        lr_annealing_mbart: !ref <lr_annealing_mbart>
+        counter: !ref <epoch_counter>
+
+valid_search: !new:speechbrain.decoders.S2SHFTextBasedBeamSearcher
+    modules: [!ref <mBART>, null, null]
+    vocab_size: !ref <vocab_size>
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: True
+    length_normalization: True
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+bleu_computer: !name:speechbrain.utils.bleu.BLEUStats
+    merge_words: False
+    lang: !ref <lang>
+
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_nllb_st.yaml b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_nllb_st.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d384bf3a86cbbe391c645bc4eabd49d1d74dc1cd
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_nllb_st.yaml
@@ -0,0 +1,189 @@
+# ############################################################################
+# Model: E2E ST with wav2vec 2.0 encoder and NLLB decoder
+# Encoder: wav2vec 2.0
+# Decoder: NLLB decoder
+# losses: NLL
+# Training: Tamasheq-French corpus
+# Author:  Ha Nguyen, 2023
+# ############################################################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1337 #7777
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+debug: False
+output_folder: !ref results/w2v2_nllb1.3B/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+wer_file: !ref <output_folder>/wer.txt
+bleu_file: !ref <output_folder>/bleu.txt
+
+# root data folder points to 17h version inside the github folder (IWSLT2022_Tamasheq_data/taq_fra_clean/)
+root_data_folder: !PLACEHOLDER # e.g., /users/hnguyen/IWSLT2022_Tamasheq_data/taq_fra_clean
+# data folder is the place where the json files will be stored prior to training
+data_folder: !ref <root_data_folder>/json_version/
+lang: "fr" #for the BLEU score detokenization
+target_lang: "fra_Latn" # for nllb initialization
+
+annotation_train: !ref <data_folder>/train.json
+annotation_valid: !ref <data_folder>/valid.json
+annotation_test: !ref <data_folder>/test.json
+skip_prep: False
+
+# URL for the HuggingFace model we want to load (BASE here)
+wav2vec2_hub: LIA-AvignonUniversity/IWSLT2022-tamasheq-only
+wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
+
+# wav2vec 2.0 specific parameters
+wav2vec2_frozen: False
+
+####################### Training Parameters ####################################
+number_of_epochs: 500
+lr: 0.001
+lr_wav2vec: 0.0001
+lr_mbart: 0.0001
+batch_size: 2
+test_batch_size: 1
+grad_accumulation_factor: 6
+valid_search_interval: 4
+loss_reduction: batchmean
+ckpt_interval_minutes: 15 # save checkpoint every N min
+
+# Data sorting parameters: sorting_debug_duration replaces sorting_min_duration in debug mode
+sorting: ascending
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+dataloader_options:
+    batch_size: !ref <batch_size>
+    num_workers: 4
+
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    num_workers: 4
+
+# Feature parameters (W2V2 etc)
+features_dim: 768 # base wav2vec output dimension, for large replace by 1024
+
+#projection for w2v
+enc_dnn_layers: 1
+enc_dnn_neurons: 1024 #256
+
+# Transformer
+activation: !name:torch.nn.GELU
+
+# Outputs
+label_smoothing: 0.1
+pad_index: 1      # pad_index defined by nllb model
+bos_index: 256057 # fra_Latn bos_index defined by nllb model
+eos_index: 2
+
+# Decoding parameters
+# Be sure that the bos and eos index match with the BPEs ones
+min_decode_ratio: 0.0
+max_decode_ratio: 0.25
+valid_beam_size: 5
+
+############################## models ################################
+#wav2vec model
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+    source: !ref <wav2vec2_hub>
+    output_norm: True
+    freeze: !ref <wav2vec2_frozen>
+    save_path: !ref <wav2vec2_folder>
+
+#linear projection
+enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+    input_shape: [null, null, !ref <features_dim>]
+    activation: !ref <activation>
+    dnn_blocks: !ref <enc_dnn_layers>
+    dnn_neurons: !ref <enc_dnn_neurons>
+
+#mBART
+mbart_path: facebook/nllb-200-1.3B
+mbart_frozen: False
+vocab_size: 256206
+mBART: !new:speechbrain.lobes.models.huggingface_transformers.nllb.NLLB
+    source: !ref <mbart_path>
+    freeze: !ref <mbart_frozen>
+    save_path: !ref <save_folder>/mbart_checkpoint
+    target_lang: !ref <target_lang>
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+    apply_log: True
+
+modules:
+    wav2vec2: !ref <wav2vec2>
+    enc: !ref <enc>
+    mBART: !ref <mBART>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <enc>]
+
+adam_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr>
+
+wav2vec_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_wav2vec>
+
+mbart_opt_class: !name:torch.optim.Adam
+    lr: !ref <lr_mbart>
+
+seq_cost: !name:speechbrain.nnet.losses.nll_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+lr_annealing_adam: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.5
+    patient: 2
+
+warmup: 8000
+hold: 32000
+cooldown: 40000
+optimizer_step_limit: 80000
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.TriStageLRSchedule
+    lr: !ref <lr_wav2vec>
+    warmup_steps: !ref <warmup>
+    hold_steps: !ref <hold>
+    decay_steps: !ref <cooldown>
+    total_steps: !ref <optimizer_step_limit>
+
+lr_annealing_mbart: !new:speechbrain.nnet.schedulers.TriStageLRSchedule
+    lr: !ref <lr_mbart>
+    warmup_steps: !ref <warmup>
+    hold_steps: !ref <hold>
+    decay_steps: !ref <cooldown>
+    total_steps: !ref <optimizer_step_limit>
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        wav2vec2: !ref <wav2vec2>
+        mBART: !ref <mBART>
+        lr_annealing_wav2vec: !ref <lr_annealing_wav2vec>
+        lr_annealing_mbart: !ref <lr_annealing_mbart>
+        counter: !ref <epoch_counter>
+
+valid_search: !new:speechbrain.decoders.S2SHFTextBasedBeamSearcher
+    modules: [!ref <mBART>, null, null]
+    vocab_size: !ref <vocab_size>
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: True
+    length_normalization: True
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+bleu_computer: !name:speechbrain.utils.bleu.BLEUStats
+    merge_words: False
+    lang: !ref <lang>
+
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
diff --git a/recipes/IWSLT22_lowresource/hparams/train_w2v2_st.yaml b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_st.yaml
similarity index 93%
rename from recipes/IWSLT22_lowresource/hparams/train_w2v2_st.yaml
rename to recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_st.yaml
index 2356aa996e75dde8c9455071ed9431a6042bd2fb..beafeba864cf0113598622bfe70b390bafc6f18d 100644
--- a/recipes/IWSLT22_lowresource/hparams/train_w2v2_st.yaml
+++ b/recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_st.yaml
@@ -12,7 +12,7 @@
 seed: 5988
 __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
 debug: False
-output_folder: !ref output/<seed>
+output_folder: !ref results/w2v2_st/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
@@ -36,7 +36,7 @@ wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
 wav2vec2_frozen: False
 keep_n_layers: 6 # keep first N layers from the Transformer Encoder stack inside the wav2vec 2.0 model
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 100
 lr: 0.001
 lr_wav2vec: 0.00001
@@ -81,7 +81,6 @@ output_neurons: !ref <vocab_size> # /!\ needs to be changed accordingly to the v
 attention_type: "regularMHA" # "RelPosMHAXL" or "regularMHA"
 
 # Outputs
-blank_index: 0
 label_smoothing: 0.1
 pad_index: 0
 bos_index: 1
@@ -96,7 +95,7 @@ test_beam_size: 5
 
 ############################## models ################################
 #wav2vec model
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <wav2vec2_frozen>
@@ -174,22 +173,20 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, null]
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
     using_eos_threshold: False
-    length_normalization: False
+    length_normalization: True
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, null]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py b/recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py
new file mode 120000
index 0000000000000000000000000000000000000000..93e733e48940a36dbde7a47982b5c632118d5824
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py
@@ -0,0 +1 @@
+../../prepare_iwslt22.py
\ No newline at end of file
diff --git a/recipes/IWSLT22_lowresource/train.py b/recipes/IWSLT22_lowresource/AST/transformer/train.py
similarity index 91%
rename from recipes/IWSLT22_lowresource/train.py
rename to recipes/IWSLT22_lowresource/AST/transformer/train.py
index 5108b7d2c1ae8fde3662c74859a92d8a503b439b..e799f8cee74e65421d77263c9eed31ead80b3f6e 100644
--- a/recipes/IWSLT22_lowresource/train.py
+++ b/recipes/IWSLT22_lowresource/AST/transformer/train.py
@@ -48,9 +48,10 @@ class ST(sb.core.Brain):
         hyps = None
         if stage == sb.Stage.VALID:
             # the output of the encoder (enc) is used for valid search
-            hyps, _ = self.hparams.valid_search(src.detach(), wav_lens)
+            hyps, _, _, _ = self.hparams.valid_search(src.detach(), wav_lens)
+
         elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_search(src.detach(), wav_lens)
+            hyps, _, _, _ = self.hparams.test_search(src.detach(), wav_lens)
 
         return p_seq, wav_lens, hyps
 
@@ -88,43 +89,28 @@ class ST(sb.core.Brain):
         return loss
 
     def init_optimizers(self):
-        # Initializes the wav2vec2 optimizer if the model is not wav2vec2_frozen
-        if not self.hparams.wav2vec2_frozen:
-            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
-                self.modules.wav2vec2.parameters()
-            )
         self.adam_optimizer = self.hparams.adam_opt_class(
             self.hparams.model.parameters()
         )
 
-    def zero_grad(self, set_to_none=False):
-        if not self.hparams.wav2vec2_frozen:
-            self.wav2vec_optimizer.zero_grad(set_to_none)
-        self.adam_optimizer.zero_grad(set_to_none)
-
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-
-        if self.check_gradients(loss):
-            if not self.hparams.wav2vec2_frozen:  # if wav2vec2 is not frozen
-                self.wav2vec_optimizer.step()
-            self.adam_optimizer.step()
+        self.optimizers_dict = {"model_optimizer": self.adam_optimizer}
 
+        # Initializes the wav2vec2 optimizer if the model is not wav2vec2_frozen
         if not self.hparams.wav2vec2_frozen:
-            self.wav2vec_optimizer.zero_grad()
-        self.adam_optimizer.zero_grad()
-
-        return loss.detach().cpu()
+            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
+                self.modules.wav2vec2.parameters()
+            )
+            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
 
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
+    def freeze_optimizers(self, optimizers):
+        """Freezes the wav2vec2 optimizer according to the warmup steps"""
+        valid_optimizers = {}
+        if not self.hparams.wav2vec2_frozen:
+            valid_optimizers["wav2vec_optimizer"] = optimizers[
+                "wav2vec_optimizer"
+            ]
+        valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+        return valid_optimizers
 
     def on_stage_start(self, stage, epoch):
         """Gets called when a stage (either training, validation, test) starts."""
@@ -148,7 +134,7 @@ class ST(sb.core.Brain):
             current_epoch = self.hparams.epoch_counter.current
 
         # log stats and save checkpoint at end-of-epoch
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
             current_epoch = self.hparams.epoch_counter.current
             old_lr_adam, new_lr_adam = self.hparams.lr_annealing_adam(
                 stage_stats["BLEU"]
@@ -254,7 +240,7 @@ def dataio_prepare(hparams):
     # 2. load data and tokenize with trained tokenizer
     datasets = {}
     for dataset in ["train", "valid"]:
-        json_path = f"{data_folder}/{dataset}.json"
+        json_path = hparams[f"annotation_{dataset}"]
 
         is_use_sp = dataset == "train" and "speed_perturb" in hparams
         audio_pipeline_func = sp_audio_pipeline if is_use_sp else audio_pipeline
@@ -376,7 +362,6 @@ if __name__ == "__main__":
     # creates a logger
     logger = logging.getLogger(__name__)
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/train_samu.py b/recipes/IWSLT22_lowresource/AST/transformer/train_samu.py
new file mode 100644
index 0000000000000000000000000000000000000000..64bedc1270b10dfad8df5a37084f9c95be4288d4
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/train_samu.py
@@ -0,0 +1,342 @@
+#!/usr/bin/env python3
+"""Recipe for fine-tuning a wav2vec model for semantically enriching: https://arxiv.org/abs/2205.08180.
+
+Author
+ * Ha Nguyen, 2023
+"""
+
+import sys
+import torch
+import logging
+from hyperpyyaml import load_hyperpyyaml
+import speechbrain as sb
+import torch.nn.functional as F
+from speechbrain.utils.distributed import run_on_main
+
+logger = logging.getLogger(__name__)
+
+
+# Define training procedure
+class ST(sb.core.Brain):
+    def compute_forward(self, batch, stage):
+        """Forward computations from the waveform batches to the output probabilities."""
+
+        batch = batch.to(self.device)
+        wavs, wav_lens = batch.sig  # audio
+
+        # wav2vec module
+        feats = self.modules.wav2vec2(wavs)
+
+        # self-attention pooling
+        uttr_embeddings = self.modules.attn_pooling(feats)
+
+        # norm
+        uttr_embeddings = F.normalize(uttr_embeddings, p=2)
+
+        # LaBSE
+        text_embeddings = self.modules.LaBSE(batch.trans)
+
+        return uttr_embeddings, text_embeddings
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss given predictions and targets."""
+        (uttr_embeddings, text_embeddings,) = predictions
+
+        B, S = uttr_embeddings.shape
+        loss = 0.0
+        for b in range(B):
+            cosine_sim = torch.dot(
+                uttr_embeddings[b].float(), text_embeddings[b].float()
+            )
+            loss += 1.0 - cosine_sim
+        loss *= self.hparams.loss_scale
+        return loss
+
+    def init_optimizers(self):
+        self.adam_optimizer = self.hparams.adam_opt_class(
+            self.hparams.model.parameters()
+        )
+
+        self.optimizers_dict = {"model_optimizer": self.adam_optimizer}
+
+        # Initializes the wav2vec2 optimizer if the model is not wav2vec2_frozen
+        if not self.hparams.wav2vec2_frozen:
+            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
+                self.modules.wav2vec2.parameters()
+            )
+            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
+
+        # Initializes the labse optimizer if the model is not labse_frozen
+        if not self.hparams.labse_frozen:
+            self.labse_optimizer = self.hparams.labse_opt_class(
+                self.modules.LaBSE.parameters()
+            )
+            self.optimizers_dict["labse_optimizer"] = self.labse_optimizer
+
+    def freeze_optimizers(self, optimizers):
+        """Freezes the wav2vec2 optimizer according to the warmup steps"""
+        valid_optimizers = {}
+        if not self.hparams.wav2vec2_frozen:
+            valid_optimizers["wav2vec_optimizer"] = optimizers[
+                "wav2vec_optimizer"
+            ]
+        if not self.hparams.labse_frozen:
+            valid_optimizers["labse_optimizer"] = optimizers["labse_optimizer"]
+        valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+        return valid_optimizers
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called when a stage (either training, validation, test) starts."""
+        return
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of a epoch."""
+        # Compute/store important stats
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_loss
+
+        else:  # valid or test
+            stage_stats = {"loss": stage_loss}
+            current_epoch = self.hparams.epoch_counter.current
+
+        # log stats and save checkpoint at end-of-epoch
+        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+            current_epoch = self.hparams.epoch_counter.current
+            old_lr_adam, new_lr_adam = self.hparams.lr_annealing_adam(
+                stage_stats["loss"]
+            )
+            sb.nnet.schedulers.update_learning_rate(
+                self.adam_optimizer, new_lr_adam
+            )
+
+            stats_meta = {
+                "epoch": current_epoch,
+                "lr_adam": old_lr_adam,
+            }
+
+            if not self.hparams.wav2vec2_frozen:
+                (
+                    old_lr_wav2vec,
+                    new_lr_wav2vec,
+                ) = self.hparams.lr_annealing_wav2vec(stage_stats["loss"])
+                sb.nnet.schedulers.update_learning_rate(
+                    self.wav2vec_optimizer, new_lr_wav2vec
+                )
+                stats_meta["lr_wav2vec"] = old_lr_wav2vec
+
+            if not self.hparams.labse_frozen:
+                (old_lr_labse, new_lr_labse,) = self.hparams.lr_annealing_labse(
+                    stage_stats["loss"]
+                )
+                sb.nnet.schedulers.update_learning_rate(
+                    self.labse_optimizer, new_lr_labse
+                )
+                stats_meta["lr_labse"] = old_lr_labse
+
+            self.hparams.train_logger.log_stats(
+                stats_meta=stats_meta,
+                train_stats={"loss": self.train_stats},
+                valid_stats=stage_stats,
+            )
+
+            # create checkpoing
+            meta = {"loss": stage_stats["loss"], "epoch": current_epoch}
+            name = "checkpoint_epoch" + str(current_epoch)
+
+            self.checkpointer.save_and_keep_only(
+                meta=meta, name=name, num_to_keep=10, min_keys=["loss"]
+            )
+
+        elif stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+
+
+# Define custom data procedure
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions."""
+
+    # Define audio pipeline. In this case, we simply read the path contained
+    # in the variable wav with the audio reader.
+    @sb.utils.data_pipeline.takes("path")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        """Load the audio signal. This is done on the CPU in the `collate_fn`."""
+        sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    @sb.utils.data_pipeline.takes("path")
+    @sb.utils.data_pipeline.provides("sig")
+    def sp_audio_pipeline(wav):
+        """Load the audio signal. This is done on the CPU in the `collate_fn`."""
+        sig = sb.dataio.dataio.read_audio(wav)
+        sig = sig.unsqueeze(0)
+        sig = hparams["speed_perturb"](sig)
+        sig = sig.squeeze(0)
+        return sig
+
+    # 3. Define text pipeline:
+    @sb.utils.data_pipeline.takes("trans")
+    @sb.utils.data_pipeline.provides("trans")
+    def reference_text_pipeline(wrd):
+        yield wrd
+
+    datasets = {}
+    data_folder = hparams["data_folder"]
+    for dataset in ["train", "valid"]:
+        json_path = hparams[f"{dataset}_set"]
+
+        is_use_sp = dataset == "train" and "speed_perturb" in hparams
+        audio_pipeline_func = sp_audio_pipeline if is_use_sp else audio_pipeline
+
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=json_path,
+            replacements={"data_root": data_folder},
+            dynamic_items=[audio_pipeline_func, reference_text_pipeline],
+            output_keys=["id", "sig", "duration", "trans"],
+        )
+
+    for dataset in ["test"]:
+        json_path = hparams[f"{dataset}_set"]
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=json_path,
+            replacements={"data_root": data_folder},
+            dynamic_items=[audio_pipeline, reference_text_pipeline],
+            output_keys=["id", "sig", "duration", "trans"],
+        )
+
+    # Sorting training data with ascending order makes the code  much
+    # faster  because we minimize zero-padding. In most of the cases, this
+    # does not harm the performance.
+    if hparams["sorting"] == "ascending":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 5},
+                sort_key="duration",
+                reverse=True,
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 5},
+                sort_key="duration",
+                reverse=True,
+            )
+        else:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                sort_key="duration"
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                sort_key="duration"
+            )
+
+        hparams["dataloader_options"]["shuffle"] = False
+        hparams["dataloader_options"]["shuffle"] = False
+    elif hparams["sorting"] == "descending":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 5},
+                sort_key="duration",
+                reverse=True,
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 5},
+                sort_key="duration",
+                reverse=True,
+            )
+        else:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                sort_key="duration", reverse=True
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                sort_key="duration", reverse=True
+            )
+
+        hparams["dataloader_options"]["shuffle"] = False
+        hparams["dataloader_options"]["shuffle"] = False
+    elif hparams["sorting"] == "random":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 3},
+                key_max_value={"duration": 5},
+                sort_key="duration",
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1}, key_max_value={"duration": 5},
+            )
+
+        hparams["dataloader_options"]["shuffle"] = True
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+
+    return datasets
+
+
+if __name__ == "__main__":
+
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Create main experiment class
+    st_brain = ST(
+        modules=hparams["modules"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    # Data preparation
+    import prepare_iwslt22
+
+    if not hparams["skip_prep"]:
+        run_on_main(
+            prepare_iwslt22.data_proc,
+            kwargs={
+                "dataset_folder": hparams["root_data_folder"],
+                "output_folder": hparams["data_folder"],
+            },
+        )
+
+    # We can now directly create the datasets for training, valid, and test
+    datasets = dataio_prepare(hparams)
+
+    # Training
+    st_brain.fit(
+        st_brain.hparams.epoch_counter,
+        datasets["train"],
+        datasets["valid"],
+        train_loader_kwargs=hparams["dataloader_options"],
+        valid_loader_kwargs=hparams["test_dataloader_options"],
+    )
+
+    # Test
+    for dataset in ["valid", "test"]:
+        st_brain.hparams.wer_file = (
+            hparams["output_folder"] + "/wer_test" + ".txt"
+        )
+        st_brain.evaluate(
+            datasets[dataset],
+            test_loader_kwargs=hparams["test_dataloader_options"],
+        )
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/train_with_samu_mbart.py b/recipes/IWSLT22_lowresource/AST/transformer/train_with_samu_mbart.py
new file mode 100644
index 0000000000000000000000000000000000000000..1499d9efabf8bd7d6a8d8f927fb5fb79ba0fe25c
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/train_with_samu_mbart.py
@@ -0,0 +1,448 @@
+#!/usr/bin/env python3
+"""Recipe for fine-tuning a samu model and mBART/NLLB model for the ST task (no transcriptions).
+
+Author
+ * Ha Nguyen, 2023
+"""
+
+import sys
+import torch
+import logging
+
+import speechbrain as sb
+from speechbrain.utils.distributed import run_on_main
+from hyperpyyaml import load_hyperpyyaml
+from sacremoses import MosesDetokenizer
+from torch.nn.parallel import DistributedDataParallel
+
+logger = logging.getLogger(__name__)
+
+
+# Define training procedure
+class ST(sb.core.Brain):
+    def compute_forward(self, batch, stage):
+        """Forward computations from the waveform batches to the output probabilities."""
+
+        batch = batch.to(self.device)
+        wavs, wav_lens = batch.sig  # audio
+        tokens_bos, _ = batch.tokens_bos  # translation
+
+        src = self.modules.wav2vec2(wavs, wav_lens)
+
+        # dimensionality reduction
+        src = self.modules.enc(src)
+
+        dec_out = self.modules.mBART(
+            src, tokens_bos, pad_idx=self.hparams.pad_index
+        )
+
+        # logits and softmax
+        p_seq = self.hparams.log_softmax(dec_out)
+        if hparams["mbart_frozen"] and not p_seq.requires_grad:
+            p_seq.requires_grad = True
+
+        # compute outputs
+        hyps = None
+        if stage == sb.Stage.VALID and self.optimizer_step >= 1000:
+            # the output of the encoder (enc) is used for valid search
+            current_epoch = self.hparams.epoch_counter.current
+            if current_epoch % self.hparams.valid_search_interval == 0:
+                if isinstance(self.modules.mBART, DistributedDataParallel):
+                    self.modules.mBART = self.modules.mBART.module
+                hyps, _, _, _ = self.hparams.valid_search(
+                    src.detach(), wav_lens
+                )
+
+        elif stage == sb.Stage.TEST:
+            if isinstance(self.modules.mBART, DistributedDataParallel):
+                self.modules.mBART = self.modules.mBART.module
+            hyps, _, _, _ = self.hparams.valid_search(src.detach(), wav_lens)
+
+        return p_seq, wav_lens, hyps
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss given predictions and targets."""
+        (p_seq, wav_lens, hyps) = predictions
+        ids = batch.id
+        tokens_eos, tokens_eos_lens = batch.tokens_eos
+
+        # st loss
+        tokens_eos = self.modules.mBART.custom_padding(
+            tokens_eos,
+            0,
+            self.modules.mBART.model.model.decoder.config.pad_token_id,
+        )
+        loss = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens)
+
+        fr_detokenizer = MosesDetokenizer(lang=self.hparams.lang)
+
+        if stage != sb.Stage.TRAIN:
+            current_epoch = self.hparams.epoch_counter.current
+            valid_search_interval = self.hparams.valid_search_interval
+            if (
+                current_epoch % valid_search_interval == 0
+                and self.optimizer_step >= 1000
+                or (stage == sb.Stage.TEST)
+            ):
+                detokenized_translation = [
+                    fr_detokenizer.detokenize(translation.split(" "))
+                    for translation in batch.trans
+                ]
+                # it needs to be a list of list due to the extend on the bleu implementation
+                targets = [detokenized_translation]
+
+                predictions = [
+                    fr_detokenizer.detokenize(hyp.split(" "))
+                    for hyp in self.modules.mBART.tokenizer.batch_decode(
+                        hyps, skip_special_tokens=True
+                    )
+                ]
+
+                self.bleu_metric.append(ids, predictions, targets)
+
+            # compute the accuracy of the one-step-forward prediction
+            self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
+
+        return loss
+
+    def init_optimizers(self):
+        self.adam_optimizer = self.hparams.adam_opt_class(
+            self.hparams.model.parameters()
+        )
+
+        self.optimizers_dict = {"model_optimizer": self.adam_optimizer}
+
+        # Initializes the wav2vec2 optimizer if the model is not wav2vec2_frozen
+        if not self.hparams.wav2vec2_frozen:
+            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
+                self.modules.wav2vec2.parameters()
+            )
+            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
+
+        # Initializes the mbart optimizer if the model is not mbart_frozen
+        if not self.hparams.mbart_frozen:
+            self.mbart_optimizer = self.hparams.mbart_opt_class(
+                self.modules.mBART.parameters()
+            )
+            self.optimizers_dict["mbart_optimizer"] = self.mbart_optimizer
+
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            if not self.hparams.wav2vec2_frozen:
+                self.hparams.lr_annealing_wav2vec(
+                    self.wav2vec_optimizer, self.optimizer_step
+                )
+            if not self.hparams.mbart_frozen:
+                self.hparams.lr_annealing_mbart(
+                    self.mbart_optimizer, self.optimizer_step
+                )
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called when a stage (either training, validation, test) starts."""
+        self.bleu_metric = self.hparams.bleu_computer()
+
+        if stage != sb.Stage.TRAIN:
+            self.acc_metric = self.hparams.acc_computer()
+            self.bleu_metric = self.hparams.bleu_computer()
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of a epoch."""
+        # Compute/store important stats
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_loss
+
+        else:  # valid or test
+            stage_stats = {"loss": stage_loss}
+            stage_stats["ACC"] = self.acc_metric.summarize()
+            current_epoch = self.hparams.epoch_counter.current
+            valid_search_interval = self.hparams.valid_search_interval
+            if (
+                current_epoch % valid_search_interval == 0
+                and self.optimizer_step >= 1000
+                or stage == sb.Stage.TEST
+            ):
+                stage_stats["BLEU"] = self.bleu_metric.summarize(field="BLEU")
+                stage_stats["BLEU_extensive"] = self.bleu_metric.summarize()
+                self.anneal_bleu = stage_stats["BLEU"]
+
+        # log stats and save checkpoint at end-of-epoch
+        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+            current_epoch = self.hparams.epoch_counter.current
+            old_lr_adam, new_lr_adam = self.hparams.lr_annealing_adam(
+                self.anneal_bleu  # stage_stats["BLEU"]
+            )
+            sb.nnet.schedulers.update_learning_rate(
+                self.adam_optimizer, new_lr_adam
+            )
+
+            stats_meta = {
+                "epoch": current_epoch,
+                "steps": self.optimizer_step,
+                "lr_adam": old_lr_adam,
+            }
+
+            if not self.hparams.wav2vec2_frozen:
+                self.hparams.lr_annealing_wav2vec(
+                    self.wav2vec_optimizer, self.optimizer_step
+                )
+                stats_meta["lr_wav2vec"] = self.wav2vec_optimizer.param_groups[
+                    0
+                ]["lr"]
+            if not self.hparams.mbart_frozen:
+                self.hparams.lr_annealing_mbart(
+                    self.mbart_optimizer, self.optimizer_step
+                )
+                stats_meta["lr_mbart"] = self.mbart_optimizer.param_groups[0][
+                    "lr"
+                ]
+            self.hparams.train_logger.log_stats(
+                stats_meta=stats_meta,
+                train_stats={"loss": self.train_stats},
+                valid_stats=stage_stats,
+            )
+
+            # create checkpoing
+            valid_search_interval = self.hparams.valid_search_interval
+            if (
+                current_epoch % valid_search_interval == 0
+                and self.optimizer_step >= 1000
+            ):
+                meta = {"BLEU": stage_stats["BLEU"], "epoch": current_epoch}
+                name = "checkpoint_epoch" + str(current_epoch)
+
+                self.checkpointer.save_and_keep_only(
+                    meta=meta, name=name, num_to_keep=10, max_keys=["BLEU"]
+                )
+
+        elif stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+
+            with open(self.hparams.bleu_file, "w") as w:
+                self.bleu_metric.write_stats(w)
+
+
+# Define custom data procedure
+def dataio_prepare(hparams, tokenizer):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions."""
+
+    # Define audio pipeline. In this case, we simply read the path contained
+    # in the variable wav with the audio reader.
+    @sb.utils.data_pipeline.takes("path")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        """Load the audio signal. This is done on the CPU in the `collate_fn`."""
+        sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    @sb.utils.data_pipeline.takes("path")
+    @sb.utils.data_pipeline.provides("sig")
+    def sp_audio_pipeline(wav):
+        """Load the audio signal. This is done on the CPU in the `collate_fn`."""
+        sig = sb.dataio.dataio.read_audio(wav)
+        sig = sig.unsqueeze(0)
+        sig = hparams["speed_perturb"](sig)
+        sig = sig.squeeze(0)
+        return sig
+
+    # Define text processing pipeline. We start from the raw text and then
+    # encode it using the tokenizer. The tokens with BOS are used for feeding
+    # decoder during training, the tokens with EOS for computing the cost function.
+    @sb.utils.data_pipeline.takes("trans")
+    @sb.utils.data_pipeline.provides(
+        "trans", "tokens_list", "tokens_bos", "tokens_eos",
+    )
+    def reference_text_pipeline(translation):
+        """Processes the transcriptions to generate proper labels"""
+        yield translation
+        labels = tokenizer(
+            text_target=translation.replace("\n", ""), return_tensors="pt"
+        )
+        tokens_list = labels["input_ids"].tolist()[-1]
+        yield tokens_list
+        tokens_bos = torch.LongTensor(tokens_list[0:-1])
+        yield tokens_bos
+        tokens_eos = torch.LongTensor(tokens_list[1:])
+        yield tokens_eos
+
+    datasets = {}
+    data_folder = hparams["data_folder"]
+    for dataset in ["train", "valid"]:
+        json_path = hparams[f"annotation_{dataset}"]
+
+        is_use_sp = dataset == "train" and "speed_perturb" in hparams
+        audio_pipeline_func = sp_audio_pipeline if is_use_sp else audio_pipeline
+
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=json_path,
+            replacements={"data_root": data_folder},
+            dynamic_items=[audio_pipeline_func, reference_text_pipeline],
+            output_keys=[
+                "id",
+                "sig",
+                "duration",
+                "trans",
+                "tokens_list",
+                "tokens_bos",
+                "tokens_eos",
+            ],
+        )
+
+    for dataset in ["test"]:
+        json_path = hparams[f"annotation_{dataset}"]
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=json_path,
+            replacements={"data_root": data_folder},
+            dynamic_items=[audio_pipeline, reference_text_pipeline],
+            output_keys=[
+                "id",
+                "sig",
+                "duration",
+                "trans",
+                "tokens_list",
+                "tokens_bos",
+                "tokens_eos",
+            ],
+        )
+
+    # Sorting training data with ascending order makes the code  much
+    # faster  because we minimize zero-padding. In most of the cases, this
+    # does not harm the performance.
+    if hparams["sorting"] == "ascending":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+                reverse=True,
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+                reverse=True,
+            )
+        else:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                sort_key="duration"
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                sort_key="duration"
+            )
+
+        hparams["dataloader_options"]["shuffle"] = False
+        hparams["dataloader_options"]["shuffle"] = False
+    elif hparams["sorting"] == "descending":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+                reverse=True,
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+                reverse=True,
+            )
+        else:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                sort_key="duration", reverse=True
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                sort_key="duration", reverse=True
+            )
+
+        hparams["dataloader_options"]["shuffle"] = False
+        hparams["dataloader_options"]["shuffle"] = False
+    elif hparams["sorting"] == "random":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1}, key_max_value={"duration": 3},
+            )
+
+        hparams["dataloader_options"]["shuffle"] = True
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+
+    return datasets
+
+
+if __name__ == "__main__":
+
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    run_on_main(hparams["pretrainer"].collect_files)
+    hparams["pretrainer"].load_collected()
+
+    # Create main experiment class
+    st_brain = ST(
+        modules=hparams["modules"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    st_brain.anneal_bleu = 0
+
+    # Data preparation
+    import prepare_iwslt22
+
+    if not hparams["skip_prep"]:
+        run_on_main(
+            prepare_iwslt22.data_proc,
+            kwargs={
+                "dataset_folder": hparams["root_data_folder"],
+                "output_folder": hparams["data_folder"],
+            },
+        )
+
+    # We can now directly create the datasets for training, valid, and test
+    datasets = dataio_prepare(hparams, st_brain.modules.mBART.tokenizer)
+
+    # Training
+    st_brain.fit(
+        st_brain.hparams.epoch_counter,
+        datasets["train"],
+        datasets["valid"],
+        train_loader_kwargs=hparams["dataloader_options"],
+        valid_loader_kwargs=hparams["test_dataloader_options"],
+    )
+
+    # Test
+    for dataset in ["valid", "test"]:
+        st_brain.hparams.wer_file = (
+            hparams["output_folder"] + "/wer_test" + ".txt"
+        )
+        st_brain.evaluate(
+            datasets[dataset],
+            test_loader_kwargs=hparams["test_dataloader_options"],
+        )
diff --git a/recipes/IWSLT22_lowresource/AST/transformer/train_with_w2v_mbart.py b/recipes/IWSLT22_lowresource/AST/transformer/train_with_w2v_mbart.py
new file mode 100644
index 0000000000000000000000000000000000000000..a00edf8035fbdd98c82572348eb843c2eaf7d03c
--- /dev/null
+++ b/recipes/IWSLT22_lowresource/AST/transformer/train_with_w2v_mbart.py
@@ -0,0 +1,445 @@
+#!/usr/bin/env python3
+"""Recipe for fine-tuning a wav2vec model and mBART/NLLB model for the ST task (no transcriptions).
+
+Author
+ * Ha Nguyen, 2023
+"""
+
+import sys
+import torch
+import logging
+
+import speechbrain as sb
+from speechbrain.utils.distributed import run_on_main
+from hyperpyyaml import load_hyperpyyaml
+from sacremoses import MosesDetokenizer
+from torch.nn.parallel import DistributedDataParallel
+
+logger = logging.getLogger(__name__)
+
+
+# Define training procedure
+class ST(sb.core.Brain):
+    def compute_forward(self, batch, stage):
+        """Forward computations from the waveform batches to the output probabilities."""
+
+        batch = batch.to(self.device)
+        wavs, wav_lens = batch.sig  # audio
+        tokens_bos, _ = batch.tokens_bos  # translation
+
+        src = self.modules.wav2vec2(wavs, wav_lens)
+
+        # dimensionality reduction
+        src = self.modules.enc(src)
+
+        dec_out = self.modules.mBART(
+            src, tokens_bos, pad_idx=self.hparams.pad_index
+        )
+
+        # logits and softmax
+        p_seq = self.hparams.log_softmax(dec_out)
+        if hparams["mbart_frozen"] and not p_seq.requires_grad:
+            p_seq.requires_grad = True
+
+        # compute outputs
+        hyps = None
+        if stage == sb.Stage.VALID and self.optimizer_step >= 1000:
+            # the output of the encoder (enc) is used for valid search
+            current_epoch = self.hparams.epoch_counter.current
+            if current_epoch % self.hparams.valid_search_interval == 0:
+                if isinstance(self.modules.mBART, DistributedDataParallel):
+                    self.modules.mBART = self.modules.mBART.module
+                hyps, _, _, _ = self.hparams.valid_search(
+                    src.detach(), wav_lens
+                )
+
+        elif stage == sb.Stage.TEST:
+            if isinstance(self.modules.mBART, DistributedDataParallel):
+                self.modules.mBART = self.modules.mBART.module
+            hyps, _, _, _ = self.hparams.valid_search(src.detach(), wav_lens)
+
+        return p_seq, wav_lens, hyps
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss given predictions and targets."""
+        (p_seq, wav_lens, hyps) = predictions
+        ids = batch.id
+        tokens_eos, tokens_eos_lens = batch.tokens_eos
+
+        # st loss
+        tokens_eos = self.modules.mBART.custom_padding(
+            tokens_eos,
+            0,
+            self.modules.mBART.model.model.decoder.config.pad_token_id,
+        )
+        loss = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens)
+
+        fr_detokenizer = MosesDetokenizer(lang=self.hparams.lang)
+
+        if stage != sb.Stage.TRAIN:
+            current_epoch = self.hparams.epoch_counter.current
+            valid_search_interval = self.hparams.valid_search_interval
+            if (
+                current_epoch % valid_search_interval == 0
+                and self.optimizer_step >= 1000
+                or (stage == sb.Stage.TEST)
+            ):
+                detokenized_translation = [
+                    fr_detokenizer.detokenize(translation.split(" "))
+                    for translation in batch.trans
+                ]
+                # it needs to be a list of list due to the extend on the bleu implementation
+                targets = [detokenized_translation]
+
+                predictions = [
+                    fr_detokenizer.detokenize(hyp.split(" "))
+                    for hyp in self.modules.mBART.tokenizer.batch_decode(
+                        hyps, skip_special_tokens=True
+                    )
+                ]
+
+                self.bleu_metric.append(ids, predictions, targets)
+
+            # compute the accuracy of the one-step-forward prediction
+            self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
+
+        return loss
+
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            if not self.hparams.wav2vec2_frozen:
+                self.hparams.lr_annealing_wav2vec(
+                    self.wav2vec_optimizer, self.optimizer_step
+                )
+            if not self.hparams.mbart_frozen:
+                self.hparams.lr_annealing_mbart(
+                    self.mbart_optimizer, self.optimizer_step
+                )
+
+    def init_optimizers(self):
+        self.adam_optimizer = self.hparams.adam_opt_class(
+            self.hparams.model.parameters()
+        )
+
+        self.optimizers_dict = {"model_optimizer": self.adam_optimizer}
+
+        # Initializes the wav2vec2 optimizer if the model is not wav2vec2_frozen
+        if not self.hparams.wav2vec2_frozen:
+            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
+                self.modules.wav2vec2.parameters()
+            )
+            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
+
+        # Initializes the mbart optimizer if the model is not mbart_frozen
+        if not self.hparams.mbart_frozen:
+            self.mbart_optimizer = self.hparams.mbart_opt_class(
+                self.modules.mBART.parameters()
+            )
+            self.optimizers_dict["mbart_optimizer"] = self.mbart_optimizer
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called when a stage (either training, validation, test) starts."""
+        self.bleu_metric = self.hparams.bleu_computer()
+
+        if stage != sb.Stage.TRAIN:
+            self.acc_metric = self.hparams.acc_computer()
+            self.bleu_metric = self.hparams.bleu_computer()
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of a epoch."""
+        # Compute/store important stats
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_loss
+
+        else:  # valid or test
+            stage_stats = {"loss": stage_loss}
+            stage_stats["ACC"] = self.acc_metric.summarize()
+            current_epoch = self.hparams.epoch_counter.current
+            valid_search_interval = self.hparams.valid_search_interval
+            if (
+                current_epoch % valid_search_interval == 0
+                and self.optimizer_step >= 1000
+                or stage == sb.Stage.TEST
+            ):
+                stage_stats["BLEU"] = self.bleu_metric.summarize(field="BLEU")
+                stage_stats["BLEU_extensive"] = self.bleu_metric.summarize()
+                self.anneal_bleu = stage_stats["BLEU"]
+
+        # log stats and save checkpoint at end-of-epoch
+        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+            current_epoch = self.hparams.epoch_counter.current
+            old_lr_adam, new_lr_adam = self.hparams.lr_annealing_adam(
+                self.anneal_bleu  # stage_stats["BLEU"]
+            )
+            sb.nnet.schedulers.update_learning_rate(
+                self.adam_optimizer, new_lr_adam
+            )
+
+            stats_meta = {
+                "epoch": current_epoch,
+                "steps": self.optimizer_step,
+                "lr_adam": old_lr_adam,
+            }
+
+            if not self.hparams.wav2vec2_frozen:
+                self.hparams.lr_annealing_wav2vec(
+                    self.wav2vec_optimizer, self.optimizer_step
+                )
+                stats_meta["lr_wav2vec"] = self.wav2vec_optimizer.param_groups[
+                    0
+                ]["lr"]
+            if not self.hparams.mbart_frozen:
+                self.hparams.lr_annealing_mbart(
+                    self.mbart_optimizer, self.optimizer_step
+                )
+                stats_meta["lr_mbart"] = self.mbart_optimizer.param_groups[0][
+                    "lr"
+                ]
+            self.hparams.train_logger.log_stats(
+                stats_meta=stats_meta,
+                train_stats={"loss": self.train_stats},
+                valid_stats=stage_stats,
+            )
+
+            # create checkpoing
+            valid_search_interval = self.hparams.valid_search_interval
+            if (
+                current_epoch % valid_search_interval == 0
+                and self.optimizer_step >= 1000
+            ):
+                meta = {"BLEU": stage_stats["BLEU"], "epoch": current_epoch}
+                name = "checkpoint_epoch" + str(current_epoch)
+
+                self.checkpointer.save_and_keep_only(
+                    meta=meta, name=name, num_to_keep=10, max_keys=["BLEU"]
+                )
+
+        elif stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+
+            with open(self.hparams.bleu_file, "w") as w:
+                self.bleu_metric.write_stats(w)
+
+
+# Define custom data procedure
+def dataio_prepare(hparams, tokenizer):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions."""
+
+    # Define audio pipeline. In this case, we simply read the path contained
+    # in the variable wav with the audio reader.
+    @sb.utils.data_pipeline.takes("path")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        """Load the audio signal. This is done on the CPU in the `collate_fn`."""
+        sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    @sb.utils.data_pipeline.takes("path")
+    @sb.utils.data_pipeline.provides("sig")
+    def sp_audio_pipeline(wav):
+        """Load the audio signal. This is done on the CPU in the `collate_fn`."""
+        sig = sb.dataio.dataio.read_audio(wav)
+        sig = sig.unsqueeze(0)
+        sig = hparams["speed_perturb"](sig)
+        sig = sig.squeeze(0)
+        return sig
+
+    # Define text processing pipeline. We start from the raw text and then
+    # encode it using the tokenizer. The tokens with BOS are used for feeding
+    # decoder during training, the tokens with EOS for computing the cost function.
+    @sb.utils.data_pipeline.takes("trans")
+    @sb.utils.data_pipeline.provides(
+        "trans", "tokens_list", "tokens_bos", "tokens_eos",
+    )
+    def reference_text_pipeline(translation):
+        """Processes the transcriptions to generate proper labels"""
+        yield translation
+        labels = tokenizer(
+            text_target=translation.replace("\n", ""), return_tensors="pt"
+        )
+        tokens_list = labels["input_ids"].tolist()[-1]
+        yield tokens_list
+        tokens_bos = torch.LongTensor(tokens_list[0:-1])
+        yield tokens_bos
+        tokens_eos = torch.LongTensor(tokens_list[1:])
+        yield tokens_eos
+
+    datasets = {}
+    data_folder = hparams["data_folder"]
+    for dataset in ["train", "valid"]:
+        json_path = hparams[f"annotation_{dataset}"]
+
+        is_use_sp = dataset == "train" and "speed_perturb" in hparams
+        audio_pipeline_func = sp_audio_pipeline if is_use_sp else audio_pipeline
+
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=json_path,
+            replacements={"data_root": data_folder},
+            dynamic_items=[audio_pipeline_func, reference_text_pipeline],
+            output_keys=[
+                "id",
+                "sig",
+                "duration",
+                "trans",
+                "tokens_list",
+                "tokens_bos",
+                "tokens_eos",
+            ],
+        )
+
+    for dataset in ["test"]:
+        json_path = hparams[f"annotation_{dataset}"]
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=json_path,
+            replacements={"data_root": data_folder},
+            dynamic_items=[audio_pipeline, reference_text_pipeline],
+            output_keys=[
+                "id",
+                "sig",
+                "duration",
+                "trans",
+                "tokens_list",
+                "tokens_bos",
+                "tokens_eos",
+            ],
+        )
+
+    # Sorting training data with ascending order makes the code  much
+    # faster  because we minimize zero-padding. In most of the cases, this
+    # does not harm the performance.
+    if hparams["sorting"] == "ascending":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+                reverse=True,
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+                reverse=True,
+            )
+        else:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                sort_key="duration"
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                sort_key="duration"
+            )
+
+        hparams["dataloader_options"]["shuffle"] = False
+        hparams["dataloader_options"]["shuffle"] = False
+    elif hparams["sorting"] == "descending":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+                reverse=True,
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+                reverse=True,
+            )
+        else:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                sort_key="duration", reverse=True
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                sort_key="duration", reverse=True
+            )
+
+        hparams["dataloader_options"]["shuffle"] = False
+        hparams["dataloader_options"]["shuffle"] = False
+    elif hparams["sorting"] == "random":
+        # use smaller dataset to debug the model
+        if hparams["debug"]:
+            datasets["train"] = datasets["train"].filtered_sorted(
+                key_min_value={"duration": 1},
+                key_max_value={"duration": 3},
+                sort_key="duration",
+            )
+            datasets["valid"] = datasets["valid"].filtered_sorted(
+                key_min_value={"duration": 1}, key_max_value={"duration": 3},
+            )
+
+        hparams["dataloader_options"]["shuffle"] = True
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+
+    return datasets
+
+
+if __name__ == "__main__":
+
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Create main experiment class
+    st_brain = ST(
+        modules=hparams["modules"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    st_brain.anneal_bleu = 0
+
+    # Data preparation
+    import prepare_iwslt22
+
+    if not hparams["skip_prep"]:
+        run_on_main(
+            prepare_iwslt22.data_proc,
+            kwargs={
+                "dataset_folder": hparams["root_data_folder"],
+                "output_folder": hparams["data_folder"],
+            },
+        )
+
+    # We can now directly create the datasets for training, valid, and test
+    datasets = dataio_prepare(hparams, st_brain.modules.mBART.tokenizer)
+
+    # Training
+    st_brain.fit(
+        st_brain.hparams.epoch_counter,
+        datasets["train"],
+        datasets["valid"],
+        train_loader_kwargs=hparams["dataloader_options"],
+        valid_loader_kwargs=hparams["test_dataloader_options"],
+    )
+
+    # Test
+    for dataset in ["valid", "test"]:
+        st_brain.hparams.wer_file = (
+            hparams["output_folder"] + "/wer_test" + ".txt"
+        )
+        st_brain.evaluate(
+            datasets[dataset],
+            test_loader_kwargs=hparams["test_dataloader_options"],
+        )
diff --git a/recipes/IWSLT22_lowresource/README.md b/recipes/IWSLT22_lowresource/README.md
deleted file mode 100644
index 6a54df97eb41fcf78b7df8ab3b848c49ed1e3a82..0000000000000000000000000000000000000000
--- a/recipes/IWSLT22_lowresource/README.md
+++ /dev/null
@@ -1,64 +0,0 @@
-# IWSLT 2022 Low-resource Task: Tamasheq-French end-to-end Speech Translation
-
-
-## Description
-
-This is the recipe for the best system from the IWSLT 2022 low-resource task, as described in the original paper.
-The speech translation model comprises a wav2vec 2.0 encoder and a Transformer decoder. It is trained end-to-end without any auxiliary loss. The recipe allows for removing the last layers of the Transformer Encoder inside the wav2vec 2.0 in order to reduce the number of training parameters.
-
-## Data Downloading
-
-For downloading the dataset used for this experiment, please run the following command.
-
-```
-git clone https://github.com/mzboito/IWSLT2022_Tamasheq_data.git
-```
-
-## Installing Extra Dependencies
-
-Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
-
-```
-pip install -r extra_requirements.txt
-```
-
-## Training
-
-For training the model, please update the variables at hparams/train_w2v2_st.yaml.
-
-Note that in order to drop the last layers of the wav2vec 2.0 module, it is necessary to update the parameter "keep_n_layers".
-For instance: Using ``keep_n_layers: 10'' means that only the first 10 layers inside the wav2vec 2.0 Transformer encoder will be used for training. The remaining layers are removed.
-
-For launching training:
-```
-python train.py hparams/train_w2v2_st.yaml
-
-```
-
-## Citation
-```
-@inproceedings{boito-etal-2022-trac,
-    title = "{ON}-{TRAC} Consortium Systems for the {IWSLT} 2022 Dialect and Low-resource Speech Translation Tasks",
-    author = {Boito, Marcely Zanon  and
-      Ortega, John  and
-      Riguidel, Hugo  and
-      Laurent, Antoine  and
-      Barrault, Lo{\"\i}c  and
-      Bougares, Fethi  and
-      Chaabani, Firas  and
-      Nguyen, Ha  and
-      Barbier, Florentin  and
-      Gahbiche, Souhir  and
-      Est{\`e}ve, Yannick},
-    booktitle = "Proceedings of the 19th International Conference on Spoken Language Translation (IWSLT 2022)",
-    month = may,
-    year = "2022",
-    address = "Dublin, Ireland (in-person and online)",
-    publisher = "Association for Computational Linguistics",
-    url = "https://aclanthology.org/2022.iwslt-1.28",
-    doi = "10.18653/v1/2022.iwslt-1.28",
-    pages = "308--318"
-}
-```
-
-
diff --git a/recipes/KsponSpeech/ASR/transformer/hparams/conformer_medium.yaml b/recipes/KsponSpeech/ASR/transformer/hparams/conformer_medium.yaml
index f79f3786fa435ac3933afa5b7e1e2ff8b13b6d85..3c0d43e2ad36b0ee39ca69b1c4b967d530e4fb18 100644
--- a/recipes/KsponSpeech/ASR/transformer/hparams/conformer_medium.yaml
+++ b/recipes/KsponSpeech/ASR/transformer/hparams/conformer_medium.yaml
@@ -34,9 +34,9 @@ test_csv:
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 # To make Transformers converge, the global bath size should be large enough.
-# The global batch size is computed as batch_size * n_gpus * gradient_accumulation.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
 # Empirically, we found that this value should be >= 128.
 # Please, set your parameters accordingly.
 number_of_epochs: 60
@@ -46,6 +46,7 @@ grad_accumulation_factor: 2
 max_grad_norm: 5.0
 loss_reduction: 'batchmean'
 sorting: random
+avg_checkpoints: 5 # Number of checkpoints to average for evaluation
 
 dynamic_batching: False
 
@@ -77,7 +78,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
     batch_size: 1
 
-####################### Model parameters ###########################
+####################### Model Parameters ###########################
 # Transformer
 d_model: 256
 nhead: 4
@@ -172,34 +173,50 @@ Adam: !name:torch.optim.Adam
     eps: 0.000000001
 
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-    bos_index: !ref <bos_index>
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
     eos_index: !ref <eos_index>
     blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.30
+
+valid_scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+test_scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>,
+                   !ref <ctc_scorer>]
+    weights:
+        transformerlm: !ref <lm_weight>
+        ctc: !ref <ctc_weight_decode>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
-    length_normalization: False
-
+    length_normalization: True
+    scorer: !ref <valid_scorer>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
-    lm_weight: !ref <lm_weight>
-    lm_modules: !ref <lm_model>
     temperature: 1.30
-    temperature_lm: 1.30
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <test_scorer>
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -232,19 +249,65 @@ normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
     update_until_epoch: 4
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: False
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 4
-    time_mask: True
-    n_time_mask: 4
-    replace_with_zero: False
-    freq_mask_width: 15
-    time_mask_width: 20
-
-speed_perturb: True
+# Time Drop
+time_drop_length_low: 15  # Min length for temporal chunk to drop in spectrogram
+time_drop_length_high: 25  # Max length for temporal chunk to drop in spectrogram
+time_drop_count_low: 5  # Min number of chunks to drop in time in the spectrogram
+time_drop_count_high: 5  # Max number of chunks to drop in time in the spectrogram
+time_drop_replace: "zeros"  # Method of dropping chunks
+
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: !ref <time_drop_length_low>
+    drop_length_high: !ref <time_drop_length_high>
+    drop_count_low: !ref <time_drop_count_low>
+    drop_count_high: !ref <time_drop_count_high>
+    replace: !ref <time_drop_replace>
+    dim: 1
+
+# Frequency Drop
+freq_drop_length_low: 25  # Min length for chunks to drop in frequency in the spectrogram
+freq_drop_length_high: 35  # Max length for chunks to drop in frequency in the spectrogram
+freq_drop_count_low: 2  # Min number of chunks to drop in frequency in the spectrogram
+freq_drop_count_high: 2  # Max number of chunks to drop in frequency in the spectrogram
+freq_drop_replace: "zeros"  # Method of dropping chunks
+
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: !ref <freq_drop_length_low>
+    drop_length_high: !ref <freq_drop_length_high>
+    drop_count_low: !ref <freq_drop_count_low>
+    drop_count_high: !ref <freq_drop_count_high>
+    replace: !ref <freq_drop_replace>
+    dim: 2
+
+# Time warp
+time_warp_window: 5  # Length of time warping window
+time_warp_mode: "bicubic"  # Time warping method
+
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+    warp_window: !ref <time_warp_window>
+    warp_mode: !ref <time_warp_mode>
+    dim: 1
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: False
+    concat_original: False
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+# Speed perturbation
+do_speed_perturb: True
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
 
 compute_features: !new:speechbrain.lobes.features.Fbank
     sample_rate: !ref <sample_rate>
diff --git a/recipes/KsponSpeech/ASR/transformer/train.py b/recipes/KsponSpeech/ASR/transformer/train.py
index 6ca58c096473bdd87e958a020aa26a80b651bbd4..92ab54fb13fa87fcbaff7164488ac030bc5fb491 100644
--- a/recipes/KsponSpeech/ASR/transformer/train.py
+++ b/recipes/KsponSpeech/ASR/transformer/train.py
@@ -53,22 +53,15 @@ class ASR(sb.core.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, _ = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-
         # compute features
         feats = self.hparams.compute_features(wavs)
         current_epoch = self.hparams.epoch_counter.current
         feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                feats = self.hparams.augmentation(feats)
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)
 
         # forward modules
         src = self.modules.CNN(feats)
@@ -85,19 +78,20 @@ class ASR(sb.core.Brain):
         pred = self.modules.seq_lin(pred)
         p_seq = self.hparams.log_softmax(pred)
 
-        # Compute outputs
         hyps = None
-        if stage == sb.Stage.TRAIN:
-            hyps = None
-        elif stage == sb.Stage.VALID:
-            hyps = None
-            current_epoch = self.hparams.epoch_counter.current
-            if current_epoch % self.hparams.valid_search_interval == 0:
-                # for the sake of efficiency, we only perform beamsearch with limited capacity
-                # and no LM to give user some idea of how the AM is doing
-                hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens)
-        elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
+        current_epoch = self.hparams.epoch_counter.current
+        is_valid_search = (
+            stage == sb.Stage.VALID
+            and current_epoch % self.hparams.valid_search_interval == 0
+        )
+        is_test_search = stage == sb.Stage.TEST
+
+        if is_valid_search:
+            hyps, _, _, _ = self.hparams.valid_search(
+                enc_out.detach(), wav_lens
+            )
+        elif is_test_search:
+            hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
         return p_ctc, p_seq, wav_lens, hyps
 
@@ -110,13 +104,18 @@ class ASR(sb.core.Brain):
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
-            )
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "fea_augment"):
+                tokens = self.hparams.fea_augment.replicate_labels(tokens)
+                tokens_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_lens
+                )
+                tokens_eos = self.hparams.fea_augment.replicate_labels(
+                    tokens_eos
+                )
+                tokens_eos_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_eos_lens
+                )
 
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
@@ -155,47 +154,6 @@ class ASR(sb.core.Brain):
             self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
         return loss
 
-    def fit_batch(self, batch):
-
-        should_step = self.step % self.grad_accumulation_factor == 0
-        # Managing automatic mixed precision
-        if self.auto_mix_prec:
-            self.optimizer.zero_grad()
-            with torch.cuda.amp.autocast():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            self.scaler.scale(loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                self.scaler.unscale_(self.optimizer)
-                if self.check_gradients(loss):
-                    self.scaler.step(self.optimizer)
-                self.scaler.update()
-                self.optimizer_step += 1
-
-                # anneal lr every update
-                self.hparams.noam_annealing(self.optimizer)
-        else:
-            outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            (loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                if self.check_gradients(loss):
-                    self.optimizer.step()
-                self.optimizer.zero_grad()
-                self.optimizer_step += 1
-
-                # anneal lr every update
-                self.hparams.noam_annealing(self.optimizer)
-
-        return loss.detach().cpu()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        with torch.no_grad():
-            predictions = self.compute_forward(batch, stage=stage)
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -221,7 +179,7 @@ class ASR(sb.core.Brain):
                 stage_stats["CER"] = self.cer_metric.summarize("error_rate")
 
         # log stats and save checkpoint at end-of-epoch
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
 
             lr = self.hparams.noam_annealing.current_lr
             steps = self.optimizer_step
@@ -241,7 +199,7 @@ class ASR(sb.core.Brain):
             self.checkpointer.save_and_keep_only(
                 meta={"ACC": stage_stats["ACC"], "epoch": epoch},
                 max_keys=["ACC"],
-                num_to_keep=5,
+                num_to_keep=self.hparams.avg_checkpoints,
             )
 
         elif stage == sb.Stage.TEST:
@@ -271,7 +229,7 @@ class ASR(sb.core.Brain):
             max_key=max_key, min_key=min_key
         )
         ckpt = sb.utils.checkpoints.average_checkpoints(
-            ckpts, recoverable_name="model", device=self.device
+            ckpts, recoverable_name="model",
         )
 
         self.hparams.model.load_state_dict(ckpt, strict=True)
@@ -344,16 +302,9 @@ def dataio_prepare(hparams):
     def audio_pipeline_train(wav):
         # Speed Perturb is done here so it is multi-threaded with the
         # workers of the dataloader (faster).
-        if hparams["speed_perturb"]:
-            sig = sb.dataio.dataio.read_audio(wav)
-            # factor = np.random.uniform(0.95, 1.05)
-            # sig = resample(sig.numpy(), 16000, int(16000*factor))
-            speed = sb.processing.speech_augmentation.SpeedPerturb(
-                16000, [x for x in range(95, 105)]
-            )
-            sig = speed(sig.unsqueeze(0)).squeeze(0)  # torch.from_numpy(sig)
-        else:
-            sig = sb.dataio.dataio.read_audio(wav)
+        sig = sb.dataio.dataio.read_audio(wav)
+        if hparams["do_speed_perturb"]:
+            sig = hparams["speed_perturb"](sig.unsqueeze(0)).squeeze(0)
         return sig
 
     sb.dataio.dataset.add_dynamic_item([train_data], audio_pipeline_train)
@@ -424,7 +375,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -466,7 +416,7 @@ if __name__ == "__main__":
     # We download the pretrained LM from HuggingFace (or elsewhere depending on
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Trainer initialization
     asr_brain = ASR(
diff --git a/recipes/KsponSpeech/LM/hparams/transformer.yaml b/recipes/KsponSpeech/LM/hparams/transformer.yaml
index 1dd28921eb4a0c8286919df8020a0c01370dc681..5b64cc196c4d87d6453f152c3b537be168ab1e6f 100644
--- a/recipes/KsponSpeech/LM/hparams/transformer.yaml
+++ b/recipes/KsponSpeech/LM/hparams/transformer.yaml
@@ -24,11 +24,11 @@ test_csv:
 # Tokenizer model
 tokenizer_file: ddwkim/asr-conformer-transformerlm-ksponspeech/tokenizer.ckpt
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 batch_size: 256
 lr: 0.1
-accu_steps: 4 # Gradient accumulation to simulate large batch training
+grad_accumulation_factor: 4 # Gradient accumulation to simulate large batch training
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
 # Dataloader options
diff --git a/recipes/KsponSpeech/LM/train.py b/recipes/KsponSpeech/LM/train.py
index 7f4e63756b0ad2f6137fb02de43fc995d3c8b275..e5a8e86d0e75c330cc03e6efe6f9f695ace7bd68 100644
--- a/recipes/KsponSpeech/LM/train.py
+++ b/recipes/KsponSpeech/LM/train.py
@@ -44,20 +44,9 @@ class LM(sb.core.Brain):
         )
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        (loss / self.hparams.accu_steps).backward()
-
-        if self.step % self.hparams.accu_steps == 0:
-            # gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
-
-            self.optimizer.step()
-            self.optimizer.zero_grad()
-
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing and logging."""
+        if should_step:
             if isinstance(
                 self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
             ) or isinstance(
@@ -73,15 +62,13 @@ class LM(sb.core.Brain):
                 stats_meta={"step": self.step}, train_stats={"loss": loss},
             )
 
-        return loss
-
     def on_stage_end(self, stage, stage_loss, epoch):
         """Gets called at the end of a epoch."""
         stage_stats = {"loss": stage_loss}
         if stage == sb.Stage.TRAIN:
             self.train_stats = stage_stats
 
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
             if not (
                 isinstance(
                     self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
@@ -167,7 +154,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -185,7 +171,7 @@ if __name__ == "__main__":
     # We download the tokenizer from HuggingFace (or elsewhere depending on
     # the path given in the YAML file).
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     lm_brain = LM(
         modules=hparams["modules"],
diff --git a/recipes/KsponSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml b/recipes/KsponSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml
index dd7cd490646b4b900fe74bc61867ee3cd7135d3a..04ef0ebfd9cf041c20668e89801433d6d5518387 100644
--- a/recipes/KsponSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml
+++ b/recipes/KsponSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml
@@ -16,7 +16,7 @@ skip_prep: False
 train_csv: !ref <output_folder>/train.csv
 valid_csv: !ref <output_folder>/dev.csv
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 5000  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
diff --git a/recipes/KsponSpeech/Tokenizer/train.py b/recipes/KsponSpeech/Tokenizer/train.py
index edf2c533f6962894ecaf32111c0f414838b07b65..4e2a26856a4a8fdbad5048ed470eb5d9e62930cf 100644
--- a/recipes/KsponSpeech/Tokenizer/train.py
+++ b/recipes/KsponSpeech/Tokenizer/train.py
@@ -28,7 +28,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/LJSpeech/TTS/README.md b/recipes/LJSpeech/TTS/README.md
index 1a7d2e16ca64f3ca286f4711ebcd1b038e6309f5..6009decf4b5a1a4981ba1f447ecaaa84c4908f17 100644
--- a/recipes/LJSpeech/TTS/README.md
+++ b/recipes/LJSpeech/TTS/README.md
@@ -5,7 +5,7 @@ This folder contains the recipes for training TTS systems (including vocoders) w
 The dataset can be downloaded from here:
 https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
 
-## Installing Extra Dependencies
+# Installing Extra Dependencies
 
 Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
 
@@ -26,9 +26,10 @@ The training logs are available [here](https://www.dropbox.com/sh/1npvo1g1ncafip
 You can find the pre-trained model with an easy-inference function on [HuggingFace](https://huggingface.co/speechbrain/tts-tacotron2-ljspeech).
 
 # FastSpeech2
-The subfolder "fastspeech2" contains the recipe for training the non-autoregressive transformer based TTS model [FastSpeech2](https://arxiv.org/abs/2006.04558).
+The subfolder "fastspeech2" contains the recipes for training the non-autoregressive transformer based TTS model [FastSpeech2](https://arxiv.org/abs/2006.04558).
 
-Training FastSpeech2 requires pre-extracted phoneme alignments (durations). The LJSpeech phoneme alignments from Montreal Forced Aligner can be automatically downloaded, decompressed and stored at this location: ```/your_folder/LJSpeech-1.1/TextGrid```.
+### FastSpeech2 with pre-extracted durations from a forced aligner
+Training FastSpeech2 requires pre-extracted phoneme alignments (durations). The LJSpeech phoneme alignments from Montreal Forced Aligner are automatically downloaded, decompressed and stored at this location: ```/your_folder/LJSpeech-1.1/TextGrid```.
 
 To run this recipe, please first install the extra-dependencies :
 
@@ -43,10 +44,27 @@ python train.py --data_folder=/your_folder/LJSpeech-1.1 hparams/train.yaml
 ```
 Training takes about 3 minutes/epoch on 1 * V100 32G.
 
-The training logs are available [here](https://www.dropbox.com/sh/tqyp58ogejqfres/AAAtmq7cRoOR3XTsq0iSgyKBa?dl=0).
+The training logs are available [here](https://www.dropbox.com/scl/fo/vtgbltqdrvw9r0vs7jz67/h?rlkey=cm2mwh5rce5ad9e90qaciypox&dl=0).
 
 You can find the pre-trained model with an easy-inference function on [HuggingFace](https://huggingface.co/speechbrain/tts-fastspeech2-ljspeech).
 
+### FastSpeech2 with internal alignment
+This recipe allows training FastSpeech2 without forced aligner referring to [One TTS Alignment To Rule Them All](https://arxiv.org/pdf/2108.10447.pdf). The alignment can be learnt by an internal alignment network that is added to FastSpeech2. This recipe aims to simplify training when using custom data and provide better alignments for punctuations.
+
+To run this recipe, please first install the extra-requirements:
+```
+pip install -r extra_requirements.txt
+```
+Then go into the "fastspeech2" folder and run:
+```
+python train_internal_alignment.py hparams/train_internal_alignment.yaml --data_folder=/your_folder/LJSpeech-1.1
+```
+The data preparation includes a grapheme-to-phoneme process for the entire corpus which may take several hours. Training takes about 5 minutes/epoch on 1 * V100 32G.
+
+The training logs are available [here](https://www.dropbox.com/scl/fo/4ctkc6jjas3uij9dzcwta/h?rlkey=i0k086d77flcsdx40du1ppm2d&dl=0).
+
+You can find the pre-trained model with an easy-inference function on [HuggingFace](https://huggingface.co/speechbrain/tts-fastspeech2-internal-alignment-ljspeech).
+
 # HiFiGAN (Vocoder)
 The subfolder "vocoder/hifi_gan/" contains the [HiFiGAN vocoder](https://arxiv.org/pdf/2010.05646.pdf).
 The vocoder is a neural network that converts a spectrogram into a waveform (it can be used on top of Tacotron2/FastSpeech2).
@@ -90,6 +108,40 @@ For inference, by setting `fast_sampling: True` , a fast sampling can be realize
 
 You can find the pre-trained model with an easy-inference function on [HuggingFace](https://huggingface.co/speechbrain/tts-diffwave-ljspeech).
 
+# K-means (Quantization)
+The subfolder "quantization" contains K-means clustering model. The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as target for speech-to-speech translation or as input for HiFiGAN Unit. By default, we use the 6th layer of HuBERT and set `k=100`.
+
+To run this recipe, please first install the extra-dependencies :
+```
+pip install -r extra_requirements.txt
+```
+Then go into the "quantization" folder and run:
+```
+python train.py hparams/kmeans.yaml --data_folder=/path/to/LJspeech
+```
+
+# HiFiGAN Unit Vocoder
+The subfolder "vocoder/hifi_gan_unit/" contains the [HiFiGAN Unit vocoder](https://arxiv.org/abs/2104.00355). This vocoder is a neural network designed to transform discrete self-supervised representations into waveform data and is suitable for speech-to-speech translation on top of CVSS/S2ST models. The discrete representations required by the vocoder are learned using k-means quantization, as previously described. Please ensure that you have executed the quantization step before proceeding with this script.
+
+To run this recipe successfully, start by installing the necessary extra dependencies:
+
+```bash
+pip install -r extra_requirements.txt
+```
+
+Then, navigate to the "vocoder/hifi_gan_unit/" folder and run the following command:
+
+```bash
+python train.py hparams/train.yaml --kmeans_folder=/path/to/Kmeans/ckpt --data_folder=/path/to/LJspeech
+```
+
+The `kmeans_folder` should be specified based on the results of the previous quantization step (e.g., ../../quantization/results/kmeans/4321/save).
+
+Training typically takes around 4 minutes per epoch when using an NVIDIA A100 40G.
+
+You can access the pre-trained model, along with an easy-to-use inference function, on [HuggingFace](https://huggingface.co/speechbrain/tts-hifigan-unit-hubert-l6-k100-ljspeech).
+
+
 # **About SpeechBrain**
 - Website: https://speechbrain.github.io/
 - Code: https://github.com/speechbrain/speechbrain/
diff --git a/recipes/LJSpeech/TTS/extra_requirements.txt b/recipes/LJSpeech/TTS/extra_requirements.txt
index dc4f0dad87145b63127745707127f4feab7a1614..78ddcc94f2a6c3dbea83b1c5399ed77c94773196 100644
--- a/recipes/LJSpeech/TTS/extra_requirements.txt
+++ b/recipes/LJSpeech/TTS/extra_requirements.txt
@@ -1,3 +1,5 @@
+# Needed only for quantization
+scikit-learn
 # Needed only with use_tensorboard=True
 # torchvision is needed to save spectrograms
 tensorboard
diff --git a/recipes/LJSpeech/TTS/fastspeech2/hparams/train.yaml b/recipes/LJSpeech/TTS/fastspeech2/hparams/train.yaml
index df83de5abfa9f328ef7c5d523f9aa233301dec03..cf20e364f163578f749b40fa0feb3bac80f5a2a3 100644
--- a/recipes/LJSpeech/TTS/fastspeech2/hparams/train.yaml
+++ b/recipes/LJSpeech/TTS/fastspeech2/hparams/train.yaml
@@ -13,7 +13,7 @@ __set_seed: !apply:torch.manual_seed [!ref <seed>]
 output_folder: !ref results/fastspeech2/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
-epochs: 1000
+epochs: 500
 train_spn_predictor_epochs: 8
 progress_samples: True
 progress_sample_path: !ref <output_folder>/samples
diff --git a/recipes/LJSpeech/TTS/fastspeech2/hparams/train_internal_alignment.yaml b/recipes/LJSpeech/TTS/fastspeech2/hparams/train_internal_alignment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a203a93e484644d5056ad7b13edc8b5d04e699e7
--- /dev/null
+++ b/recipes/LJSpeech/TTS/fastspeech2/hparams/train_internal_alignment.yaml
@@ -0,0 +1,284 @@
+############################################################################
+# Model: FastSpeech2 with internal alignment
+# Tokens: Phonemes (ARPABET)
+# Dataset: LJSpeech
+# Authors: Yingzhi Wang 2023
+# ############################################################################
+
+###################################
+# Experiment Parameters and setup #
+###################################
+seed: 1234
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/fastspeech2_internal_alignment/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+epochs: 500
+progress_samples: True
+progress_sample_path: !ref <output_folder>/samples
+progress_samples_min_run: 10
+progress_samples_interval: 10
+progress_batch_sample_size: 4
+
+#################################
+# Data files and pre-processing #
+#################################
+data_folder: !PLACEHOLDER # e.g., /data/Database/LJSpeech-1.1
+
+train_json: !ref <save_folder>/train.json
+valid_json: !ref <save_folder>/valid.json
+test_json: !ref <save_folder>/test.json
+
+splits: ["train", "valid"]
+split_ratio: [90, 10]
+
+skip_prep: False
+
+################################
+# Audio Parameters             #
+################################
+sample_rate: 22050
+hop_length: 256
+win_length: null
+n_mel_channels: 80
+n_fft: 1024
+mel_fmin: 0.0
+mel_fmax: 8000.0
+power: 1
+norm: "slaney"
+mel_scale: "slaney"
+dynamic_range_compression: True
+mel_normalized: False
+min_max_energy_norm: True
+min_f0: 65  #(torchaudio pyin values)
+max_f0: 2093 #(torchaudio pyin values)
+
+################################
+# Optimization Hyperparameters #
+################################
+learning_rate: 0.0001
+weight_decay: 0.000001
+max_grad_norm: 1.0
+batch_size: 16 #minimum 2
+betas: [0.9, 0.998]
+num_workers_train: 16
+num_workers_valid: 4
+
+################################
+# Model Parameters and model   #
+################################
+# Input parameters
+lexicon:
+    - "AA"
+    - "AE"
+    - "AH"
+    - "AO"
+    - "AW"
+    - "AY"
+    - "B"
+    - "CH"
+    - "D"
+    - "DH"
+    - "EH"
+    - "ER"
+    - "EY"
+    - "F"
+    - "G"
+    - "HH"
+    - "IH"
+    - "IY"
+    - "JH"
+    - "K"
+    - "L"
+    - "M"
+    - "N"
+    - "NG"
+    - "OW"
+    - "OY"
+    - "P"
+    - "R"
+    - "S"
+    - "SH"
+    - "T"
+    - "TH"
+    - "UH"
+    - "UW"
+    - "V"
+    - "W"
+    - "Y"
+    - "Z"
+    - "ZH"
+    - "-"
+    - "!"
+    - "'"
+    - "("
+    - ")"
+    - ","
+    - "."
+    - ":"
+    - ";"
+    - "?"
+    - " "
+
+n_symbols: 52 #fixed depending on symbols in the lexicon (+1 for a dummy symbol used for padding, +1 for unknown)
+padding_idx: 0
+
+hidden_channels: 512
+# Encoder parameters
+enc_num_layers: 4
+enc_num_head: 2
+enc_d_model: !ref <hidden_channels>
+enc_ffn_dim: 1024
+enc_k_dim: !ref <hidden_channels>
+enc_v_dim: !ref <hidden_channels>
+enc_dropout: 0.2
+
+# Aligner parameters
+in_query_channels: 80
+in_key_channels: !ref <hidden_channels> # 512 in the paper
+attn_channels: 80
+temperature: 0.0005
+
+# Decoder parameters
+dec_num_layers: 4
+dec_num_head: 2
+dec_d_model: !ref <hidden_channels>
+dec_ffn_dim: 1024
+dec_k_dim: !ref <hidden_channels>
+dec_v_dim: !ref <hidden_channels>
+dec_dropout: 0.2
+
+# Postnet parameters
+postnet_embedding_dim: 512
+postnet_kernel_size: 5
+postnet_n_convolutions: 5
+postnet_dropout: 0.2
+
+# common
+normalize_before: True
+ffn_type: 1dcnn #1dcnn or ffn
+ffn_cnn_kernel_size_list: [9, 1]
+
+# variance predictor
+dur_pred_kernel_size: 3
+pitch_pred_kernel_size: 3
+energy_pred_kernel_size: 3
+variance_predictor_dropout: 0.5
+
+#model
+model: !new:speechbrain.lobes.models.FastSpeech2.FastSpeech2WithAlignment
+    enc_num_layers: !ref <enc_num_layers>
+    enc_num_head: !ref <enc_num_head>
+    enc_d_model: !ref <enc_d_model>
+    enc_ffn_dim: !ref <enc_ffn_dim>
+    enc_k_dim: !ref <enc_k_dim>
+    enc_v_dim: !ref <enc_v_dim>
+    enc_dropout: !ref <enc_dropout>
+    in_query_channels: !ref <in_query_channels>
+    in_key_channels: !ref <in_key_channels>
+    attn_channels: !ref <attn_channels>
+    temperature: !ref <temperature>
+    dec_num_layers: !ref <dec_num_layers>
+    dec_num_head: !ref <dec_num_head>
+    dec_d_model: !ref <dec_d_model>
+    dec_ffn_dim: !ref <dec_ffn_dim>
+    dec_k_dim: !ref <dec_k_dim>
+    dec_v_dim: !ref <dec_v_dim>
+    dec_dropout: !ref <dec_dropout>
+    normalize_before: !ref <normalize_before>
+    ffn_type: !ref <ffn_type>
+    ffn_cnn_kernel_size_list: !ref <ffn_cnn_kernel_size_list>
+    n_char: !ref <n_symbols>
+    n_mels: !ref <n_mel_channels>
+    postnet_embedding_dim: !ref <postnet_embedding_dim>
+    postnet_kernel_size: !ref <postnet_kernel_size>
+    postnet_n_convolutions: !ref <postnet_n_convolutions>
+    postnet_dropout: !ref <postnet_dropout>
+    padding_idx: !ref <padding_idx>
+    dur_pred_kernel_size: !ref <dur_pred_kernel_size>
+    pitch_pred_kernel_size: !ref <pitch_pred_kernel_size>
+    energy_pred_kernel_size: !ref <energy_pred_kernel_size>
+    variance_predictor_dropout: !ref <variance_predictor_dropout>
+
+mel_spectogram: !name:speechbrain.lobes.models.FastSpeech2.mel_spectogram
+    sample_rate: !ref <sample_rate>
+    hop_length: !ref <hop_length>
+    win_length: !ref <win_length>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mel_channels>
+    f_min: !ref <mel_fmin>
+    f_max: !ref <mel_fmax>
+    power: !ref <power>
+    normalized: !ref <mel_normalized>
+    min_max_energy_norm: !ref <min_max_energy_norm>
+    norm: !ref <norm>
+    mel_scale: !ref <mel_scale>
+    compression: !ref <dynamic_range_compression>
+
+criterion: !new:speechbrain.lobes.models.FastSpeech2.LossWithAlignment
+    log_scale_durations: True
+    duration_loss_weight: 1.0
+    pitch_loss_weight: 1.0
+    energy_loss_weight: 1.0
+    ssim_loss_weight: 1.0
+    mel_loss_weight: 1.0
+    postnet_mel_loss_weight: 1.0
+    aligner_loss_weight: 1.0
+    binary_alignment_loss_weight: 0.2
+    binary_alignment_loss_warmup_epochs: 1
+    binary_alignment_loss_max_epochs: 80
+
+vocoder: "hifi-gan"
+pretrained_vocoder: True
+vocoder_source: speechbrain/tts-hifigan-ljspeech
+vocoder_download_path: tmpdir_vocoder
+
+modules:
+    model: !ref <model>
+
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    drop_last: False  #True #False
+    num_workers: !ref <num_workers_train>
+    shuffle: True
+    collate_fn: !new:speechbrain.lobes.models.FastSpeech2.TextMelCollateWithAlignment
+
+valid_dataloader_opts:
+    batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers_valid>
+    shuffle: False
+    collate_fn: !new:speechbrain.lobes.models.FastSpeech2.TextMelCollateWithAlignment
+
+#optimizer
+opt_class: !name:torch.optim.Adam
+    lr: !ref <learning_rate>
+    weight_decay: !ref <weight_decay>
+    betas: !ref <betas>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <learning_rate>
+    n_warmup_steps: 4000
+
+
+#epoch object
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <epochs>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+#checkpointer
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        lr_annealing: !ref <noam_annealing>
+        counter: !ref <epoch_counter>
+
+input_encoder: !new:speechbrain.dataio.encoder.TextEncoder
+
+progress_sample_logger: !new:speechbrain.utils.train_logger.ProgressSampleLogger
+    output_path: !ref <progress_sample_path>
+    batch_sample_size: !ref <progress_batch_sample_size>
+    formats:
+        raw_batch: raw
diff --git a/recipes/LJSpeech/TTS/fastspeech2/ljspeech_prepare.py b/recipes/LJSpeech/TTS/fastspeech2/ljspeech_prepare.py
index 2de5a21a8daef2535958e950a9d95e0853bf6ba7..2f703273cb4b6cd6e3e08358a18ab94d1fe0746a 120000
--- a/recipes/LJSpeech/TTS/fastspeech2/ljspeech_prepare.py
+++ b/recipes/LJSpeech/TTS/fastspeech2/ljspeech_prepare.py
@@ -1 +1 @@
-../ljspeech_prepare.py
\ No newline at end of file
+../../ljspeech_prepare.py
\ No newline at end of file
diff --git a/recipes/LJSpeech/TTS/fastspeech2/train.py b/recipes/LJSpeech/TTS/fastspeech2/train.py
index b3a16bd1aef305e87b3d53495032629a260ddcd8..d8efe4666fd8339e119821b316cbc9bd678feddf 100644
--- a/recipes/LJSpeech/TTS/fastspeech2/train.py
+++ b/recipes/LJSpeech/TTS/fastspeech2/train.py
@@ -18,11 +18,11 @@ import logging
 import torchaudio
 import numpy as np
 import speechbrain as sb
-from speechbrain.pretrained import HIFIGAN
+from speechbrain.inference.vocoders import HIFIGAN
 from pathlib import Path
 from hyperpyyaml import load_hyperpyyaml
 from speechbrain.utils.data_utils import scalarize
-from speechbrain.pretrained import GraphemeToPhoneme
+from speechbrain.inference.text import GraphemeToPhoneme
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
 logger = logging.getLogger(__name__)
@@ -31,7 +31,8 @@ logger = logging.getLogger(__name__)
 class FastSpeech2Brain(sb.Brain):
     def on_fit_start(self):
         """Gets called at the beginning of ``fit()``, on multiple processes
-        if ``distributed_count > 0`` and backend is ddp and initializes statistics"""
+        if ``distributed_count > 0`` and backend is ddp and initializes statistics
+        """
         self.hparams.progress_sample_logger.reset()
         self.last_epoch = 0
         self.last_batch = None
@@ -97,20 +98,10 @@ class FastSpeech2Brain(sb.Brain):
             spn_preds,
         )
 
-    def fit_batch(self, batch):
-        """Fits a single batch
-        Arguments
-        ---------
-        batch: tuple
-            a training batch
-        Returns
-        -------
-        loss: torch.Tensor
-            detached loss
-        """
-        result = super().fit_batch(batch)
-        self.hparams.noam_annealing(self.optimizer)
-        return result
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            self.hparams.noam_annealing(self.optimizer)
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss given the predicted and targeted outputs.
@@ -264,8 +255,7 @@ class FastSpeech2Brain(sb.Brain):
             )
 
     def run_inference(self):
-        """Produces a sample in inference mode with predicted durations.
-        """
+        """Produces a sample in inference mode with predicted durations."""
         if self.last_batch is None:
             return
         tokens, *_, labels, _ = self.last_batch
@@ -420,17 +410,17 @@ class FastSpeech2Brain(sb.Brain):
 
     def batch_to_device(self, batch, return_metadata=False):
         """Transfers the batch to the target device
-            Arguments
-            ---------
-            batch: tuple
-                the batch to use
-            return_metadata: bool
-                indicates whether the metadata should be returned
-            Returns
-            -------
-            batch: tuple
-                the batch on the correct device
-            """
+        Arguments
+        ---------
+        batch: tuple
+            the batch to use
+        return_metadata: bool
+            indicates whether the metadata should be returned
+        Returns
+        -------
+        batch: tuple
+            the batch on the correct device
+        """
 
         (
             text_padded,
@@ -510,7 +500,6 @@ def dataio_prepare(hparams):
         spn_labels,
         last_phoneme_flags,
     ):
-
         durs = np.load(dur)
         durs_seq = torch.from_numpy(durs).int()
         label_phoneme = label_phoneme.strip()
@@ -583,7 +572,6 @@ def main():
         overrides=overrides,
     )
 
-    sys.path.append("../")
     from ljspeech_prepare import prepare_ljspeech
 
     sb.utils.distributed.run_on_main(
diff --git a/recipes/LJSpeech/TTS/fastspeech2/train_internal_alignment.py b/recipes/LJSpeech/TTS/fastspeech2/train_internal_alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..91dee98e80681498b18cfc0313fdf097fd3adb18
--- /dev/null
+++ b/recipes/LJSpeech/TTS/fastspeech2/train_internal_alignment.py
@@ -0,0 +1,401 @@
+"""
+Recipe for training the FastSpeech2 Text-To-Speech model
+Instead of using pre-extracted phoneme durations from MFA,
+This recipe trains an internal alignment from scratch, as introduced in:
+https://arxiv.org/pdf/2108.10447.pdf (One TTS Alignment To Rule Them All)
+To run this recipe, do the following:
+# python train_internal_alignment.py hparams/train_internal_alignment.yaml
+
+Authors
+* Yingzhi Wang 2023
+"""
+
+import os
+import sys
+import torch
+import logging
+import torchaudio
+import numpy as np
+import speechbrain as sb
+from speechbrain.inference.vocoders import HIFIGAN
+from pathlib import Path
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.data_utils import scalarize
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+logger = logging.getLogger(__name__)
+
+
+class FastSpeech2Brain(sb.Brain):
+    def on_fit_start(self):
+        """Gets called at the beginning of ``fit()``, on multiple processes
+        if ``distributed_count > 0`` and backend is ddp and initializes statistics"""
+        self.hparams.progress_sample_logger.reset()
+        self.last_epoch = 0
+        self.last_batch = None
+        self.last_loss_stats = {}
+        return super().on_fit_start()
+
+    def compute_forward(self, batch, stage):
+        """Computes the forward pass
+        Arguments
+        ---------
+        batch: str
+            a single batch
+        stage: speechbrain.Stage
+            the training stage
+        Returns
+        -------
+        the model output
+        """
+        inputs, _ = self.batch_to_device(batch)
+        return self.hparams.model(*inputs)
+
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing and logging."""
+        if should_step:
+            self.hparams.noam_annealing(self.optimizer)
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss given the predicted and targeted outputs.
+        Arguments
+        ---------
+        predictions : torch.Tensor
+            The model generated spectrograms and other metrics from `compute_forward`.
+        batch : PaddedBatch
+            This batch object contains all the relevant tensors for computation.
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
+        Returns
+        -------
+        loss : torch.Tensor
+            A one-element tensor used for backpropagating the gradient.
+        """
+        x, y, metadata = self.batch_to_device(batch, return_metadata=True)
+        self.last_batch = [x[0], y[-1], y[-2], predictions[0], *metadata]
+        self._remember_sample([x[0], *y, *metadata], predictions)
+        loss = self.hparams.criterion(
+            predictions, y, self.hparams.epoch_counter.current
+        )
+        self.last_loss_stats[stage] = scalarize(loss)
+        return loss["total_loss"]
+
+    def _remember_sample(self, batch, predictions):
+        """Remembers samples of spectrograms and the batch for logging purposes
+        Arguments
+        ---------
+        batch: tuple
+            a training batch
+        predictions: tuple
+            predictions (raw output of the FastSpeech2
+             model)
+        """
+        (
+            phoneme_padded,
+            mel_padded,
+            pitch,
+            energy,
+            output_lengths,
+            input_lengths,
+            labels,
+            wavs,
+        ) = batch
+
+        (
+            mel_post,
+            postnet_mel_out,
+            predict_durations,
+            predict_pitch,
+            average_pitch,
+            predict_energy,
+            average_energy,
+            predict_mel_lens,
+            alignment_durations,
+            alignment_soft,
+            alignment_logprob,
+            alignment_mas,
+        ) = predictions
+        self.hparams.progress_sample_logger.remember(
+            target=self.process_mel(mel_padded, output_lengths),
+            output=self.process_mel(postnet_mel_out, output_lengths),
+            raw_batch=self.hparams.progress_sample_logger.get_batch_sample(
+                {
+                    "tokens": phoneme_padded,
+                    "input_lengths": input_lengths,
+                    "mel_target": mel_padded,
+                    "mel_out": postnet_mel_out,
+                    "mel_lengths": predict_mel_lens,
+                    "durations": alignment_durations,
+                    "predict_durations": predict_durations,
+                    "labels": labels,
+                    "wavs": wavs,
+                }
+            ),
+        )
+
+    def process_mel(self, mel, len, index=0):
+        """Converts a mel spectrogram to one that can be saved as an image
+        sample  = sqrt(exp(mel))
+        Arguments
+        ---------
+        mel: torch.Tensor
+            the mel spectrogram (as used in the model)
+        len: int
+            length of the mel spectrogram
+        index: int
+            batch index
+        Returns
+        -------
+        mel: torch.Tensor
+            the spectrogram, for image saving purposes
+        """
+        assert mel.dim() == 3
+        return torch.sqrt(torch.exp(mel[index][: len[index]]))
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of an epoch.
+        Arguments
+        ---------
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
+        stage_loss : float
+            The average loss for all of the data processed in this stage.
+        epoch : int
+            The currently-starting epoch. This is passed
+            `None` during the test stage.
+        """
+        # At the end of validation, we can write
+        if stage == sb.Stage.VALID:
+            # Update learning rate
+            self.last_epoch = epoch
+            lr = self.hparams.noam_annealing.current_lr
+
+            # The train_logger writes a summary to stdout and to the logfile.
+            self.hparams.train_logger.log_stats(  # 1#2#
+                stats_meta={"Epoch": epoch, "lr": lr},
+                train_stats=self.last_loss_stats[sb.Stage.TRAIN],
+                valid_stats=self.last_loss_stats[sb.Stage.VALID],
+            )
+            output_progress_sample = (
+                self.hparams.progress_samples
+                and epoch % self.hparams.progress_samples_interval == 0
+                and epoch >= self.hparams.progress_samples_min_run
+            )
+
+            if output_progress_sample:
+                logger.info("Saving predicted samples")
+                inference_mel, mel_lens = self.run_inference()
+                self.hparams.progress_sample_logger.save(epoch)
+                self.run_vocoder(inference_mel, mel_lens)
+            # Save the current checkpoint and delete previous checkpoints.
+            # UNCOMMENT THIS
+            self.checkpointer.save_and_keep_only(
+                meta=self.last_loss_stats[stage], min_keys=["total_loss"],
+            )
+        # We also write statistics about test data spectogramto stdout and to the logfile.
+        if stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                {"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=self.last_loss_stats[sb.Stage.TEST],
+            )
+
+    def run_inference(self):
+        """Produces a sample in inference mode with predicted durations."""
+        if self.last_batch is None:
+            return
+        tokens, *_ = self.last_batch
+
+        (
+            _,
+            postnet_mel_out,
+            _,
+            _,
+            _,
+            _,
+            _,
+            predict_mel_lens,
+            _,
+            _,
+            _,
+            _,
+        ) = self.hparams.model(tokens)
+        self.hparams.progress_sample_logger.remember(
+            infer_output=self.process_mel(
+                postnet_mel_out, [len(postnet_mel_out[0])]
+            )
+        )
+        return postnet_mel_out, predict_mel_lens
+
+    def run_vocoder(self, inference_mel, mel_lens):
+        """Uses a pretrained vocoder to generate audio from predicted mel
+        spectogram. By default, uses speechbrain hifigan.
+        Arguments
+        ---------
+        inference_mel: torch.Tensor
+            predicted mel from fastspeech2 inference
+        mel_lens: torch.Tensor
+            predicted mel lengths from fastspeech2 inference
+            used to mask the noise from padding
+        """
+        if self.last_batch is None:
+            return
+        *_, wavs = self.last_batch
+
+        inference_mel = inference_mel[: self.hparams.progress_batch_sample_size]
+        mel_lens = mel_lens[0 : self.hparams.progress_batch_sample_size]
+        assert (
+            self.hparams.vocoder == "hifi-gan"
+            and self.hparams.pretrained_vocoder is True
+        ), "Specified vocoder not supported yet"
+        logger.info(
+            f"Generating audio with pretrained {self.hparams.vocoder_source} vocoder"
+        )
+        hifi_gan = HIFIGAN.from_hparams(
+            source=self.hparams.vocoder_source,
+            savedir=self.hparams.vocoder_download_path,
+        )
+        waveforms = hifi_gan.decode_batch(
+            inference_mel.transpose(2, 1), mel_lens, self.hparams.hop_length
+        )
+        for idx, wav in enumerate(waveforms):
+
+            path = os.path.join(
+                self.hparams.progress_sample_path,
+                str(self.last_epoch),
+                f"pred_{Path(wavs[idx]).stem}.wav",
+            )
+            torchaudio.save(path, wav, self.hparams.sample_rate)
+
+    def batch_to_device(self, batch, return_metadata=False):
+        """Transfers the batch to the target device
+        Arguments
+        ---------
+        batch: tuple
+            the batch to use
+        Returns
+        -------
+        batch: tuple
+            the batch on the correct device
+        """
+
+        (
+            phoneme_padded,
+            input_lengths,
+            mel_padded,
+            pitch_padded,
+            energy_padded,
+            output_lengths,
+            # len_x,
+            labels,
+            wavs,
+        ) = batch
+
+        # durations = durations.to(self.device, non_blocking=True).long()
+        phonemes = phoneme_padded.to(self.device, non_blocking=True).long()
+        input_lengths = input_lengths.to(self.device, non_blocking=True).long()
+        spectogram = mel_padded.to(self.device, non_blocking=True).float()
+        pitch = pitch_padded.to(self.device, non_blocking=True).float()
+        energy = energy_padded.to(self.device, non_blocking=True).float()
+        mel_lengths = output_lengths.to(self.device, non_blocking=True).long()
+        x = (phonemes, spectogram, pitch, energy)
+        y = (spectogram, pitch, energy, mel_lengths, input_lengths)
+        metadata = (labels, wavs)
+        if return_metadata:
+            return x, y, metadata
+        return x, y
+
+
+def dataio_prepare(hparams):
+    "Creates the datasets and their data processing pipelines."
+    # Load lexicon
+    lexicon = hparams["lexicon"]
+    input_encoder = hparams.get("input_encoder")
+
+    # add a dummy symbol for idx 0 - used for padding.
+    lexicon = ["@@"] + lexicon
+    input_encoder.update_from_iterable(lexicon, sequence_input=False)
+    input_encoder.add_unk()
+
+    # load audio, text and durations on the fly; encode audio and text.
+    @sb.utils.data_pipeline.takes("wav", "phonemes", "pitch")
+    @sb.utils.data_pipeline.provides("mel_text_pair")
+    def audio_pipeline(wav, phonemes, pitch):
+        phoneme_seq = input_encoder.encode_sequence_torch(phonemes).int()
+
+        audio, fs = torchaudio.load(wav)
+        audio = audio.squeeze()
+        mel, energy = hparams["mel_spectogram"](audio=audio)
+
+        pitch = np.load(pitch)
+        pitch = torch.from_numpy(pitch)
+        pitch = pitch[: mel.shape[-1]]
+        return phoneme_seq, mel, pitch, energy, len(phoneme_seq), len(mel)
+
+    # define splits and load it as sb dataset
+    datasets = {}
+
+    for dataset in hparams["splits"]:
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=hparams[f"{dataset}_json"],
+            replacements={"data_root": hparams["data_folder"]},
+            dynamic_items=[audio_pipeline],
+            output_keys=["mel_text_pair", "wav", "label", "pitch"],
+        )
+    return datasets
+
+
+def main():
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    from ljspeech_prepare import prepare_ljspeech
+
+    sb.utils.distributed.run_on_main(
+        prepare_ljspeech,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "save_folder": hparams["save_folder"],
+            "splits": hparams["splits"],
+            "split_ratio": hparams["split_ratio"],
+            "model_name": hparams["model"].__class__.__name__,
+            "seed": hparams["seed"],
+            "pitch_n_fft": hparams["n_fft"],
+            "pitch_hop_length": hparams["hop_length"],
+            "pitch_min_f0": hparams["min_f0"],
+            "pitch_max_f0": hparams["max_f0"],
+            "skip_prep": hparams["skip_prep"],
+            "use_custom_cleaner": True,
+            "device": "cuda",
+        },
+    )
+
+    datasets = dataio_prepare(hparams)
+
+    # Brain class initialization
+    fastspeech2_brain = FastSpeech2Brain(
+        modules=hparams["modules"],
+        opt_class=hparams["opt_class"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+    # Training
+    fastspeech2_brain.fit(
+        fastspeech2_brain.hparams.epoch_counter,
+        datasets["train"],
+        datasets["valid"],
+        train_loader_kwargs=hparams["train_dataloader_opts"],
+        valid_loader_kwargs=hparams["valid_dataloader_opts"],
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/recipes/LJSpeech/TTS/quantization/hparams/kmeans.yaml b/recipes/LJSpeech/TTS/quantization/hparams/kmeans.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e92a54f49e43d09ff6ac93717b2575ffffd94770
--- /dev/null
+++ b/recipes/LJSpeech/TTS/quantization/hparams/kmeans.yaml
@@ -0,0 +1,46 @@
+# ############################################################################
+# Model: K-means
+# Training: LJSpeech
+# Authors:  Duret, Jarod 2023
+# ############################################################################
+
+
+###################################
+# Experiment Parameters and setup #
+###################################
+seed: 4321
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+no_cuda: False
+output_folder: !ref ./results/kmeans/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+#################################
+# Data files and pre-processing #
+#################################
+data_folder: !PLACEHOLDER
+train_json: !ref <save_folder>/train.json
+valid_json: !ref <save_folder>/valid.json
+splits: [train, valid]
+split_ratio: [90, 10]
+skip_prep: False
+sample_pct: 0.2
+sample_rate: 16000
+
+# URL for the HuggingFace model we want to load
+encoder_hub: facebook/hubert-base-ls960
+encoder_folder: !ref <save_folder>/pretrained_models
+layer: 6
+
+####################
+# Model Parameters #
+####################
+num_clusters: 100
+init: k-means++
+max_iter: 100
+batch_size: 10000
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
+out_kmeans_model_path: !ref <save_folder>/kmeans.ckpt
diff --git a/recipes/LJSpeech/TTS/quantization/ljspeech_prepare.py b/recipes/LJSpeech/TTS/quantization/ljspeech_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..2f703273cb4b6cd6e3e08358a18ab94d1fe0746a
--- /dev/null
+++ b/recipes/LJSpeech/TTS/quantization/ljspeech_prepare.py
@@ -0,0 +1 @@
+../../ljspeech_prepare.py
\ No newline at end of file
diff --git a/recipes/LJSpeech/TTS/quantization/train.py b/recipes/LJSpeech/TTS/quantization/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffd3a9211a974a66feb9b2c59a8a7af2870f3aba
--- /dev/null
+++ b/recipes/LJSpeech/TTS/quantization/train.py
@@ -0,0 +1,219 @@
+"""
+Script to train K-means clustering model on self-supervised representations.
+
+To run this recipe, do the following:
+> python train.py hparams/kmeans.yaml --data_folder=/path/to/LJspeech
+
+Authors
+ * Jarod Duret 2023
+"""
+
+
+import sys
+import logging
+import time
+import random
+import itertools
+import pathlib as pl
+
+import joblib
+import torch
+import torchaudio
+import tqdm
+import numpy as np
+from sklearn.cluster import MiniBatchKMeans
+from hyperpyyaml import load_hyperpyyaml
+import speechbrain as sb
+from ljspeech_prepare import prepare_ljspeech
+from speechbrain.lobes.models.huggingface_wav2vec import HuggingFaceWav2Vec2
+
+
+def setup_logger():
+    """Set up a logger with a log format and logging level."""
+    log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+    logging.basicConfig(format=log_format, level=logging.INFO)
+    logger = logging.getLogger(__name__)
+    return logger
+
+
+def get_device(use_cuda):
+    """Determine and return the appropriate device for computation."""
+    use_cuda = use_cuda and torch.cuda.is_available()
+    print("\n" + "=" * 30)
+    print("USE_CUDA SET TO: {}".format(use_cuda))
+    print("CUDA AVAILABLE?: {}".format(torch.cuda.is_available()))
+    print("=" * 30 + "\n")
+    return torch.device("cuda" if use_cuda else "cpu")
+
+
+def np_array(tensor):
+    """Convert a Pytorch tensor to a Numpy array."""
+    tensor = tensor.squeeze(0)
+    tensor = tensor.detach().cpu()
+    return tensor.numpy()
+
+
+def fetch_data(splits, sample_pct, seed=1234):
+    """Fetch data from specified splits for k-means training."""
+    ds_splits = {}
+    for split in splits:
+        key = f"{split.parent}_{split.stem}"
+        ds_splits[key] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=split, output_keys=["id", "wav"],
+        )
+
+    data = list(itertools.chain(*ds_splits.values()))
+    random.seed(seed)
+    if sample_pct < 1.0:
+        data = random.sample(data, int(sample_pct * len(data)))
+    return iter(data), len(data)
+
+
+def extract_features(
+    model, layer, splits, sample_pct, flatten, device="cpu", sample_rate=16000
+):
+    """Extract features from audio using a pre-trained model."""
+    data, num_files = fetch_data(splits, sample_pct)
+    features_list = []
+    id_list = []
+
+    for item in tqdm.tqdm(data, total=num_files):
+        wav = item["wav"]
+        with torch.no_grad():
+            info = torchaudio.info(wav)
+            audio = sb.dataio.dataio.read_audio(wav)
+            audio = torchaudio.transforms.Resample(
+                info.sample_rate, sample_rate,
+            )(audio)
+            audio = audio.unsqueeze(0).to(device)
+            feats = model.extract_features(audio)
+            feats = feats[layer]
+            feats = np_array(feats)
+        features_list.append(feats)
+        id_list.append(item["id"])
+
+    if flatten:
+        return np.concatenate(features_list), id_list
+
+    return features_list, id_list
+
+
+def fetch_kmeans_model(
+    n_clusters,
+    init,
+    max_iter,
+    batch_size,
+    tol,
+    max_no_improvement,
+    n_init,
+    reassignment_ratio,
+    random_state,
+):
+    """Return a k-means clustering model with specified parameters."""
+    return MiniBatchKMeans(
+        n_clusters=n_clusters,
+        init=init,
+        max_iter=max_iter,
+        batch_size=batch_size,
+        tol=tol,
+        max_no_improvement=max_no_improvement,
+        n_init=n_init,
+        reassignment_ratio=reassignment_ratio,
+        random_state=random_state,
+        verbose=1,
+        compute_labels=True,
+        init_size=None,
+    )
+
+
+def train_kmeans(kmeans_model, features_batch):
+    """Train a k-means clustering model using the provided features."""
+    start_time = time.time()
+    kmeans_model.fit(features_batch)
+    time_taken = round((time.time() - start_time) // 60, 2)
+    return kmeans_model, time_taken
+
+
+if __name__ == "__main__":
+    logger = setup_logger()
+
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    sb.utils.distributed.run_on_main(
+        prepare_ljspeech,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "save_folder": hparams["save_folder"],
+            "splits": hparams["splits"],
+            "seed": hparams["seed"],
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Fetch device
+    device = get_device(not hparams["no_cuda"])
+
+    logger.info(f"Loading encoder model from HF hub: {hparams['encoder_hub']}")
+    encoder = HuggingFaceWav2Vec2(
+        hparams["encoder_hub"],
+        hparams["encoder_folder"],
+        output_all_hiddens=True,
+        output_norm=False,
+        freeze_feature_extractor=True,
+        freeze=True,
+    ).to(device)
+
+    splits = []
+    data_folder = pl.Path(hparams["save_folder"])
+    for split in hparams["splits"]:
+        splits.append(data_folder / f"{split}.json")
+
+    # Features loading/extraction for K-means
+    logger.info("Extracting acoustic features ...")
+
+    (features_batch, idx) = extract_features(
+        model=encoder,
+        layer=hparams["layer"],
+        splits=splits,
+        sample_pct=hparams["sample_pct"],
+        flatten=True,
+        device=device,
+        sample_rate=hparams["sample_rate"],
+    )
+
+    logger.info(f"Features shape = {features_batch.shape}\n")
+
+    # Train and save Kmeans model
+    kmeans_model = fetch_kmeans_model(
+        n_clusters=hparams["num_clusters"],
+        init=hparams["init"],
+        max_iter=hparams["max_iter"],
+        batch_size=hparams["batch_size"],
+        tol=hparams["tol"],
+        max_no_improvement=hparams["max_no_improvement"],
+        n_init=hparams["n_init"],
+        reassignment_ratio=hparams["reassignment_ratio"],
+        random_state=hparams["seed"],
+    )
+
+    logger.info("Starting k-means training...")
+    kmeans_model, time_taken = train_kmeans(
+        kmeans_model=kmeans_model, features_batch=features_batch
+    )
+    logger.info(f"k-means model trained in {time_taken} minutes")
+    inertia = -kmeans_model.score(features_batch) / len(features_batch)
+    logger.info(f"Total intertia: {round(inertia, 2)}\n")
+
+    logger.info(f"Saving k-means model to {hparams['out_kmeans_model_path']}")
+    joblib.dump(kmeans_model, open(hparams["out_kmeans_model_path"], "wb"))
diff --git a/recipes/LJSpeech/TTS/tacotron2/ljspeech_prepare.py b/recipes/LJSpeech/TTS/tacotron2/ljspeech_prepare.py
index 2de5a21a8daef2535958e950a9d95e0853bf6ba7..2f703273cb4b6cd6e3e08358a18ab94d1fe0746a 120000
--- a/recipes/LJSpeech/TTS/tacotron2/ljspeech_prepare.py
+++ b/recipes/LJSpeech/TTS/tacotron2/ljspeech_prepare.py
@@ -1 +1 @@
-../ljspeech_prepare.py
\ No newline at end of file
+../../ljspeech_prepare.py
\ No newline at end of file
diff --git a/recipes/LJSpeech/TTS/tacotron2/train.py b/recipes/LJSpeech/TTS/tacotron2/train.py
index c7166e41cf20bd4a2078dcb219fb390e7c7796d5..2cb370b15616a46e830dc204cd420ab2935f1d19 100644
--- a/recipes/LJSpeech/TTS/tacotron2/train.py
+++ b/recipes/LJSpeech/TTS/tacotron2/train.py
@@ -33,7 +33,8 @@ class Tacotron2Brain(sb.Brain):
 
     def on_fit_start(self):
         """Gets called at the beginning of ``fit()``, on multiple processes
-        if ``distributed_count > 0`` and backend is ddp and initializes statistics"""
+        if ``distributed_count > 0`` and backend is ddp and initializes statistics
+        """
         self.hparams.progress_sample_logger.reset()
         self.last_epoch = 0
         self.last_batch = None
@@ -62,22 +63,10 @@ class Tacotron2Brain(sb.Brain):
         max_input_length = input_lengths.max().item()
         return self.modules.model(inputs, alignments_dim=max_input_length)
 
-    def fit_batch(self, batch):
-        """Fits a single batch and applies annealing
-
-        Arguments
-        ---------
-        batch: tuple
-            a training batch
-
-        Returns
-        -------
-        loss: torch.Tensor
-            detached loss
-        """
-        result = super().fit_batch(batch)
-        self.hparams.lr_annealing(self.optimizer)
-        return result
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            self.hparams.lr_annealing(self.optimizer)
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss given the predicted and targeted outputs.
@@ -339,14 +328,12 @@ def dataio_prepare(hparams):
 
 
 if __name__ == "__main__":
-
     # Load hyperparameters file with command-line overrides
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -357,7 +344,6 @@ if __name__ == "__main__":
         overrides=overrides,
     )
 
-    sys.path.append("../")
     from ljspeech_prepare import prepare_ljspeech
 
     sb.utils.distributed.run_on_main(
diff --git a/recipes/LJSpeech/TTS/vocoder/diffwave/hparams/train.yaml b/recipes/LJSpeech/TTS/vocoder/diffwave/hparams/train.yaml
index 6fea32e4cce9cb8050aa185b63dd3b64d0856c8f..82483a9e92aa22ff3898919b6c22c02f01c50d8a 100644
--- a/recipes/LJSpeech/TTS/vocoder/diffwave/hparams/train.yaml
+++ b/recipes/LJSpeech/TTS/vocoder/diffwave/hparams/train.yaml
@@ -66,9 +66,6 @@ test_dataloader_opts:
     batch_size: 1
     num_workers: !ref <num_workers>
 
-dataloader_options:
-    batch_size: !ref <batch_size>
-
 use_tensorboard: False
 tensorboard_logs: !ref <output_folder>/logs/
 
diff --git a/recipes/LJSpeech/TTS/vocoder/diffwave/ljspeech_prepare.py b/recipes/LJSpeech/TTS/vocoder/diffwave/ljspeech_prepare.py
index 2f703273cb4b6cd6e3e08358a18ab94d1fe0746a..069e475ec698877801c96112bf9b7637fdbd82d0 120000
--- a/recipes/LJSpeech/TTS/vocoder/diffwave/ljspeech_prepare.py
+++ b/recipes/LJSpeech/TTS/vocoder/diffwave/ljspeech_prepare.py
@@ -1 +1 @@
-../../ljspeech_prepare.py
\ No newline at end of file
+../../../ljspeech_prepare.py
\ No newline at end of file
diff --git a/recipes/LJSpeech/TTS/vocoder/diffwave/train.py b/recipes/LJSpeech/TTS/vocoder/diffwave/train.py
index db87a56f13731c774bac40ca30249cdb4876ef37..b920abfca60922af3ed6bc3071b8919156179d73 100644
--- a/recipes/LJSpeech/TTS/vocoder/diffwave/train.py
+++ b/recipes/LJSpeech/TTS/vocoder/diffwave/train.py
@@ -77,18 +77,6 @@ class DiffWaveBrain(sb.Brain):
         self.last_loss_stats[stage] = {"loss": loss}
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        loss = super().fit_batch(batch)
-        return loss.detach().cpu()
-
-    def evaluate_batch(self, batch, stage):
-        """Evaluate one batch
-        """
-        out = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(out, batch, stage=stage)
-        return loss.detach().cpu()
-
     def on_fit_start(self):
         """Gets called at the beginning of ``fit()``, on multiple processes
         if ``distributed_count > 0`` and backend is ddp and initializes statistics
@@ -364,5 +352,5 @@ if __name__ == "__main__":
         test_stats = diffusion_brain.evaluate(
             test_set=datasets["test"],
             min_key="error",
-            test_loader_kwargs=hparams["dataloader_options"],
+            test_loader_kwargs=hparams["test_dataloader_opts"],
         )
diff --git a/recipes/LJSpeech/TTS/vocoder/hifi_gan/ljspeech_prepare.py b/recipes/LJSpeech/TTS/vocoder/hifi_gan/ljspeech_prepare.py
index 2f703273cb4b6cd6e3e08358a18ab94d1fe0746a..069e475ec698877801c96112bf9b7637fdbd82d0 120000
--- a/recipes/LJSpeech/TTS/vocoder/hifi_gan/ljspeech_prepare.py
+++ b/recipes/LJSpeech/TTS/vocoder/hifi_gan/ljspeech_prepare.py
@@ -1 +1 @@
-../../ljspeech_prepare.py
\ No newline at end of file
+../../../ljspeech_prepare.py
\ No newline at end of file
diff --git a/recipes/LJSpeech/TTS/vocoder/hifi_gan/train.py b/recipes/LJSpeech/TTS/vocoder/hifi_gan/train.py
index 6c10c64f64dce906223371db8af615d8281b5c94..4c1dcfeabf43a3ed1a9c19d71f06fc7aba409873 100644
--- a/recipes/LJSpeech/TTS/vocoder/hifi_gan/train.py
+++ b/recipes/LJSpeech/TTS/vocoder/hifi_gan/train.py
@@ -48,8 +48,7 @@ class HifiGanBrain(sb.Brain):
         return (y_g_hat, scores_fake, feats_fake, scores_real, feats_real)
 
     def compute_objectives(self, predictions, batch, stage):
-        """Computes and combines generator and discriminator losses
-        """
+        """Computes and combines generator and discriminator losses"""
         batch = batch.to(self.device)
         x, _ = batch.mel
         y, _ = batch.sig
@@ -64,7 +63,7 @@ class HifiGanBrain(sb.Brain):
 
         y_hat, scores_fake, feats_fake, scores_real, feats_real = predictions
         loss_g = self.hparams.generator_loss(
-            y_hat, y, scores_fake, feats_fake, feats_real
+            stage, y_hat, y, scores_fake, feats_fake, feats_real
         )
         loss_d = self.hparams.discriminator_loss(scores_fake, scores_real)
         loss = {**loss_g, **loss_d}
@@ -72,8 +71,7 @@ class HifiGanBrain(sb.Brain):
         return loss
 
     def fit_batch(self, batch):
-        """Train discriminator and generator adversarially
-        """
+        """Train discriminator and generator adversarially"""
 
         batch = batch.to(self.device)
         y, _ = batch.sig
@@ -104,8 +102,7 @@ class HifiGanBrain(sb.Brain):
         return loss_g.detach().cpu()
 
     def evaluate_batch(self, batch, stage):
-        """Evaluate one batch
-        """
+        """Evaluate one batch"""
         out = self.compute_forward(batch, stage=stage)
         loss = self.compute_objectives(out, batch, stage=stage)
         loss_g = loss["G_loss"]
@@ -172,8 +169,7 @@ class HifiGanBrain(sb.Brain):
         y_hat, scores_fake, feats_fake, scores_real, feats_real = predictions
 
     def on_stage_end(self, stage, stage_loss, epoch):
-        """Gets called at the end of a stage (TRAIN, VALID, Or TEST)
-        """
+        """Gets called at the end of a stage (TRAIN, VALID, Or TEST)"""
         if stage == sb.Stage.VALID:
             # Update learning rate
             self.scheduler_g.step()
@@ -334,14 +330,12 @@ def dataio_prepare(hparams):
 
 
 if __name__ == "__main__":
-
     # Load hyperparameters file with command-line overrides
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -352,7 +346,6 @@ if __name__ == "__main__":
         overrides=overrides,
     )
 
-    sys.path.append("../../")
     from ljspeech_prepare import prepare_ljspeech
 
     sb.utils.distributed.run_on_main(
diff --git a/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/extract_code.py b/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/extract_code.py
new file mode 100644
index 0000000000000000000000000000000000000000..0035b5f3bf0cecff032aaec7c6032492403f997b
--- /dev/null
+++ b/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/extract_code.py
@@ -0,0 +1,203 @@
+"""
+Apply K-means clustering over acoustic features to extract speech units for HiFi-GAN training.
+
+Authors
+ * Jarod Duret 2023
+"""
+
+import logging
+import json
+import pathlib as pl
+
+import joblib
+import torch
+import torchaudio
+import numpy as np
+from tqdm import tqdm
+import speechbrain as sb
+from speechbrain.dataio.dataio import (
+    load_pkl,
+    save_pkl,
+)
+from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import Wav2Vec2
+
+OPT_FILE = "opt_ljspeech_extract.pkl"
+TRAIN_JSON = "train.json"
+VALID_JSON = "valid.json"
+TEST_JSON = "test.json"
+
+
+def setup_logger():
+    """Set up a logger with a log format and logging level."""
+    log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+    logging.basicConfig(format=log_format, level=logging.INFO)
+    logger = logging.getLogger(__name__)
+    return logger
+
+
+def get_device(use_cuda):
+    """Determine and return the appropriate device for computation."""
+    use_cuda = use_cuda and torch.cuda.is_available()
+    print("\n" + "=" * 30)
+    print("USE_CUDA SET TO: {}".format(use_cuda))
+    print("CUDA AVAILABLE?: {}".format(torch.cuda.is_available()))
+    print("=" * 30 + "\n")
+    return torch.device("cuda" if use_cuda else "cpu")
+
+
+def np_array(tensor):
+    """Convert a Pytorch tensor to a Numpy array."""
+    tensor = tensor.squeeze(0)
+    tensor = tensor.detach().cpu()
+    return tensor.numpy()
+
+
+def skip(splits, save_folder, conf):
+    """
+    Detects if the ljspeech data_extraction has been already done.
+    If the extraction has been done, we can skip it.
+
+    Returns
+    -------
+    bool
+        if True, the preparation phase can be skipped.
+        if False, it must be done.
+    """
+    # Checking json files
+    skip = True
+
+    split_files = {
+        "train": TRAIN_JSON,
+        "valid": VALID_JSON,
+        "test": TEST_JSON,
+    }
+
+    for split in splits:
+        if not (save_folder / split_files[split]).exists():
+            skip = False
+
+    #  Checking saved options
+    save_opt = save_folder / OPT_FILE
+    if skip is True:
+        if save_opt.is_file():
+            opts_old = load_pkl(save_opt.as_posix())
+            if opts_old == conf:
+                skip = True
+            else:
+                skip = False
+        else:
+            skip = False
+    return skip
+
+
+def extract_ljspeech(
+    data_folder,
+    splits,
+    kmeans_folder,
+    encoder,
+    layer,
+    save_folder,
+    sample_rate=16000,
+    skip_extract=False,
+):
+    """
+    Extract speech units for HiFi-GAN training on the LJspeech datasets.
+
+    Arguments
+    ---------
+    data_folder : str
+        Path to the folder where the original LJspeech dataset is stored.
+    splits : list
+        List of splits to prepare.
+    kmeans_folder: str
+        Path to the folder where the k-means model checkpoint is stored.
+    encoder: str
+        Url to the model used as feature extractor.
+    layer: int
+        Layer from which features are extracted.
+    save_folder: str
+        Path to the folder where the speech units are stored.
+    sample_rate: int
+        LjSpeech dataset sample rate
+    skip_extract: Bool
+        If True, skip extraction.
+
+    Example
+    -------
+    >>> from recipes.LJSpeech.S2ST.extract_code import extract_ljspeech
+    >>> data_folder = 'data/LJspeech/'
+    >>> splits = ['train', 'valid']
+    >>> kmeans_folder = ./Quantization/results/kmeans/4321/save
+    >>> encoder = facebook/hubert-base-ls960
+    >>> layer = 6
+    >>> save_folder = 'save/'
+    >>> extract_ljspeech(data_folder, splits, kmeans_folder, encoder, layer, save_folder)
+    """
+    logger = setup_logger()
+
+    if skip_extract:
+        return
+    # Create configuration for easily skipping code extraction stage
+    conf = {
+        "data_folder": data_folder,
+        "splits": splits,
+        "save_folder": save_folder,
+        "kmeans_folder": kmeans_folder,
+        "encoder": encoder,
+        "layer": layer,
+    }
+
+    save_folder = pl.Path(save_folder)
+    # Check if this phase is already done (if so, skip it)
+    if skip(splits, save_folder, conf):
+        logger.info("Skipping code extraction, completed in previous run.")
+        return
+
+    # Fetch device
+    device = get_device(use_cuda=True)
+
+    save_opt = save_folder / OPT_FILE
+    data_folder = pl.Path(data_folder)
+    kmeans_folder = pl.Path(kmeans_folder)
+    kmeans_ckpt = kmeans_folder / "kmeans.ckpt"
+    encoder_save_path = kmeans_folder / "pretrained_models"
+    code_folder = save_folder / "codes"
+    code_folder.mkdir(parents=True, exist_ok=True)
+
+    logger.info(f"Loading encoder: {encoder} ...")
+    encoder = Wav2Vec2(
+        encoder,
+        encoder_save_path.as_posix(),
+        output_all_hiddens=True,
+        output_norm=False,
+        freeze_feature_extractor=True,
+        freeze=True,
+    ).to(device)
+
+    # K-means model
+    logger.info(f"Loading K-means model from {kmeans_ckpt} ...")
+    kmeans_model = joblib.load(open(kmeans_ckpt, "rb"))
+    kmeans_model.verbose = False
+
+    for split in splits:
+        dataset_path = data_folder / f"{split}.json"
+        logger.info(f"Reading dataset from {dataset_path} ...")
+        meta_json = json.load(open(dataset_path))
+        for key in tqdm(meta_json.keys()):
+            item = meta_json[key]
+            wav = item["wav"]
+            with torch.no_grad():
+                info = torchaudio.info(wav)
+                audio = sb.dataio.dataio.read_audio(wav)
+                audio = torchaudio.transforms.Resample(
+                    info.sample_rate, sample_rate,
+                )(audio)
+                audio = audio.unsqueeze(0).to(device)
+                feats = encoder.extract_features(audio)
+                feats = feats[layer]
+                feats = np_array(feats)
+            pred = kmeans_model.predict(feats)
+            np.save(code_folder / f"{key}.npy", pred)
+
+    logger.info("Extraction completed.")
+    save_pkl(conf, save_opt)
diff --git a/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/hparams/train.yaml b/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/hparams/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..68516efa2095f86267668fdc4279b14a8a4609f2
--- /dev/null
+++ b/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/hparams/train.yaml
@@ -0,0 +1,221 @@
+############################################################################
+# Model: Unit HiFi-GAN
+# Tokens: discrete speech units (K-means)
+# Training: LJSpeech (English)
+# Authors: Jarod Duret, Yingzhi Wang
+# ############################################################################
+
+
+###################################
+# Experiment Parameters and setup #
+###################################
+seed: 4321
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref ./results/hifi_gan/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+progress_sample_path: !ref <output_folder>/samples
+epochs: 200
+keep_checkpoint_interval: 50
+use_tensorboard: False
+
+#################################
+# Data files and pre-processing #
+#################################
+data_folder: !PLACEHOLDER # e.g, /datasets/ljspeech 16k!
+train_json: !ref <save_folder>/train.json
+valid_json: !ref <save_folder>/valid.json
+test_json: !ref <save_folder>/test.json
+
+splits: ["train", "valid"]
+split_ratio: [90, 10]
+skip_prep: False
+
+kmeans_folder: !PLACEHOLDER # e.g, ../../quantization/results/kmeans/4321/save
+codes_folder: !ref <save_folder>/codes
+encoder_hub: facebook/hubert-base-ls960
+layer: 6
+
+################################
+# Audio Parameters             #
+################################
+
+segment_size: 8960
+code_hop_size: 320
+sample_rate: 16000
+
+
+hop_length: 256
+win_length: 1024
+n_mel_channels: 80
+n_fft: 1024
+mel_fmin: 0.0
+mel_fmax: 8000
+mel_normalized: False
+power: 1
+norm: "slaney"
+mel_scale: "slaney"
+dynamic_range_compression: True
+
+################################
+# Optimization Hyperparameters #
+################################
+learning_rate: 0.0002
+weight_decay: 0.9999
+adam_b1: 0.8
+adam_b2: 0.99
+batch_size: 32 #minimum 32
+
+train_dataloader_opts:
+  batch_size: !ref <batch_size>
+  drop_last: False
+  num_workers: 8
+
+valid_dataloader_opts:
+  batch_size: 1
+  num_workers: 8
+
+test_dataloader_opts:
+  batch_size: 1
+  num_workers: 8
+
+################################
+# Model Parameters and model   #
+################################
+duration_predictor: True
+
+# embedding params
+num_embeddings: 101 # K-means size + 1 for padding
+embedding_dim: 128
+
+# generator params
+in_channels: 128
+out_channels: 1
+
+var_pred_hidden_dim: 128
+var_pred_kernel_size: 3
+var_pred_dropout: 0.5
+
+###########################################################################################################################################################
+# version | resblock_type | upsample_kernel_sizes | upsample_factors | resblock_kernel_sizes | upsample_initial_channel | resblock_dilation_sizes
+#    1    |      "1"      |      [16,16,4,4]      |    [8, 8, 2, 2]  |     [3, 7, 11]        |           512            | [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+#    2    |      "1"      |      [16,16,4,4]      |    [8, 8, 2, 2]  |     [3, 7, 11]        |           128            | [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+#    3    |      "2"      |       [16,16,8]       |      [8,8,4]     |       [3,5,7]         |           256            |     [[1,2], [2,6], [3,12]]
+###########################################################################################################################################################
+resblock_type: "1"
+resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+resblock_kernel_sizes: [3, 7, 11]
+upsample_kernel_sizes: [11, 8, 8, 4, 4]
+upsample_initial_channel: 512
+upsample_factors: [5, 4, 4, 2, 2]
+
+inference_padding: 5
+cond_channels: 0
+conv_post_bias: True
+
+mel_spectogram: !name:speechbrain.lobes.models.HifiGAN.mel_spectogram
+  sample_rate: !ref <sample_rate>
+  hop_length: !ref <hop_length>
+  win_length: !ref <win_length>
+  n_fft: !ref <n_fft>
+  n_mels: !ref <n_mel_channels>
+  f_min: !ref <mel_fmin>
+  f_max: !ref <mel_fmax>
+  power: !ref <power>
+  normalized: !ref <mel_normalized>
+  norm: !ref <norm>
+  mel_scale: !ref <mel_scale>
+  compression: !ref <dynamic_range_compression>
+
+generator: !new:speechbrain.lobes.models.HifiGAN.UnitHifiganGenerator
+  in_channels: !ref <in_channels>
+  out_channels: !ref <out_channels>
+  resblock_type: !ref <resblock_type>
+  resblock_dilation_sizes: !ref <resblock_dilation_sizes>
+  resblock_kernel_sizes: !ref <resblock_kernel_sizes>
+  upsample_kernel_sizes: !ref <upsample_kernel_sizes>
+  upsample_initial_channel: !ref <upsample_initial_channel>
+  upsample_factors: !ref <upsample_factors>
+  inference_padding: !ref <inference_padding>
+  cond_channels: !ref <cond_channels>
+  conv_post_bias: !ref <conv_post_bias>
+  num_embeddings: !ref <num_embeddings>
+  embedding_dim: !ref <embedding_dim>
+  duration_predictor: !ref <duration_predictor>
+  var_pred_hidden_dim: !ref <var_pred_hidden_dim>
+  var_pred_kernel_size: !ref <var_pred_kernel_size>
+  var_pred_dropout: !ref <var_pred_dropout>
+
+discriminator: !new:speechbrain.lobes.models.HifiGAN.HifiganDiscriminator
+
+modules:
+  generator: !ref <generator>
+  discriminator: !ref <discriminator>
+
+#generator loss
+stft_loss: null
+mseg_loss: !new:speechbrain.lobes.models.HifiGAN.MSEGLoss
+feat_match_loss: !new:speechbrain.lobes.models.HifiGAN.MelganFeatureLoss
+l1_spec_loss: !new:speechbrain.lobes.models.HifiGAN.L1SpecLoss
+  sample_rate: !ref <sample_rate>
+  hop_length: !ref <hop_length>
+  win_length: !ref <win_length>
+  n_mel_channels: !ref <n_mel_channels>
+  n_fft: !ref <n_fft>
+  n_stft: !ref <n_fft> // 2 + 1
+  mel_fmin: !ref <mel_fmin>
+  mel_fmax: null
+  mel_normalized: !ref <mel_normalized>
+  power: !ref <power>
+  dynamic_range_compression: !ref <dynamic_range_compression>
+mseg_dur_loss: True
+
+generator_loss: !new:speechbrain.lobes.models.HifiGAN.GeneratorLoss
+  stft_loss: !ref <stft_loss>
+  stft_loss_weight: 0
+  mseg_loss: !ref <mseg_loss>
+  mseg_loss_weight: 1
+  feat_match_loss: !ref <feat_match_loss>
+  feat_match_loss_weight: 10
+  l1_spec_loss: !ref  <l1_spec_loss>
+  l1_spec_loss_weight: 45
+  mseg_dur_loss: !ref <mseg_dur_loss>
+  mseg_dur_loss_weight: 1
+
+#discriminator loss
+msed_loss: !new:speechbrain.lobes.models.HifiGAN.MSEDLoss
+
+discriminator_loss: !new:speechbrain.lobes.models.HifiGAN.DiscriminatorLoss
+  msed_loss: !ref <msed_loss>
+
+#optimizer
+opt_class_generator: !name:torch.optim.AdamW
+  lr: !ref <learning_rate>
+  betas: [!ref <adam_b1>, !ref <adam_b2>]
+
+opt_class_discriminator: !name:torch.optim.AdamW
+  lr: !ref <learning_rate>
+  betas: [!ref <adam_b1>, !ref <adam_b2>]
+
+sch_class_generator: !name:torch.optim.lr_scheduler.ExponentialLR
+  gamma: !ref <weight_decay>
+  last_epoch: -1
+
+sch_class_discriminator: !name:torch.optim.lr_scheduler.ExponentialLR
+  gamma: !ref <weight_decay>
+  last_epoch: -1
+
+#epoch object
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+  limit: !ref <epochs>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+  save_file: !ref <train_log>
+
+#checkpointer
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+  checkpoints_dir: !ref <save_folder>
+  recoverables:
+    generator: !ref <generator>
+    discriminator: !ref <discriminator>
+    counter: !ref <epoch_counter>
diff --git a/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/ljspeech_prepare.py b/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/ljspeech_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..069e475ec698877801c96112bf9b7637fdbd82d0
--- /dev/null
+++ b/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/ljspeech_prepare.py
@@ -0,0 +1 @@
+../../../ljspeech_prepare.py
\ No newline at end of file
diff --git a/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/train.py b/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..19d4aec8136cc50a89475f0e3c37708f1c03f8d3
--- /dev/null
+++ b/recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/train.py
@@ -0,0 +1,529 @@
+#!/usr/bin/env python3
+"""Recipe for training a hifi-gan vocoder on self-supervised representations.
+For more details about hifi-gan: https://arxiv.org/pdf/2010.05646.pdf
+For more details about speech synthesis using self-supervised representations: https://arxiv.org/pdf/2104.00355.pdf
+
+To run this recipe, do the following:
+> python train.py hparams/train.yaml --kmeans_folder=/path/to/Kmeans/ckpt --data_folder=/path/to/LJspeech
+
+Authors
+ * Jarod Duret 2023
+ * Yingzhi WANG 2022
+"""
+
+import sys
+import copy
+import random
+import pathlib as pl
+
+from hyperpyyaml import load_hyperpyyaml
+import speechbrain as sb
+from speechbrain.utils.data_utils import scalarize
+import torch
+import torchaudio
+import numpy as np
+
+
+class HifiGanBrain(sb.Brain):
+    def compute_forward(self, batch, stage):
+        """The forward function, generates synthesized waveforms,
+        calculates the scores and the features of the discriminator
+        for synthesized waveforms and real waveforms.
+
+        Arguments
+        ---------
+        batch : torch.Tensor or tensors
+            An element from the dataloader, including inputs for processing.
+        stage : Stage
+            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
+
+        """
+        batch = batch.to(self.device)
+
+        x, _ = batch.code
+        y, _ = batch.sig
+
+        # generate sythesized waveforms
+        y_g_hat, (log_dur_pred, log_dur) = self.modules.generator(x)
+        y_g_hat = y_g_hat[:, :, : y.size(2)]
+
+        # get scores and features from discriminator for real and synthesized waveforms
+        scores_fake, feats_fake = self.modules.discriminator(y_g_hat.detach())
+        scores_real, feats_real = self.modules.discriminator(y)
+
+        return (
+            y_g_hat,
+            scores_fake,
+            feats_fake,
+            scores_real,
+            feats_real,
+            log_dur_pred,
+            log_dur,
+        )
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss given the predicted and targeted outputs.
+        Arguments
+        ---------
+        predictions : torch.Tensor
+            The model generated spectrograms and other metrics from `compute_forward`.
+        batch : PaddedBatch
+            This batch object contains all the relevant tensors for computation.
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
+        Returns
+        -------
+        loss : torch.Tensor
+            A one-element tensor used for backpropagating the gradient.
+        """
+        batch = batch.to(self.device)
+
+        x, _ = batch.code
+        y, _ = batch.sig
+
+        # Hold on to the batch for the inference sample. This is needed because
+        # the infernece sample is run from on_stage_end only, where
+        # batch information is not available
+        self.last_batch = (x, y)
+
+        (
+            y_hat,
+            scores_fake,
+            feats_fake,
+            scores_real,
+            feats_real,
+            log_dur_pred,
+            log_dur,
+        ) = predictions
+
+        loss_g = self.hparams.generator_loss(
+            stage,
+            y_hat,
+            y,
+            scores_fake,
+            feats_fake,
+            feats_real,
+            log_dur_pred,
+            log_dur,
+        )
+        loss_d = self.hparams.discriminator_loss(scores_fake, scores_real)
+        loss = {**loss_g, **loss_d}
+        self.last_loss_stats[stage] = scalarize(loss)
+        return loss
+
+    def fit_batch(self, batch):
+        """Fits a single batch.
+        Arguments
+        ---------
+        batch: tuple
+            a training batch
+        Returns
+        -------
+        loss: torch.Tensor
+            detached loss
+        """
+        batch = batch.to(self.device)
+        y, _ = batch.sig
+
+        outputs = self.compute_forward(batch, sb.core.Stage.TRAIN)
+        (
+            y_g_hat,
+            scores_fake,
+            feats_fake,
+            scores_real,
+            feats_real,
+            log_dur_pred,
+            log_dur,
+        ) = outputs
+        # calculate discriminator loss with the latest updated generator
+        loss_d = self.compute_objectives(outputs, batch, sb.core.Stage.TRAIN)[
+            "D_loss"
+        ]
+        # First train the discriminator
+        self.optimizer_d.zero_grad()
+        loss_d.backward()
+        self.optimizer_d.step()
+
+        # calculate generator loss with the latest updated discriminator
+        scores_fake, feats_fake = self.modules.discriminator(y_g_hat)
+        scores_real, feats_real = self.modules.discriminator(y)
+        outputs = (
+            y_g_hat,
+            scores_fake,
+            feats_fake,
+            scores_real,
+            feats_real,
+            log_dur_pred,
+            log_dur,
+        )
+        loss_g = self.compute_objectives(outputs, batch, sb.core.Stage.TRAIN)[
+            "G_loss"
+        ]
+        # Then train the generator
+        self.optimizer_g.zero_grad()
+        loss_g.backward()
+        self.optimizer_g.step()
+
+        return loss_g.detach().cpu()
+
+    def evaluate_batch(self, batch, stage):
+        """Evaluate one batch.
+
+        Arguments
+        ---------
+        batch : list of torch.Tensors
+            Batch of data to use for evaluation. Default implementation assumes
+            this batch has two elements: inputs and targets.
+        stage : Stage
+            The stage of the experiment: Stage.VALID, Stage.TEST
+
+        Returns
+        -------
+        detached loss
+        """
+        out = self.compute_forward(batch, stage=stage)
+        loss = self.compute_objectives(out, batch, stage=stage)
+        loss_g = loss["G_loss"]
+        return loss_g.detach().cpu()
+
+    def on_fit_start(self):
+        """Gets called at the beginning of ``fit()``, on multiple processes
+        if ``distributed_count > 0`` and backend is ddp and initializes statistics.
+        """
+        self.last_epoch = 0
+        self.last_batch = None
+        self.last_loss_stats = {}
+        return super().on_fit_start()
+
+    def init_optimizers(self):
+        """Called during ``on_fit_start()``, initialize optimizers
+        after parameters are fully configured (e.g. DDP, jit).
+        """
+        if self.opt_class is not None:
+            (
+                opt_g_class,
+                opt_d_class,
+                sch_g_class,
+                sch_d_class,
+            ) = self.opt_class
+
+            self.optimizer_g = opt_g_class(self.modules.generator.parameters())
+            self.optimizer_d = opt_d_class(
+                self.modules.discriminator.parameters()
+            )
+
+            self.optimizers_dict = {
+                "optimizer_g": self.optimizer_g,
+                "optimizer_d": self.optimizer_d,
+            }
+            self.scheduler_g = sch_g_class(self.optimizer_g)
+            self.scheduler_d = sch_d_class(self.optimizer_d)
+
+            if self.checkpointer is not None:
+                self.checkpointer.add_recoverable(
+                    "optimizer_g", self.optimizer_g
+                )
+                self.checkpointer.add_recoverable(
+                    "optimizer_d", self.optimizer_d
+                )
+                self.checkpointer.add_recoverable(
+                    "scheduler_g", self.scheduler_d
+                )
+                self.checkpointer.add_recoverable(
+                    "scheduler_d", self.scheduler_d
+                )
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of an epoch.
+
+        Arguments
+        ---------
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
+        stage_loss : float
+            The average loss for all of the data processed in this stage.
+        epoch : int
+            The currently-starting epoch. This is passed
+            `None` during the test stage.
+        """
+        if stage == sb.Stage.VALID:
+            # Update learning rate
+            self.scheduler_g.step()
+            self.scheduler_d.step()
+            lr_g = self.optimizer_g.param_groups[-1]["lr"]
+            lr_d = self.optimizer_d.param_groups[-1]["lr"]
+
+            self.hparams.train_logger.log_stats(  # 1#2#
+                stats_meta={"Epoch": epoch, "lr_g": lr_g, "lr_d": lr_d},
+                train_stats=self.last_loss_stats[sb.Stage.TRAIN],
+                valid_stats=self.last_loss_stats[sb.Stage.VALID],
+            )
+            # The tensorboard_logger writes a summary to stdout and to the logfile.
+            if self.hparams.use_tensorboard:
+                self.tensorboard_logger.log_stats(
+                    stats_meta={"Epoch": epoch, "lr_g": lr_g, "lr_d": lr_d},
+                    train_stats=self.last_loss_stats[sb.Stage.TRAIN],
+                    valid_stats=self.last_loss_stats[sb.Stage.VALID],
+                )
+
+            # Save the current checkpoint and delete previous checkpoints.
+            epoch_metadata = {
+                **{"epoch": epoch},
+                **self.last_loss_stats[sb.Stage.VALID],
+            }
+            if self.checkpointer is not None:
+                self.checkpointer.save_and_keep_only(
+                    meta=epoch_metadata,
+                    end_of_epoch=True,
+                    min_keys=["loss"],
+                    ckpt_predicate=(
+                        lambda ckpt: (
+                            ckpt.meta["epoch"]
+                            % self.hparams.keep_checkpoint_interval
+                            != 0
+                        )
+                    )
+                    if self.hparams.keep_checkpoint_interval is not None
+                    else None,
+                )
+
+            self.run_inference_sample("Valid", epoch)
+
+        # We also write statistics about test data to stdout and to the TensorboardLogger.
+        if stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(  # 1#2#
+                {"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=self.last_loss_stats[sb.Stage.TEST],
+            )
+            if self.hparams.use_tensorboard:
+                self.tensorboard_logger.log_stats(
+                    {"Epoch loaded": self.hparams.epoch_counter.current},
+                    test_stats=self.last_loss_stats[sb.Stage.TEST],
+                )
+            self.run_inference_sample("Test")
+
+    def run_inference_sample(self, name, epoch):
+        """Produces a sample in inference mode.
+        This is called when producing samples.
+
+        Arguments
+        ---------
+        name: str
+            the name of the saved audio folder
+        epoch: int or str
+            the epoch number (used in file path calculations)
+            or "test" for test stage
+        """
+        with torch.no_grad():
+            if self.last_batch is None:
+                return
+            x, y = self.last_batch
+
+            # Preparing model for inference by removing weight norm
+            inference_generator = copy.deepcopy(self.hparams.generator)
+            inference_generator.remove_weight_norm()
+            if inference_generator.duration_predictor:
+                x = torch.unique_consecutive(x, dim=1)
+            sig_out = inference_generator.inference(x)
+            spec_out = self.hparams.mel_spectogram(
+                audio=sig_out.squeeze(0).cpu()
+            )
+        if self.hparams.use_tensorboard:
+            self.tensorboard_logger.log_audio(
+                f"{name}/audio_target", y.squeeze(0), self.hparams.sample_rate
+            )
+            self.tensorboard_logger.log_audio(
+                f"{name}/audio_pred",
+                sig_out.squeeze(0),
+                self.hparams.sample_rate,
+            )
+            self.tensorboard_logger.log_figure(f"{name}/mel_target", x)
+            self.tensorboard_logger.log_figure(f"{name}/mel_pred", spec_out)
+        else:
+            # folder name is the current epoch for validation and "test" for test
+            folder = (
+                self.hparams.epoch_counter.current
+                if name == "Valid"
+                else "test"
+            )
+            self.save_audio("target", y.squeeze(0), folder)
+            self.save_audio("synthesized", sig_out.squeeze(0), folder)
+
+    def save_audio(self, name, data, epoch):
+        """Saves a single wav file.
+
+        Arguments
+        ---------
+        name: str
+            the name of the saved audio
+        data: torch.Tensor
+            the  wave data to save
+        epoch: int or str
+            the epoch number (used in file path calculations)
+            or "test" for test stage
+        """
+        target_path = pl.Path(self.hparams.progress_sample_path) / str(epoch)
+        target_path.mkdir(parents=True, exist_ok=True)
+        file_name = target_path / f"{name}.wav"
+        torchaudio.save(file_name, data.cpu(), 16000)
+
+
+def sample_interval(seqs, segment_size):
+    "This function sample an interval of audio and code according to segment size."
+    N = max([v.shape[-1] for v in seqs])
+    seq_len = segment_size if segment_size > 0 else N
+    hops = [N // v.shape[-1] for v in seqs]
+    lcm = np.lcm.reduce(hops)
+    interval_start = 0
+    interval_end = N // lcm - seq_len // lcm
+    start_step = random.randint(interval_start, interval_end)
+
+    new_seqs = []
+    for i, v in enumerate(seqs):
+        start = start_step * (lcm // hops[i])
+        end = (start_step + seq_len // lcm) * (lcm // hops[i])
+        new_seqs += [v[..., start:end]]
+
+    return new_seqs
+
+
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions.
+    """
+    segment_size = hparams["segment_size"]
+    code_hop_size = hparams["code_hop_size"]
+    code_folder = pl.Path(hparams["codes_folder"])
+
+    # Define audio pipeline:
+    @sb.utils.data_pipeline.takes("id", "wav", "segment")
+    @sb.utils.data_pipeline.provides("code", "sig")
+    def audio_pipeline(utt_id, wav, segment):
+        info = torchaudio.info(wav)
+        audio = sb.dataio.dataio.read_audio(wav)
+        audio = torchaudio.transforms.Resample(
+            info.sample_rate, hparams["sample_rate"],
+        )(audio)
+
+        code = np.load(code_folder / f"{utt_id}.npy")
+        code = torch.IntTensor(code)
+
+        # Maps indices from the range [0, k] to [1, k+1]
+        code = code + 1
+
+        # Trim end of audio
+        code_length = min(audio.shape[0] // code_hop_size, code.shape[0])
+        code = code[:code_length]
+        audio = audio[: code_length * code_hop_size]
+
+        while audio.shape[0] < segment_size:
+            audio = torch.hstack([audio, audio])
+            code = torch.hstack([code, code])
+
+        audio = audio.unsqueeze(0)
+        if segment:
+            audio, code = sample_interval([audio, code], segment_size)
+
+        return code, audio
+
+    datasets = {}
+    data_info = {
+        "train": hparams["train_json"],
+        "valid": hparams["valid_json"],
+        "test": hparams["test_json"],
+    }
+    for dataset in hparams["splits"]:
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=data_info[dataset],
+            replacements={"data_root": hparams["data_folder"]},
+            dynamic_items=[audio_pipeline],
+            output_keys=["id", "code", "sig"],
+        )
+
+    return datasets
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # If --distributed_launch then
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    from ljspeech_prepare import prepare_ljspeech
+
+    sb.utils.distributed.run_on_main(
+        prepare_ljspeech,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "save_folder": hparams["save_folder"],
+            "splits": hparams["splits"],
+            "split_ratio": hparams["split_ratio"],
+            "seed": hparams["seed"],
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    from extract_code import extract_ljspeech
+
+    sb.utils.distributed.run_on_main(
+        extract_ljspeech,
+        kwargs={
+            "data_folder": hparams["save_folder"],
+            "splits": hparams["splits"],
+            "kmeans_folder": hparams["kmeans_folder"],
+            "encoder": hparams["encoder_hub"],
+            "layer": hparams["layer"],
+            "save_folder": hparams["save_folder"],
+            "sample_rate": hparams["sample_rate"],
+            "skip_extract": hparams["skip_prep"],
+        },
+    )
+
+    datasets = dataio_prepare(hparams)
+
+    # Brain class initialization
+    hifi_gan_brain = HifiGanBrain(
+        modules=hparams["modules"],
+        opt_class=[
+            hparams["opt_class_generator"],
+            hparams["opt_class_discriminator"],
+            hparams["sch_class_generator"],
+            hparams["sch_class_discriminator"],
+        ],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    if hparams["use_tensorboard"]:
+        hifi_gan_brain.tensorboard_logger = sb.utils.train_logger.TensorboardLogger(
+            save_dir=hparams["output_folder"] + "/tensorboard"
+        )
+
+    # Training
+    hifi_gan_brain.fit(
+        hifi_gan_brain.hparams.epoch_counter,
+        train_set=datasets["train"],
+        valid_set=datasets["valid"],
+        train_loader_kwargs=hparams["train_dataloader_opts"],
+        valid_loader_kwargs=hparams["valid_dataloader_opts"],
+    )
+
+    # Test
+    if "test" in datasets:
+        hifi_gan_brain.evaluate(
+            datasets["test"],
+            test_loader_kwargs=hparams["test_dataloader_opts"],
+        )
diff --git a/recipes/LJSpeech/TTS/ljspeech_prepare.py b/recipes/LJSpeech/ljspeech_prepare.py
similarity index 80%
rename from recipes/LJSpeech/TTS/ljspeech_prepare.py
rename to recipes/LJSpeech/ljspeech_prepare.py
index f2186ecff95915bb21e23d96bd29505b2cc138b7..98c367eafe4cee37659246b2aca6f3c112ee94ee 100644
--- a/recipes/LJSpeech/TTS/ljspeech_prepare.py
+++ b/recipes/LJSpeech/ljspeech_prepare.py
@@ -13,15 +13,17 @@ import csv
 import json
 import random
 import logging
+import torch
 import torchaudio
 import numpy as np
 from tqdm import tqdm
 from speechbrain.utils.data_utils import download_file
 from speechbrain.dataio.dataio import load_pkl, save_pkl
 import tgt
-from speechbrain.pretrained import GraphemeToPhoneme
+from speechbrain.inference.text import GraphemeToPhoneme
 import re
 from unidecode import unidecode
+from speechbrain.utils.text_to_sequence import _g2p_keep_punctuations
 
 
 logger = logging.getLogger(__name__)
@@ -47,7 +49,7 @@ def prepare_ljspeech(
     pitch_n_fft=1024,
     pitch_hop_length=256,
     pitch_min_f0=65,
-    pitch_max_f0=2093,
+    pitch_max_f0=400,
     skip_prep=False,
     use_custom_cleaner=False,
     device="cpu",
@@ -128,10 +130,10 @@ def prepare_ljspeech(
     duration_folder = None
     pitch_folder = None
     # Setting up additional folders required for FastSpeech2
-    if model_name == "FastSpeech2":
+    if model_name is not None and "FastSpeech2" in model_name:
         # This step requires phoneme alignements to be present in the data_folder
         # We automatically donwload the alignments from https://www.dropbox.com/s/v28x5ldqqa288pu/LJSpeech.zip
-        # Download and unzip LJSpeech phoneme alignments from here: https://www.dropbox.com/sh/647h69vuarms5zj/AABeCQxeyD4AiqIss5eJoX4Qa?dl=0
+        # Download and unzip LJSpeech phoneme alignments from here: https://drive.google.com/drive/folders/1DBRkALpPd6FL9gjHMmMEdHODmkgNIIK4
         alignment_URL = (
             "https://www.dropbox.com/s/v28x5ldqqa288pu/LJSpeech.zip?dl=1"
         )
@@ -146,6 +148,7 @@ def prepare_ljspeech(
         if not os.path.exists(duration_folder):
             os.makedirs(duration_folder)
 
+        # extract pitch for both Fastspeech2 and FastSpeech2WithAligner models
         pitch_folder = os.path.join(data_folder, "pitch")
         if not os.path.exists(pitch_folder):
             os.makedirs(pitch_folder)
@@ -379,27 +382,26 @@ def prepare_json(
     """
 
     logger.info(f"preparing {json_file}.")
-    if model_name == "Tacotron2":
+    if model_name in ["Tacotron2", "FastSpeech2WithAlignment"]:
         logger.info(
             "Computing phonemes for LJSpeech labels using SpeechBrain G2P. This may take a while."
         )
         g2p = GraphemeToPhoneme.from_hparams(
             "speechbrain/soundchoice-g2p", run_opts={"device": device}
         )
-    if model_name == "FastSpeech2":
+    if model_name is not None and "FastSpeech2" in model_name:
         logger.info(
             "Computing pitch as required for FastSpeech2. This may take a while."
         )
 
     json_dict = {}
     for index in tqdm(seg_lst):
-
         # Common data preparation
         id = list(csv_reader)[index][0]
         wav = os.path.join(wavs_folder, f"{id}.wav")
         label = list(csv_reader)[index][2]
         if use_custom_cleaner:
-            label = custom_clean(label)
+            label = custom_clean(label, model_name)
 
         json_dict[id] = {
             "uttid": id,
@@ -408,16 +410,8 @@ def prepare_json(
             "segment": True if "train" in json_file else False,
         }
 
-        # Tacotron2 specific data preparation
-        if model_name == "Tacotron2":
-            # Computes phoneme labels using SpeechBrain G2P for Tacotron2
-            label_phoneme_list = g2p(label)
-            label_phoneme = " ".join(label_phoneme_list)
-            json_dict[id].update({"label_phoneme": label_phoneme})
-
         # FastSpeech2 specific data preparation
         if model_name == "FastSpeech2":
-
             audio, fs = torchaudio.load(wav)
 
             # Parses phoneme alignments
@@ -465,14 +459,24 @@ def prepare_json(
                 wavs_folder, pitch_folder
             )
             if not os.path.isfile(pitch_file):
-                pitch = torchaudio.functional.compute_kaldi_pitch(
+                pitch = torchaudio.functional.detect_pitch_frequency(
                     waveform=audio,
                     sample_rate=fs,
-                    frame_length=(pitch_n_fft / fs * 1000),
-                    frame_shift=(pitch_hop_length / fs * 1000),
-                    min_f0=pitch_min_f0,
-                    max_f0=pitch_max_f0,
-                )[0, :, 0]
+                    frame_time=(pitch_hop_length / fs),
+                    win_length=3,
+                    freq_low=pitch_min_f0,
+                    freq_high=pitch_max_f0,
+                ).squeeze(0)
+
+                # Concatenate last element to match duration.
+                pitch = torch.cat([pitch, pitch[-1].unsqueeze(0)])
+
+                # Mean and Variance Normalization
+                mean = 256.1732939688805
+                std = 328.319759158607
+
+                pitch = (pitch - mean) / std
+
                 pitch = pitch[: sum(duration)]
                 np.save(pitch_file, pitch)
 
@@ -487,6 +491,50 @@ def prepare_json(
                 {"last_phoneme_flags": trimmed_last_phoneme_flags}
             )
 
+        # FastSpeech2WithAlignment specific data preparation
+        if model_name == "FastSpeech2WithAlignment":
+            audio, fs = torchaudio.load(wav)
+            # Computes pitch
+            pitch_file = wav.replace(".wav", ".npy").replace(
+                wavs_folder, pitch_folder
+            )
+            if not os.path.isfile(pitch_file):
+
+                if torchaudio.__version__ < "2.1":
+                    pitch = torchaudio.functional.compute_kaldi_pitch(
+                        waveform=audio,
+                        sample_rate=fs,
+                        frame_length=(pitch_n_fft / fs * 1000),
+                        frame_shift=(pitch_hop_length / fs * 1000),
+                        min_f0=pitch_min_f0,
+                        max_f0=pitch_max_f0,
+                    )[0, :, 0]
+                else:
+                    pitch = torchaudio.functional.detect_pitch_frequency(
+                        waveform=audio,
+                        sample_rate=fs,
+                        frame_time=(pitch_hop_length / fs),
+                        win_length=3,
+                        freq_low=pitch_min_f0,
+                        freq_high=pitch_max_f0,
+                    ).squeeze(0)
+
+                    # Concatenate last element to match duration.
+                    pitch = torch.cat([pitch, pitch[-1].unsqueeze(0)])
+
+                    # Mean and Variance Normalization
+                    mean = 256.1732939688805
+                    std = 328.319759158607
+
+                    pitch = (pitch - mean) / std
+
+                np.save(pitch_file, pitch)
+
+            phonemes = _g2p_keep_punctuations(g2p, label)
+            # Updates data for the utterance
+            json_dict[id].update({"phonemes": phonemes})
+            json_dict[id].update({"pitch": pitch_file})
+
     # Writing the dictionary to the json file
     with open(json_file, mode="w") as json_f:
         json.dump(json_dict, json_f, indent=2)
@@ -496,26 +544,26 @@ def prepare_json(
 
 def get_alignment(tier, sampling_rate, hop_length, last_phoneme_flags):
     """
-  Returns phonemes, phoneme durations (in frames), start time (in seconds), end time (in seconds).
-  This function is adopted from https://github.com/ming024/FastSpeech2/blob/master/preprocessor/preprocessor.py
-
-  Arguments
-  ---------
-  tier : tgt.core.IntervalTier
-      For an utterance, contains Interval objects for phonemes and their start time and end time in seconds
-  sampling_rate : int
-      Sample rate if audio signal
-  hop_length : int
-      Hop length for duration computation
-  last_phoneme_flags : list
-      List of (phoneme, flag) tuples with flag=1 if the phoneme is the last phoneme else flag=0
-
-
-  Returns
-  -------
-  (phones, durations, start_time, end_time) : tuple
-      The phonemes, durations, start time, and end time for an utterance
-  """
+    Returns phonemes, phoneme durations (in frames), start time (in seconds), end time (in seconds).
+    This function is adopted from https://github.com/ming024/FastSpeech2/blob/master/preprocessor/preprocessor.py
+
+    Arguments
+    ---------
+    tier : tgt.core.IntervalTier
+        For an utterance, contains Interval objects for phonemes and their start time and end time in seconds
+    sampling_rate : int
+        Sample rate if audio signal
+    hop_length : int
+        Hop length for duration computation
+    last_phoneme_flags : list
+        List of (phoneme, flag) tuples with flag=1 if the phoneme is the last phoneme else flag=0
+
+
+    Returns
+    -------
+    (phones, durations, start_time, end_time) : tuple
+        The phonemes, durations, start time, and end time for an utterance
+    """
 
     sil_phones = ["sil", "sp", "spn", ""]
 
@@ -570,23 +618,23 @@ def get_alignment(tier, sampling_rate, hop_length, last_phoneme_flags):
 
 def get_last_phoneme_info(words_seq, phones_seq):
     """This function takes word and phoneme tiers from a TextGrid file as input
-  and provides a list of tuples for the phoneme sequence indicating whether
-  each of the phonemes is the last phoneme of a word or not.
+    and provides a list of tuples for the phoneme sequence indicating whether
+    each of the phonemes is the last phoneme of a word or not.
 
-  Each tuple of the returned list has this format: (phoneme, flag)
+    Each tuple of the returned list has this format: (phoneme, flag)
 
 
-  Arguments
-  ---------
-  words_seq :
-      word tier from a TextGrid file
-  phones_seq :
-      phoneme tier from a TextGrid file
+    Arguments
+    ---------
+    words_seq :
+        word tier from a TextGrid file
+    phones_seq :
+        phoneme tier from a TextGrid file
 
-  Returns
-  -------
-  last_phoneme_flags : list
-      each tuple of the returned list has this format: (phoneme, flag)
+    Returns
+    -------
+    last_phoneme_flags : list
+        each tuple of the returned list has this format: (phoneme, flag)
     """
 
     # Gets all phoneme objects for the entire sequence
@@ -613,7 +661,7 @@ def get_last_phoneme_info(words_seq, phones_seq):
     return last_phoneme_flags
 
 
-def custom_clean(text):
+def custom_clean(text, model_name):
     """
     Uses custom criteria to clean text.
 
@@ -621,6 +669,8 @@ def custom_clean(text):
     ---------
     text : str
         Input text to be cleaned
+    model_name : str
+        whether to treat punctuations
 
     Returns
     -------
@@ -652,10 +702,12 @@ def custom_clean(text):
         ]
     ]
     text = unidecode(text.lower())
-    text = re.sub("[:;]", " - ", text)
-    text = re.sub(r'[)(\[\]"]', " ", text)
+    if model_name != "FastSpeech2WithAlignment":
+        text = re.sub("[:;]", " - ", text)
+        text = re.sub(r'[)(\[\]"]', " ", text)
+        text = text.strip().strip().strip("-")
+
     text = re.sub(" +", " ", text)
     for regex, replacement in _abbreviations:
         text = re.sub(regex, replacement, text)
-    text = text.strip().strip().strip("-")
     return text
diff --git a/recipes/LJSpeech/quantization/README.md b/recipes/LJSpeech/quantization/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ab9c190efe0fbc49ae5a10b6a2adc94e11346453
--- /dev/null
+++ b/recipes/LJSpeech/quantization/README.md
@@ -0,0 +1,49 @@
+
+# K-means (Quantization)
+This folder contains recipes for training K-means clustering model for the LJSpeech Dataset.
+The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation.
+It supports  kmeans model using the features from  HuBERT, WAVLM or Wav2Vec.
+
+You can download LibriSpeech at http://www.openslr.org/12
+
+## Installing Extra Dependencies
+
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+
+```
+pip install -r extra_requirements.txt
+```
+
+# How to run:
+```shell
+python train.py hparams/train_with_{SSL_model}.yaml
+```
+
+# Results
+
+The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0).
+
+The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository.
+
+
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/LJSpeech/quantization/extra-requirements.txt b/recipes/LJSpeech/quantization/extra-requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bd6f06aa66c3b5fcfb0e48905e727393af7627cd
--- /dev/null
+++ b/recipes/LJSpeech/quantization/extra-requirements.txt
@@ -0,0 +1,3 @@
+scikit-learn
+tgt
+unidecode
diff --git a/recipes/LJSpeech/quantization/hparams/train_with_hubert.yaml b/recipes/LJSpeech/quantization/hparams/train_with_hubert.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4eee7151e6bf519a55eaa0b64526e7403c232ec2
--- /dev/null
+++ b/recipes/LJSpeech/quantization/hparams/train_with_hubert.yaml
@@ -0,0 +1,56 @@
+################################
+# Recipe for Training K-Means Clustering on LJSpeech Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from LJSpeech data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/LJSpeech/clustering/hubert/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+data_folder: !PLACEHOLDER # e,g./path/to/LJSpeech-1.1
+
+train_json: !ref <save_folder>/train.json
+
+splits: ["train"]
+split_ratio: [100]
+skip_prep: False
+sample_rate: 16000
+
+ssl_hub: facebook/hubert-base-ls960
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/hubert_checkpoint
+ssl_layer_num: 7
+batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/LJSpeech/quantization/hparams/train_with_wav2vec.yaml b/recipes/LJSpeech/quantization/hparams/train_with_wav2vec.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b12c67442e56c806bf731cd9f98075b14d19e5ea
--- /dev/null
+++ b/recipes/LJSpeech/quantization/hparams/train_with_wav2vec.yaml
@@ -0,0 +1,55 @@
+################################
+# Recipe for Training K-Means Clustering on LJSpeech Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from LJSpeech data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/LJSpeech/clustering/wav2vec/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+# Data files
+data_folder: !PLACEHOLDER # e,g./path/to/LJSpeech-1.1
+
+train_json: !ref <save_folder>/train.json
+splits: ["train"]
+split_ratio: [100]
+skip_prep: False
+sample_rate: 16000
+
+ssl_hub: facebook/wav2vec2-large-960h-lv60-self
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/wav2vec_checkpoint
+ssl_layer_num: 7
+batch_size: 64 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/LJSpeech/quantization/hparams/train_with_wavlm.yaml b/recipes/LJSpeech/quantization/hparams/train_with_wavlm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6a767bafc294dee87f550b67883c5cae0b8c4cc7
--- /dev/null
+++ b/recipes/LJSpeech/quantization/hparams/train_with_wavlm.yaml
@@ -0,0 +1,56 @@
+################################
+# Recipe for Training K-Means Clustering on LJSpeech Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from LJSpeech data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/LJSpeech/clustering/wavlm/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+# Data files
+data_folder: !PLACEHOLDER # e,g./path/to/LJSpeech-1.1
+
+train_json: !ref <save_folder>/train.json
+
+splits: ["train"]
+split_ratio: [100]
+skip_prep: False
+sample_rate: 16000
+
+ssl_hub: microsoft/wavlm-large
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/wavlm_checkpoint
+ssl_layer_num: 7
+batch_size: 32 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/LJSpeech/quantization/ljspeech_prepare.py b/recipes/LJSpeech/quantization/ljspeech_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..2de5a21a8daef2535958e950a9d95e0853bf6ba7
--- /dev/null
+++ b/recipes/LJSpeech/quantization/ljspeech_prepare.py
@@ -0,0 +1 @@
+../ljspeech_prepare.py
\ No newline at end of file
diff --git a/recipes/LJSpeech/quantization/train.py b/recipes/LJSpeech/quantization/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..340688327ef92b65320b7756cae0f56e997329e5
--- /dev/null
+++ b/recipes/LJSpeech/quantization/train.py
@@ -0,0 +1,125 @@
+"""
+Recipe  to train K-means clustering model on self-supervised representations.
+
+To run this recipe, do the following:
+> python train.py hparams/train_with_[SSL-model].yaml --data_folder=/path/to/LJSpeech
+Author
+ * Pooneh Mousavi 2023
+"""
+
+import os
+import sys
+import logging
+import speechbrain as sb
+from speechbrain.utils.distributed import run_on_main
+from hyperpyyaml import load_hyperpyyaml
+from torch.utils.data import DataLoader
+from speechbrain.dataio.dataloader import LoopedLoader
+from speechbrain.utils.kmeans import fetch_kmeans_model, train, save_model
+import torchaudio
+
+
+logger = logging.getLogger(__name__)
+
+
+def dataio_prepare(hparams):
+
+    # Define audio pipeline:
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        sig = sb.dataio.dataio.read_audio(wav)
+        info = torchaudio.info(wav)
+        resampled = torchaudio.transforms.Resample(
+            info.sample_rate, hparams["sample_rate"],
+        )(sig)
+        return resampled
+
+    datasets = {}
+    data_info = {
+        "train": hparams["train_json"],
+    }
+    for dataset in hparams["splits"]:
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=data_info[dataset],
+            replacements={"data_root": hparams["data_folder"]},
+            dynamic_items=[audio_pipeline],
+            output_keys=["id", "sig"],
+        )
+
+    return datasets
+
+    return datasets
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing Librispeech)
+    from ljspeech_prepare import prepare_ljspeech  # noqa
+
+    # multi-gpu (ddp) save data preparation
+    run_on_main(
+        prepare_ljspeech,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "save_folder": hparams["save_folder"],
+            "splits": hparams["splits"],
+            "split_ratio": hparams["split_ratio"],
+            "seed": hparams["seed"],
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Load SSL model
+    hparams["ssl_model"] = hparams["ssl_model"].to(run_opts["device"])
+
+    # Make training Dataloader
+    train_set = dataio_prepare(hparams)["train"]
+    if not (
+        isinstance(train_set, DataLoader) or isinstance(train_set, LoopedLoader)
+    ):
+        train_set = sb.dataio.dataloader.make_dataloader(
+            train_set, **hparams["train_dataloader_opts"]
+        )
+
+    # Load pretrained KMeans model if it exists. Otherwise,  create new one.
+    checkpoint_path = os.path.join(
+        hparams["save_folder"], f"kmeans_{hparams['num_clusters']}.pt"
+    )
+    kmeans_model = fetch_kmeans_model(
+        n_clusters=hparams["num_clusters"],
+        init=hparams["init"],
+        max_iter=hparams["max_iter"],
+        batch_size=hparams["batch_size"],
+        tol=hparams["tol"],
+        max_no_improvement=hparams["max_no_improvement"],
+        n_init=hparams["n_init"],
+        reassignment_ratio=hparams["reassignment_ratio"],
+        random_state=hparams["seed"],
+        checkpoint_path=checkpoint_path,
+    )
+
+    # Train and save Kmeans model
+    train(
+        kmeans_model,
+        train_set,
+        hparams["ssl_model"],
+        hparams["ssl_layer_num"],
+        kmeans_batch_size=hparams["kmeans_batch_size"],
+        device=run_opts["device"],
+    )
+
+    logger.info(f"Saving kmeans model at {checkpoint_path}.")
+    save_model(kmeans_model, checkpoint_path)
diff --git a/recipes/LibriMix/separation/README.md b/recipes/LibriMix/separation/README.md
index 39c4b6c32f22680901c3bc908fbe36bc78486d3d..17094d5b61ee4df7c0b5e0b6847ab624b85d648b 100644
--- a/recipes/LibriMix/separation/README.md
+++ b/recipes/LibriMix/separation/README.md
@@ -70,8 +70,8 @@ The output folder with the trained model and the logs can be found [here](https:
 
 You can run the following command to train the model using Distributed Data Parallel (DDP) with 2 GPUs:
 
-```
- python -m torch.distributed.launch --nproc_per_node=2 train.py hparams/sepformer-libri2mix.yaml --data_folder /yourdatapath --distributed_launch --distributed_backend='nccl'
+```bash
+torchrun --nproc_per_node=2 train.py hparams/sepformer-libri2mix.yaml --data_folder /yourdatapath
 ```
 You can add the other runtime options as appropriate. For more complete information on multi-GPU usage, take a look at this [tutorial](https://colab.research.google.com/drive/13pBUacPiotw1IvyffvGZ-HrtBr9T6l15).
 
diff --git a/recipes/LibriMix/separation/hparams/sepformer-libri2mix.yaml b/recipes/LibriMix/separation/hparams/sepformer-libri2mix.yaml
index 5913359617f06fe38e17a37001b87a4b6eb52422..ffa5a1ef2edde5bb435f812248d6b78ee278155a 100644
--- a/recipes/LibriMix/separation/hparams/sepformer-libri2mix.yaml
+++ b/recipes/LibriMix/separation/hparams/sepformer-libri2mix.yaml
@@ -31,13 +31,13 @@ skip_prep: False
 ckpt_interval_minutes: 60
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 2
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -60,18 +60,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/LibriMix/separation/hparams/sepformer-libri3mix.yaml b/recipes/LibriMix/separation/hparams/sepformer-libri3mix.yaml
index 7451b447cdf7e7480718af048d2b4dec9e3791a5..abc9c76c7e07df2847eea7e551cd7270fa14ea00 100644
--- a/recipes/LibriMix/separation/hparams/sepformer-libri3mix.yaml
+++ b/recipes/LibriMix/separation/hparams/sepformer-libri3mix.yaml
@@ -31,13 +31,13 @@ skip_prep: False
 ckpt_interval_minutes: 60
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 3
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -60,18 +60,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/LibriMix/separation/train.py b/recipes/LibriMix/separation/train.py
index b84ac79db2cfe98350f1da5be9675b875b2d73fb..803d8bc483acf9ccba24e036479af8f8162b5870 100755
--- a/recipes/LibriMix/separation/train.py
+++ b/recipes/LibriMix/separation/train.py
@@ -29,12 +29,12 @@ import torchaudio
 import speechbrain as sb
 import speechbrain.nnet.schedulers as schedulers
 from speechbrain.utils.distributed import run_on_main
-from torch.cuda.amp import autocast
 from hyperpyyaml import load_hyperpyyaml
 import numpy as np
 from tqdm import tqdm
 import csv
 import logging
+from speechbrain.core import AMPConfig
 
 logger = logging.getLogger(__name__)
 
@@ -75,7 +75,8 @@ class Separation(sb.Brain):
                         targets = targets[:, :min_len, :]
 
                 if self.hparams.use_wavedrop:
-                    mix = self.hparams.wavedrop(mix, mix_lens)
+                    mix = self.hparams.drop_chunk(mix, mix_lens)
+                    mix = self.hparams.drop_freq(mix)
 
                 if self.hparams.limit_training_signal_len:
                     mix, targets = self.cut_signals(mix, targets)
@@ -111,6 +112,9 @@ class Separation(sb.Brain):
 
     def fit_batch(self, batch):
         """Trains one batch"""
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
+
         # Unpacking batch list
         mixture = batch.mix_sig
         targets = [batch.s1_sig, batch.s2_sig]
@@ -122,14 +126,51 @@ class Separation(sb.Brain):
         if self.hparams.num_spks == 3:
             targets.append(batch.s3_sig)
 
-        if self.auto_mix_prec:
-            with autocast():
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
+                    predictions, targets = self.compute_forward(
+                        mixture, targets, sb.Stage.TRAIN, noise
+                    )
+                    loss = self.compute_objectives(predictions, targets)
+
+                    # hard threshold the easy dataitems
+                    if self.hparams.threshold_byloss:
+                        th = self.hparams.threshold
+                        loss = loss[loss > th]
+                        if loss.nelement() > 0:
+                            loss = loss.mean()
+                    else:
+                        loss = loss.mean()
+
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    self.scaler.scale(loss).backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        self.scaler.unscale_(self.optimizer)
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.scaler.step(self.optimizer)
+                    self.scaler.update()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
+                    )
+                    loss.data = torch.tensor(0).to(self.device)
+            else:
                 predictions, targets = self.compute_forward(
                     mixture, targets, sb.Stage.TRAIN, noise
                 )
                 loss = self.compute_objectives(predictions, targets)
 
-                # hard threshold the easy dataitems
                 if self.hparams.threshold_byloss:
                     th = self.hparams.threshold
                     loss = loss[loss > th]
@@ -138,56 +179,24 @@ class Separation(sb.Brain):
                 else:
                     loss = loss.mean()
 
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                self.scaler.scale(loss).backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    self.scaler.unscale_(self.optimizer)
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm,
-                    )
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
-                    )
-                )
-                loss.data = torch.tensor(0).to(self.device)
-        else:
-            predictions, targets = self.compute_forward(
-                mixture, targets, sb.Stage.TRAIN, noise
-            )
-            loss = self.compute_objectives(predictions, targets)
-
-            if self.hparams.threshold_byloss:
-                th = self.hparams.threshold
-                loss = loss[loss > th]
-                if loss.nelement() > 0:
-                    loss = loss.mean()
-            else:
-                loss = loss.mean()
-
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                loss.backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm
-                    )
-                self.optimizer.step()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    loss.backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.optimizer.step()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
                     )
-                )
-                loss.data = torch.tensor(0).to(self.device)
+                    loss.data = torch.tensor(0).to(self.device)
         self.optimizer.zero_grad()
 
         return loss.detach().cpu()
@@ -263,9 +272,7 @@ class Separation(sb.Brain):
             recombine = True
 
             for i in range(targets.shape[-1]):
-                new_target = self.hparams.speedperturb(
-                    targets[:, :, i], targ_lens
-                )
+                new_target = self.hparams.speed_perturb(targets[:, :, i])
                 new_targets.append(new_target)
                 if i == 0:
                     min_len = new_target.shape[-1]
@@ -574,6 +581,10 @@ if __name__ == "__main__":
             "Please, specify a valid base_folder_dm folder when using dynamic mixing"
         )
 
+    # Update precision to bf16 if the device is CPU and precision is fp16
+    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
+        hparams["precision"] = "bf16"
+
     # Data preparation
     from prepare_data import prepare_librimix
 
@@ -648,9 +659,7 @@ if __name__ == "__main__":
     # Load pretrained model if pretrained_separator is present in the yaml
     if "pretrained_separator" in hparams:
         run_on_main(hparams["pretrained_separator"].collect_files)
-        hparams["pretrained_separator"].load_collected(
-            device=run_opts["device"]
-        )
+        hparams["pretrained_separator"].load_collected()
 
     # Brain class initialization
     separator = Separation(
diff --git a/recipes/LibriParty/VAD/hparams/train.yaml b/recipes/LibriParty/VAD/hparams/train.yaml
index 846cf1160795966b44b0eeedb89f5ba4b276afe7..be91916855eb846b941aecf5d2b648ce26acbf2a 100644
--- a/recipes/LibriParty/VAD/hparams/train.yaml
+++ b/recipes/LibriParty/VAD/hparams/train.yaml
@@ -22,8 +22,12 @@ train_log: !ref <output_folder>/train_log.txt
 # LibriParty (main data)
 data_folder: !PLACEHOLDER  # e.g. /path/to/LibriParty
 
+# Openrir Dataset for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_csv_openrir: !ref <save_folder>/noise_openrir.csv #The data manifest files are created by the data preparation script
+
 # Additional data (for augmentation)
-open_rir_folder: !ref <data_folder> # where to store noisy +ris from open_rir
 musan_folder: !PLACEHOLDER  # e.g, /path/to/musan (download it from the web before)
 commonlanguage_folder: !PLACEHOLDER  # e.g, /path/to/commonlang (download it from the web before)
 
@@ -37,7 +41,7 @@ speech_csv: !ref <save_folder>/speech.csv
 multilang_speech_csv: !ref <save_folder>/multilang_speech.csv
 skip_prep: False # Skip data preparation
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 100
 lr: 1.0
 lr_final: 0.1
@@ -45,18 +49,23 @@ batch_size: 2
 example_length: 5 # in seconds
 sample_rate: 16000
 time_resolution: 0.01 # in seconds
+
+num_workers: 4
 train_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 valid_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 test_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 # Feature parameters
 n_fft: 400
 n_mels: 40
 
-# Model parameters
+####################### Model Parameters #######################################
 # activation: !name:torch.nn.LeakyReLU
 # dropout: 0.15
 # cnn_blocks: 2
@@ -71,54 +80,52 @@ output_neurons: 1
 
 
 # Data augmentation
-add_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <open_rir_folder>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: -5
-    noise_snr_high: -15
-
-# noise_corruption: !new:speechbrain.lobes.augment.EnvCorrupt
-#    openrir_folder: !ref <open_rir_folder>
-#    babble_prob: 0.0
-#    reverb_prob: 0.0
-#    noise_prob: 1.0
-#    noise_snr_low: 5
-#    noise_snr_high: 15
-
-add_noise_musan: !new:speechbrain.lobes.augment.EnvCorrupt
-    noise_csv: !ref <noise_csv>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: -15
-    noise_snr_high: -20
-
-add_music_musan: !new:speechbrain.lobes.augment.EnvCorrupt
-    noise_csv: !ref <music_csv>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: -15
-    noise_snr_high: -20
-
-add_speech_musan: !new:speechbrain.lobes.augment.EnvCorrupt
-    noise_csv: !ref <speech_csv>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: -15
-    noise_snr_high: -20
-
-# add_speech_multilang: !new:speechbrain.lobes.augment.EnvCorrupt
-#    noise_csv: !ref <multilang_speech_csv>
-#    babble_prob: 0.0
-#    reverb_prob: 0.0
-#    noise_prob: 1.0
-#    noise_snr_low: -15
-#    noise_snr_high: -20
-
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_csv_openrir>
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_csv_openrir>
+    snr_low: -5
+    snr_high: 15
+    noise_sample_rate: 16000
+    clean_sample_rate: 16000
+    num_workers: !ref <num_workers>
+
+add_noise_musan: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_csv>
+    snr_low: -5
+    snr_high: 15
+    noise_sample_rate: 16000
+    clean_sample_rate: 16000
+    num_workers: !ref <num_workers>
+
+add_music_musan: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <music_csv>
+    snr_low: -5
+    snr_high: 15
+    noise_sample_rate: 16000
+    clean_sample_rate: 16000
+    num_workers: !ref <num_workers>
+
+add_speech_musan: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <speech_csv>
+    snr_low: -5
+    snr_high: 15
+    noise_sample_rate: 16000
+    clean_sample_rate: 16000
+    num_workers: !ref <num_workers>
+
+#add_speech_multilang: !new:speechbrain.augment.time_domain.AddNoise
+#    csv_file: !ref <multilang_speech_csv>
+#    snr_low: -5
+#    snr_high: 15
+#    noise_sample_rate: 16000
+#    clean_sample_rate: 16000
+#    num_workers: !ref <num_workers>
 
 # Models
 compute_features: !new:speechbrain.lobes.features.Fbank
diff --git a/recipes/LibriParty/VAD/train.py b/recipes/LibriParty/VAD/train.py
index 7bb640c4eb9d872b1b3e1922463e2ba1dc0e782d..2e712826489d27614462446fba6b22bfdbebed43 100644
--- a/recipes/LibriParty/VAD/train.py
+++ b/recipes/LibriParty/VAD/train.py
@@ -233,6 +233,9 @@ if __name__ == "__main__":
         },
     )
 
+    # Prepare openrir
+    run_on_main(hparams["prepare_noise_data"])
+
     # Prepare Musan
     from musan_prepare import prepare_musan
 
diff --git a/recipes/LibriParty/generate_dataset/README.md b/recipes/LibriParty/generate_dataset/README.md
index fbb7d997a3602bc00a5dd50732b57a7d4805fa79..75ac905ffeabfb726df5126a078be3f49f276e95 100644
--- a/recipes/LibriParty/generate_dataset/README.md
+++ b/recipes/LibriParty/generate_dataset/README.md
@@ -38,7 +38,7 @@ It also requires background QUT-TIMIT noises. The metadata are downloaded from t
         You need to specify *metadata_folder*, *out_folder* and paths to downloaded source datasets:
         Librispeech, noises and impulse responses and QUT noise.
 
-- step 3: run *get_dataset_from_metadata.py*
+- step 3: run *python get_dataset_from_metadata.py dataset.yaml*
 
 #### Custom:
 Follow the next steps to create a novel LibriParty datasets.
diff --git a/recipes/LibriParty/generate_dataset/dataset.yaml b/recipes/LibriParty/generate_dataset/dataset.yaml
index 11d908a0d6db229cf7558b5ef524462b7272bc01..3dcddb9bf94e025e1daf434b0b79a593fbdac5e1 100644
--- a/recipes/LibriParty/generate_dataset/dataset.yaml
+++ b/recipes/LibriParty/generate_dataset/dataset.yaml
@@ -17,7 +17,7 @@ save_dry_sources: False
 # Source datasets paths #
 #########################
 
-librispeech_root: !PLACEHOLDER
+librispeech_root: !PLACEHOLDER # e.g., /workspace/LibriParty/LibriSpeech/
 # /media/sam/bx500/LibriSpeech
 # root path to librispeech: download from https://openslr.org/12/
 
@@ -29,13 +29,13 @@ librispeech_folders: # folders one wants to use for the train dataset.
   eval:
     - !ref <librispeech_root>/test-clean/
 
-rirs_noises_root: /media/sam/bx500/LibriParty/RIRS_NOISES/
+rirs_noises_root: !PLACEHOLDER #e.g., /workspace/LibriParty/RIRS_NOISES/
 rirs_folders:
   - !ref <rirs_noises_root>/simulated_rirs/
   - !ref <rirs_noises_root>/real_rirs_isotropic_noises
 noises_folders:
   - !ref <rirs_noises_root>/pointsource_noises/
-backgrounds_root: /media/sam/bx500/LibriParty/QUT_NOISE_16kHz/
+backgrounds_root: !PLACEHOLDER # e.g., /workspace/LibriParty/QUT_NOISE_16kHz/
 # optional background noise from QUT (required for "official" dataset)
 # One can use also other background noises.
 
diff --git a/recipes/LibriParty/generate_dataset/local/create_mixtures_metadata.py b/recipes/LibriParty/generate_dataset/local/create_mixtures_metadata.py
index ee90313bd065693806a0b335fb843cd343ec2091..7f671b3a56285c215621a0a20b1f888264a31e18 100644
--- a/recipes/LibriParty/generate_dataset/local/create_mixtures_metadata.py
+++ b/recipes/LibriParty/generate_dataset/local/create_mixtures_metadata.py
@@ -11,7 +11,6 @@ Samuele Cornell, 2020
 import numpy as np
 from pathlib import Path
 import json
-import os
 from tqdm import tqdm
 import torchaudio
 
@@ -203,7 +202,5 @@ def create_metadata(
 
         dataset_metadata["session_{}".format(n_sess)] = activity
 
-    with open(
-        os.path.join(configs["out_folder"], output_filename + ".json"), "w"
-    ) as f:
+    with open(output_filename + ".json", "w") as f:
         json.dump(dataset_metadata, f, indent=4)
diff --git a/recipes/LibriParty/generate_dataset/local/resample_folder.py b/recipes/LibriParty/generate_dataset/local/resample_folder.py
index 481ebf09e17ed1c11446c6bf4079d9cf332c2f1d..801e3213798bbc9e5285a813b32b21c601772033 100644
--- a/recipes/LibriParty/generate_dataset/local/resample_folder.py
+++ b/recipes/LibriParty/generate_dataset/local/resample_folder.py
@@ -31,15 +31,12 @@ parser.add_argument("--regex", type=str, default="*.wav")
 def resample_folder(input_folder, output_folder, fs, regex):
 
     files = get_all_files(input_folder, match_and=[regex])
-    torchaudio.initialize_sox()
     for f in tqdm.tqdm(files):
 
         # we use sox because torchaudio.Resample uses too much RAM.
-        resample = torchaudio.sox_effects.SoxEffectsChain()
-        resample.append_effect_to_chain("rate", [fs])
-        resample.set_input_file(f)
-
-        audio, fs = resample.sox_build_flow_effects()
+        audio, fs = torchaudio.sox_effects.apply_effects_file(
+            f, [["rate", str(fs)]]
+        )
 
         audio = (
             audio / torch.max(torch.abs(audio), dim=-1, keepdim=True)[0]
@@ -59,7 +56,6 @@ def resample_folder(input_folder, output_folder, fs, regex):
             audio,
             fs,
         )
-    torchaudio.shutdown_sox()
 
 
 if __name__ == "__main__":
diff --git a/recipes/LibriSpeech/ASR/CTC/README.md b/recipes/LibriSpeech/ASR/CTC/README.md
index 5c816737a697030b12c50ce2230bf7e51c889dfe..8b0743446ccb174d40db108f2c702cf2e32de4f6 100644
--- a/recipes/LibriSpeech/ASR/CTC/README.md
+++ b/recipes/LibriSpeech/ASR/CTC/README.md
@@ -1,6 +1,9 @@
 # LibriSpeech ASR with CTC and pre-trained wav2vec2 or whisper models.
 This folder contains the scripts to finetune a wav2vec2 or a whisper based system using LibriSpeech.
 You can download LibriSpeech at http://www.openslr.org/12.
+The loss function is the CTC loss and it is implemented in two different ways:
+- Using the [CTCLoss](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) from PyTorch.
+- Using the [CTC implementation](https://github.com/k2-fsa/k2/blob/master/k2/python/k2/ctc_loss.py) from K2 (WFST-based). For an example of such recipe, check the `train_with_wav2vec_k2.py` file.
 
 **Supported pre-trained wav2vec2:** [SpeechBrain](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriSpeech/self-supervised-learning/wav2vec2) and [HuggingFace](https://github.com/speechbrain/speechbrain/tree/develop/recipes/CommonVoice/self-supervised-learning/wav2vec2)
 
@@ -25,16 +28,62 @@ To run a fine-tuning of "WavLM" with signal downsampled inputs (for faster train
 python train_with_wav2vec.py hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml --downsampling_factor 2
 ```
 
+# WFST-based CTC loss
+To fine-tune a wav2vec 2.0 model with the WFST-based CTC loss, you can use the `train_with_wav2vec_k2.py` script. This will create a `lang` directory inside your output folder, which will contain the files required to build a lexicon FST. The tokenization method used here is a very basic character-based tokenization (e.g. `hello -> h e l l o`).
+
+To use this script, you will first need to install `k2`. The integration has been tested with `k2==1.24.4` and `torch==2.0.1`, although it should also work with any `torch` version as long as `k2` supports it (compatibility list [here](https://k2-fsa.github.io/k2/installation/pre-compiled-cuda-wheels-linux/index.html)). You can install `k2` by following the instructions [here](https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cuda-example).
+
+Using a lexicon FST (L) while training can help guide the model to better predictions. When decoding, you can either use a simple HL decoding graph (where H is the ctc topology), or use an HLG graph (where G is usually a 3-gram language model) to further improve the results. In addition, whole lattice rescoring is also supported. This typically happens with a 4-gram language model. See `hparams/train_with_wav2vec_k2.yaml`` for more details.
+
+If you choose to use a 3-gram or a 4-gram language model, you can either supply pre-existing ARPA LMs for both cases, including the option to train your own, or you can specify the name in the YAML docstring for automatic downloading. Comprehensive instructions are provided in `train_hf_wav2vec_k2.yaml`.
+
+For those interested in training their own language model, please consult our recipe at LibriSpeech/LM/train_ngram.py.
+
+Example usage:
+```
+python train_with_wav2vec_k2.py hparams/train_hf_wav2vec_k2.yaml --data_folder=/path/to/LibriSpeech
+```
+
+To use the HLG graph (instead of the default HL), pass `--compose_HL_with_G=True`. To use the 4-gram LM for rescoring, pass the `--decoding_method=whole-lattice-rescoring` argument. Note that this will require more memory, as the whole lattice will be kept in memory during the decoding. In this recipe, the `lm_scale` used by default is 0.4. This is the value that gave the best results in our HL-graph experiments after trying scales of `[0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4]`. When rescoring is used alongside the HLG graph, the 4-gram seems to not bring any improvement. The best lm scale in that case was 0.2 (the lowest value we tried).
+
 # KenLM n-gram CTC rescoring
-To enable n-gram rescoring during the decoding, you can download the LibriSpeech official LM from [here](https://www.openslr.org/11/). Please make sure to install the extra dependencies first. Any KenLM language model may be used with this rescoring technique. Results are reported without rescoring.
+To enable n-gram rescoring during the decoding, you can download the LibriSpeech official LM from [here](https://www.openslr.org/11/). Please make sure to install the extra dependencies first. Any KenLM language model may be used with this rescoring technique. The n-gram can either be a binary or an arpa file, but note that the binary format is faster to load. The following command shows how to use the official LibriSpeech 4-gram LM with SpeechBrain:
+```bash
+wget https://openslr.elda.org/resources/11/4-gram.arpa.gz
+gzip -d 4-gram.arpa.gz
+python train_with_wav2vec.py hparams/file.yaml --kenlm_model_path='4-gram.arpa'
+```
+
+# Rescoring with a Neural Language Model
+Two yamls do support LM rescoring: `train_hf_wav2vec_rnn_rescoring.yaml` and `train_hf_wav2vec_transformer_rescoring.yaml`. The first one uses a RNN LM, while the second one uses a Transformer LM. Both LMs are already pretrained on LibriSpeech (see [RNNLM](https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech) and [TransformerLM](https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech)). The acoustic model (wav2vec2) generates a list of hypotheses (called n-best), which are then rescored (aka re-ranked) by the LM. The LM rescores by computing the score of each hypothesis by summing the log-probabilities of each tokens with respect to the previous tokens. The LM score is then added to the acoustic model score to obtain the final score. Using this technique, will results in better WERs. For instance, we went from 1.95 to 1.57 of WER. However, note that the inference time will be slower.
+
+Two parameters need to be tuned: `topk` (and `beam_size` to have enough topk) and `lm_weight`. Increasing `topk` will increase the number of hypotheses to be rescored, and ultimately the inference time. Increasing `lm_weight` will increase the importance of the LM score in the final score. The following command shows how to use the RNN LM with SpeechBrain:
+```bash
+python train_with_wav2vec.py hparams/train_hf_wav2vec_rnn_rescoring.yaml --data_folder=/path/to/LibriSpeech/ --topk=50 --beam_size=50 --lm_weight=0.5
+```
+Note: by default, `topk` is set to 20 as it gives a good trade-off between WER and inference time.
 
 # Results
 
-| Release | Hyperparams file | Finetuning Split | Test Clean WER | HuggingFace link | Full model link | GPUs |
-|:-------------:|:---------------------------:| :-----:| :-----:| :-----:| :-----:| :--------:|
-| 09-09-21 | train_hf_wav2vec.yaml | 960h | 1.90 | [Link](https://huggingface.co/speechbrain/asr-wav2vec2-librispeech) | [Link](https://www.dropbox.com/sh/qj2ps85g8oiicrj/AAAxlkQw5Pfo0M9EyHMi8iAra?dl=0) | 1xRTX8000 48GB |
-| 22-09-22 | train_sb_wav2vec.yaml | 960h | 4.2 | Not Avail. | Not Avail. | 2xTesla V100 32GB |
-| 06-12-23 | train_hf_whisper.yaml (small) | 960h | 4.89 | Not Avail. | Not Avail. | 4xRTX 2080 Ti |
+| Release | Hyperparams file | Decoding method | Finetuning Split | Test-clean WER | GPU- Test-clean Inference Time | Test-other WER | GPU- Test-other Inference Time |  HuggingFace link | Full model link | Inference GPUs | Training GPUs |
+|:-------------:|:---------------------------:|  :----------:|  :-----:| :-----:| :-----:| :-----:| :-----:| :-----:| :-----:| :--------:| :--------:|
+| 05-08-23 | train_hf_wav2vec.yaml | GreedySearch | 960h  | 2.12 | 1min30s | 4.31| 1min24s | [Link](https://huggingface.co/speechbrain/asr-wav2vec2-librispeech) | [Link](https://www.dropbox.com/sh/qj2ps85g8oiicrj/AAAxlkQw5Pfo0M9EyHMi8iAra?dl=0) | 1xRTX3090 24GB | 1xA100 40GB |
+| 05-08-23 | train_hf_wav2vec.yaml | GreedySearch  + test batch size = 1| 960h  | 1.95 | 2min09s | 3.97| 2min21s | Not Avail. | [Link](https://www.dropbox.com/sh/8zqufkmegbgpsa8/AACB6MMJ_efbGDvTi5ZhB4pQa?dl=0) | 1xRTX3090 24GB | 1xA100 40GB |
+| 05-08-23 | train_hf_wav2vec.yaml | CTCBeamSearch  + test batch size = 1| 960h  | 1.92 | 2min22s | 3.97 | 2min16s | Not Avail. | [Link](https://www.dropbox.com/sh/8zqufkmegbgpsa8/AACB6MMJ_efbGDvTi5ZhB4pQa?dl=0) | 1xRTX3090 24GB | 1xA100 40GB |
+| 05-08-23 | train_hf_wav2vec.yaml | CTCPrefixBeamSearch  + test batch size = 1| 960h | 1.92 | 2min45s | 3.97 | 2min21s | Not Avail. | [Link](https://www.dropbox.com/sh/8zqufkmegbgpsa8/AACB6MMJ_efbGDvTi5ZhB4pQa?dl=0) | 1xRTX3090 24GB | 1xA100 40GB |
+| 05-08-23 | train_hf_wav2vec.yaml | CTCBeamSearch + 4-gram  + test batch size = 1| 960h  | 1.75  | 2min37s | 3.67 | 2min20s | Not Avail. | [Link](https://www.dropbox.com/sh/8zqufkmegbgpsa8/AACB6MMJ_efbGDvTi5ZhB4pQa?dl=0) | 1xRTX3090 24GB | 1xA100 40GB |
+| 05-08-23 | train_hf_wav2vec.yaml | CTCPrefixBeamSearch + 4-gram  + test batch size = 1| 960h  | 1.80 | 2min38s | 3.78 | 2min25s |Not Avail. | [Link](https://www.dropbox.com/sh/8zqufkmegbgpsa8/AACB6MMJ_efbGDvTi5ZhB4pQa?dl=0) | 1xRTX3090 24GB | 1xA100 40GB |
+| 22-09-22 | train_sb_wav2vec.yaml | GreedySearch | 960h | 4.2 | Not Avail. | Not Avail. | Not Avail. | Not Avail. | Not Avail. | Not Avail.| 2xTesla V100 32GB |
+| 08-12-23 | train_hf_whisper.yaml (small) | CTCBeamSearch  + test batch size = 1 | 960h | 4.72 | 3.08 | 12.66 |3.30 | Not Avail. | [Link](https://www.dropbox.com/sh/zmtp13huxn02fot/AADyKL5q0MwRhEG1-WbSXDWda?dl=0) |  1xRTX3090 24GB | 2xTesla V100 32GB |
+| 08-12-23 | train_hf_whisper.yaml (small) | CTCPrefixBeamSearch  + test batch size = 1 | 960h | 4.73 | 3.19 | 12.65 |3.39 | Not Avail. | [Link](https://www.dropbox.com/sh/zmtp13huxn02fot/AADyKL5q0MwRhEG1-WbSXDWda?dl=0) |  1xRTX3090 24GB | 2xTesla V100 32GB |
+| 08-12-23 | train_hf_whisper.yaml (small) | CTCBeamSearch + 4-gram  + test batch size = 1 | 960h | 4.37 | 3.16 | 11.76 | 3.43 | Not Avail. | [Link](https://www.dropbox.com/sh/zmtp13huxn02fot/AADyKL5q0MwRhEG1-WbSXDWda?dl=0) |  1xRTX3090 24GB | 2xTesla V100 32GB |
+| 08-12-23 | train_hf_whisper.yaml (small) | CTCPrefixBeamSearch + 4-gram  + test batch size = 1 | 960h | 4.44 | 3.30 | 11.89 | 3.47 | Not Avail. | [Link](https://www.dropbox.com/sh/zmtp13huxn02fot/AADyKL5q0MwRhEG1-WbSXDWda?dl=0) |  1xRTX3090 24GB | 2xTesla V100 32GB |
+| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HL graph + 1best decoding + test batch size = 1 | 960h | 1.83 | Not Avail. | 3.82 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/678rj1a44jt4zrxjwaetu/h?rlkey=x0xwz31nkl01qwr3k5ivtywvz&dl=0) |  1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
+| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HLG graph + 1best decoding + test batch size = 1 | 960h | 1.69 | Not Avail. | 3.44 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/c91vqlr8ase90x0m7u3v3/h?rlkey=duh55n0qzlfnfhy4auu0a4f8g&dl=0) |  1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
+| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HL graph + whole lattice rescoring + test batch size = 1 | 960h | 1.72 | Not Avail. | 3.51 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/mx6hd4zc0iyzqvixxre6q/h?rlkey=xxbpb949btmeiecw30be5qwhj&dl=0) |  1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
+| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HLG graph + whole lattice rescoring + test batch size = 1 | 960h | 1.81 | Not Avail. | 3.57 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/kj2ujqj3votq7ue6ydh0l/h?rlkey=mibyoria19zasvuxs0iwx6plt&dl=0) |  1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
+| 08-12-23 | train_hf_wav2vec.yaml | CTCBeamSearch + RNNLM Rescorer  + test batch size = 1 + topk = 100  | 960h | 1.69 | 26mins15 | 3.55 | 32min44s | Not Avail. | [Link](https://www.dropbox.com/sh/k4ixa211yp5b1tm/AAD85sgYw2CH7NKk_qKMO9Tja?dl=0) |  1x A100 40GB | 2xTesla V100 40GB |
+| 08-12-23 | train_hf_wav2vec.yaml | CTCBeamSearch + TransformerLM Rescorer + test batch size = 1 + topk = 100 | 960h | 1.57 | 26mins56s | 3.37 | 32min46 | Not Avail. | [Link](https://www.dropbox.com/sh/ijqalvre7mm08ng/AAD_hsN-8dBneUMMkELsOOxga?dl=0) |  1x A100 40GB | 2xTesla V100 32GB |
 
 # Downsampling inputs for faster fine-tuning and inferences using SSL Models
 This repository contains the code allowing to reproduce part of the results obtained in the paper : "Fine-tuning Strategies for Faster Inference using Speech Self-Supervised Models:  A Comparative Study"
diff --git a/recipes/LibriSpeech/ASR/CTC/extra_requirements.txt b/recipes/LibriSpeech/ASR/CTC/extra_requirements.txt
index 620aa58fa1a02d04ad16975841ec4b8bed0553c9..c7f57d1f875c46f466d68c1ba55962551f5c8de5 100644
--- a/recipes/LibriSpeech/ASR/CTC/extra_requirements.txt
+++ b/recipes/LibriSpeech/ASR/CTC/extra_requirements.txt
@@ -1,3 +1,3 @@
+# k2 # It is better to install k2 with the procedure listed here: https://k2-fsa.github.io/k2/installation/from_wheels.html
+kaldilm==1.15
 kenlm
-pyctcdecode
-
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_average_downsampling.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_average_downsampling.yaml
index 09d29407f326ac5d666d6327c79d1278b1689985..fdbd7e86d8364dc02fecae7127a94927800c8b13 100644
--- a/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_average_downsampling.yaml
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_average_downsampling.yaml
@@ -1,5 +1,6 @@
 # ################################
 # Model: downsampling + wavlm + DNN + CTC
+# Decoding AM: Greedy for validation, and Beam search for testing
 # Augmentation: SpecAugment
 # Authors: Sung-Lin Yeh 2021
 # Salah Zaiem 2023
@@ -32,20 +33,19 @@ test_csv:
    - !ref <output_folder>/test-clean.csv
    - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 1
 lr: 0.9
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 #Downsampling parameters
 downsampling_factor: 2
 downsampling_kernel_size: 21
 upsampling: False
-use_language_modelling: False
-ngram_lm_path: /path/to/lm # Download from https://www.openslr.org/11/
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
@@ -63,7 +63,8 @@ valid_dataloader_opts:
 test_dataloader_opts:
    batch_size: !ref <test_batch_size>
 
-# Model parameters
+####################### Model Parameters #######################################
+
 activation: !name:torch.nn.LeakyReLU
 dnn_layers: 2
 dnn_neurons: 1024
@@ -72,8 +73,6 @@ freeze_wav2vec: True
 # Outputs
 ctc_neurons: 29
 output_neurons: 29  # Characters size, index(blank/eos/bos) = 0
-
-# Decoding parameters
 blank_index: 0
 
 #
@@ -82,17 +81,13 @@ blank_index: 0
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
-
 enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
    input_shape: [null, null, 1024]
    activation: !ref <activation>
    dnn_blocks: !ref <dnn_layers>
    dnn_neurons: !ref <dnn_neurons>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
    source: !ref <wav2vec2_hub>
    output_norm: True
    freeze_feature_extractor: True
@@ -154,8 +149,60 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.9
    patient: 0
 
+############################## Decoding ########################################
+
+test_beam_search:
+   beam_size: 200
+   topk: 1
+   blank_index: !ref <blank_index>
+   space_token: ' ' # make sure this is the same as the one used in the tokenizer
+   beam_prune_logp: -10.0
+   token_prune_min_logp: -5
+   prune_history: True
+   alpha: 0.5
+   beta: 1.5
+   # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
+   # It can either be a .bin or .arpa ; note: .arpa is much slower at loading
+   # If you don't want to use an LM, comment it out or set it to null
+   kenlm_model_path: null
+
 label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
 
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_conv_downsampling.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_conv_downsampling.yaml
index 38516590f2c6ae255c3750c6c0ce909292ca889e..1b84596dcd99a075af3d2266b44d04395031d619 100644
--- a/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_conv_downsampling.yaml
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_conv_downsampling.yaml
@@ -1,5 +1,6 @@
 # ################################
 # Model: downsampling + wavlm + DNN + CTC
+# Decoding AM: Greedy for validation, and Beam search for testing
 # Augmentation: SpecAugment
 # Authors: Sung-Lin Yeh 2021
 # Salah Zaiem 2023
@@ -33,20 +34,19 @@ test_csv:
    - !ref <output_folder>/test-clean.csv
    - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 1
 lr: 0.9
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 #Downsampling parameters
 downsampling_factor: 2
 downsampling_kernel_size: 81
 upsampling: False
-use_language_modelling: False
-ngram_lm_path: /path/to/lm # Download from https://www.openslr.org/11/
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
@@ -64,7 +64,8 @@ valid_dataloader_opts:
 test_dataloader_opts:
    batch_size: !ref <test_batch_size>
 
-# Model parameters
+####################### Model Parameters #######################################
+
 activation: !name:torch.nn.LeakyReLU
 dnn_layers: 2
 dnn_neurons: 1024
@@ -73,8 +74,6 @@ freeze_wav2vec: True
 # Outputs
 ctc_neurons: 29
 output_neurons: 29  # Characters size, index(blank/eos/bos) = 0
-
-# Decoding parameters
 blank_index: 0
 
 #
@@ -83,9 +82,6 @@ blank_index: 0
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
 
 enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
    input_shape: [null, null, 1024]
@@ -93,7 +89,7 @@ enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
    dnn_blocks: !ref <dnn_layers>
    dnn_neurons: !ref <dnn_neurons>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
    source: !ref <wav2vec2_hub>
    output_norm: True
    freeze_feature_extractor: True
@@ -156,8 +152,60 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.9
    patient: 0
 
+############################## Decoding ########################################
+
+test_beam_search:
+   beam_size: 200
+   topk: 1
+   blank_index: !ref <blank_index>
+   space_token: ' ' # make sure this is the same as the one used in the tokenizer
+   beam_prune_logp: -10.0
+   token_prune_min_logp: -5
+   prune_history: True
+   alpha: 0.5
+   beta: 1.5
+   # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
+   # It can either be a .bin or .arpa ; note: .arpa is much slower at loading
+   # If you don't want to use an LM, comment it out or set it to null
+   kenlm_model_path: null
+
 label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
 
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml
index ed562cba48703ad361c755e373cb9d011062d3cb..d0daf5b77759c5b87410893abd71e77ea00725c2 100644
--- a/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml
@@ -1,5 +1,6 @@
 ## ################################
 # Model: downsampling + wavlm + DNN + CTC
+# Decoding AM: Greedy for validation, and Beam search for testing
 # Augmentation: SpecAugment
 # Authors: Sung-Lin Yeh 2021
 # Salah Zaiem 2023
@@ -32,19 +33,18 @@ test_csv:
    - !ref <output_folder>/test-clean.csv
    - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 1
 lr: 0.9
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 #Downsampling parameters
 downsampling_factor: 3
 upsampling: True
-use_language_modelling: False
-ngram_lm_path: /path/to/lm # Download from https://www.openslr.org/11/
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
@@ -62,7 +62,8 @@ valid_dataloader_opts:
 test_dataloader_opts:
    batch_size: !ref <test_batch_size>
 
-# Model parameters
+####################### Model Parameters #######################################
+
 activation: !name:torch.nn.LeakyReLU
 dnn_layers: 2
 dnn_neurons: 1024
@@ -71,8 +72,6 @@ freeze_wav2vec: True
 # Outputs
 ctc_neurons: 58 # Twice bigger than the  number of characters for upsampling
 output_neurons: 29  # Characters size, index(blank/eos/bos) = 0
-
-# Decoding parameters
 blank_index: 0
 
 #
@@ -81,9 +80,6 @@ blank_index: 0
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
 
 enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
    input_shape: [null, null, 1024]
@@ -91,7 +87,7 @@ enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
    dnn_blocks: !ref <dnn_layers>
    dnn_neurons: !ref <dnn_neurons>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
    source: !ref <wav2vec2_hub>
    output_norm: True
    freeze_feature_extractor: True
@@ -153,8 +149,58 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.9
    patient: 0
 
+############################## Decoding ########################################
+
+test_beam_search:
+   beam_size: 200
+   topk: 1
+   blank_index: !ref <blank_index>
+   space_token: ' ' # make sure this is the same as the one used in the tokenizer
+   beam_prune_logp: -10.0
+   token_prune_min_logp: -5
+   prune_history: True
+   alpha: 0.5
+   beta: 1.5
+   # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
+   # It can either be a .bin or .arpa ; note: .arpa is much slower at loading
+   # If you don't want to use an LM, comment it out or set it to null
+   kenlm_model_path: null
+
 label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
 
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml
index 7a4075c63087209a29b7132422ccd7027003daaa..1d860a29f1cf0955b277e934748cde06db9d516f 100644
--- a/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml
@@ -1,7 +1,8 @@
 # ################################
 # Model: wav2vec2 + DNN + CTC
+# Decoding AM: Greedy for validation, and Beam search for testing
 # Augmentation: SpecAugment
-# Authors: Sung-Lin Yeh 2021
+# Authors: Sung-Lin Yeh 2021, Adel Moumen 2023
 # ################################
 
 # Seed needs to be set at top of yaml, before objects with parameters are made
@@ -31,12 +32,13 @@ test_csv:
    - !ref <output_folder>/test-clean.csv
    - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 1
 lr: 0.9
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
@@ -55,7 +57,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
    batch_size: !ref <test_batch_size>
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dnn_layers: 2
 dnn_neurons: 1024
@@ -63,29 +65,24 @@ freeze_wav2vec: True
 
 # Outputs
 output_neurons: 29  # BPE size, index(blank/eos/bos) = 0
-
-# Decoding parameters
 blank_index: 0
-use_language_modelling: False
-# ngram_lm_path: !PLACEHOLDER
 
 #
 # Functions and classes
 #
+
+label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
+
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
-
 enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
    input_shape: [null, null, 1024]
    activation: !ref <activation>
    dnn_blocks: !ref <dnn_layers>
    dnn_neurons: !ref <dnn_neurons>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
    source: !ref <wav2vec2_hub>
    output_norm: True
    freeze: !ref <freeze_wav2vec>
@@ -141,7 +138,58 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.9
    patient: 0
 
-label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Decoding ########################################
+
+# Decoding parameters
+test_beam_search:
+   beam_size: 143
+   topk: 1
+   blank_index: !ref <blank_index>
+   space_token: ' ' # make sure this is the same as the one used in the tokenizer
+   beam_prune_logp: -12.0
+   token_prune_min_logp: -1.2
+   prune_history: True
+   alpha: 0.8
+   beta: 1.2
+   # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
+   # It can either be a .bin or .arpa ; note: .arpa is much slower at loading
+   # If you don't want to use an LM, comment it out or set it to null
+   kenlm_model_path: null
+
+############################## Logging and Pretrainer ##########################
 
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_k2.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_k2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c97968f9a25e4203bf35e47e2e19c4e5aaa56acc
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_k2.yaml
@@ -0,0 +1,255 @@
+# ################################
+# Model: wav2vec2 + DNN + CTC + LM (k2)
+# Augmentation: SpecAugment
+#
+# This recipe trains a wav2vec2 model with a DNN and DWFST-based CTC loss.
+# To use this recipe you need to have the following:
+#  - A folder with the LibriSpeech dataset (see `datafolder`)
+#  - A folder with a small, and (optionally) a big LM (see `lm_dir`)
+#    These can be downloaded in ARPA format from: http://www.openslr.org/resources/11/.
+#  - A working installation of k2 (and kaldilm if you want to use ARPA LMs).
+#
+# Authors: Zeyu Zhao 2023
+#          Georgios Karakasidis 2023
+#          Pierre Champion 2023
+# ################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1111
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/train_wav2vec2_char_k2/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# URL for the biggest Fairseq english wav2vec2 model.
+wav2vec2_hub: facebook/wav2vec2-large-960h-lv60-self
+wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
+
+# Data files
+data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
+# noise/ris dataset will automatically be downloaded
+# data_folder_rirs: !ref <data_folder>
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: ["dev-clean", "dev-other"]
+test_splits: ["test-clean", "test-other"]
+skip_prep: False
+ckpt_interval_minutes: 25 # save checkpoint every N min
+train_csv: !ref <output_folder>/train.csv
+valid_csv: !ref <output_folder>/dev-clean.csv
+test_csv:
+   - !ref <output_folder>/test-clean.csv
+   - !ref <output_folder>/test-other.csv
+   - !ref <output_folder>/dev-clean.csv
+   - !ref <output_folder>/dev-other.csv
+
+# For k2 CTC training
+lang_dir: !ref <output_folder>/lang
+vocab_file: !ref <data_folder>/librispeech-vocab.txt
+sil_prob: 0.
+add_word_boundary: True
+# For k2 decoding
+test_search_beam: 32
+# Beam size (for decoding)
+test_output_beam: 8
+test_min_active_state: 300
+test_max_active_state: 3000
+# Acoustic scale (mutliplied by the log probs)
+ac_scale: 1.5
+compose_HL_with_G: False
+# 1best or whole-lattice-rescoring
+# decoding_method: whole-lattice-rescoring
+decoding_method: 1best
+# LM scale to be used for rescoring. Only used if rescoring
+rescoring_lm_scale: 0.4
+# This is where the 3gram and (optionally) 4gram LM are stored
+# They can be in either ARPA or FST format. If the former, then
+# the FST equivalent will be created in the same directory by
+# using kaldilm.
+lm_dir: !ref <output_folder>/lm
+# The ARPA LM files are located under the lm_dir.
+# - Use (recommended):
+#     - 3-gram_sb.arpa
+#     - 4-gram_sb.arpa
+#     To downloads speechbrain pretrained models (trained on train-960+librispeech-lm-norm.txt, 214k words)
+# - Use:
+#    - 3-gram.arpa
+#    - 3-gram.pruned.1e-7.arpa
+#    - 3-gram.pruned.3e-7.arpa
+#    - 4-gram.arpa
+#    To downloads http://www.openslr.org/resources/11/ pretrained models (trained on librispeech-lm-norm.txt, 200k words)
+# - Use another name for a model you trained yourself.
+#    If the arpa does not exist in the lm_dir, you'll need to train it yourself.
+#    Please see LibriSpeech/LM/README.md for instructions.
+# Using one of the above name will automatically download the corresponding model.
+# You can speciy a different name, but you'll need to make sure the file exists in the lm_dir.
+# Make sure to use enough RAM and CPUs as the conversion to FST can be quite demanding.
+G_arpa: 3-gram_sb.arpa
+G_rescoring_arpa: 4-gram_sb.arpa
+# caching: False
+
+# Training parameters
+number_of_epochs: 1
+lr: 0.9
+lr_wav2vec: 0.0001
+sorting: ascending  # only ascending and descending are supported currently
+precision: fp32
+sample_rate: 16000
+
+# With data_parallel batch_size is split into N jobs
+# With DDP batch_size is multiplied by N jobs
+# Must be 3 per GPU to fit 32GB of VRAM
+batch_size: 6
+test_batch_size: 1
+num_workers: 10
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+   num_workers: !ref <num_workers>
+
+valid_dataloader_opts:
+   batch_size: !ref <batch_size>
+   num_workers: !ref <num_workers>
+
+test_dataloader_opts:
+   batch_size: !ref <test_batch_size>
+   num_workers: !ref <num_workers>
+
+# Model parameters
+activation: !name:torch.nn.LeakyReLU
+dnn_layers: 2
+dnn_neurons: 1024
+freeze_wav2vec: True
+
+# Outputs
+output_neurons: 30  # BPE size, index(blank/eos/bos) = 0
+
+#
+# Functions and classes
+#
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+   limit: !ref <number_of_epochs>
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: !ref <drop_freq_low>
+   drop_freq_high: !ref <drop_freq_high>
+   drop_freq_count_low: !ref <drop_freq_count_low>
+   drop_freq_count_high: !ref <drop_freq_count_high>
+   drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: !ref <drop_chunk_length_low>
+   drop_length_high: !ref <drop_chunk_length_high>
+   drop_count_low: !ref <drop_chunk_count_low>
+   drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   parallel_augment: False
+   repeat_augment: 1
+   shuffle_augmentations: False
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+   input_shape: [null, null, 1024]
+   activation: !ref <activation>
+   dnn_blocks: !ref <dnn_layers>
+   dnn_neurons: !ref <dnn_neurons>
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.Wav2Vec2
+   source: !ref <wav2vec2_hub>
+   output_norm: True
+   freeze: !ref <freeze_wav2vec>
+   save_path: !ref <wav2vec2_folder>
+
+#####
+# Uncomment this block if you prefer to use a Fairseq pretrained model instead
+# of a HuggingFace one. Here, we provide an URL that is obtained from the
+# Fairseq github for the multilingual XLSR.
+#
+#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt
+#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
+#    pretrained_path: !ref <wav2vec2_url>
+#    output_norm: True
+#    freeze: False
+#    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <dnn_neurons>
+   n_neurons: !ref <output_neurons>
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+   apply_log: True
+
+ctc_cost: !name:speechbrain.k2_integration.losses.ctc_k2
+   reduction: mean
+   beam_size: 10
+
+modules:
+   wav2vec2: !ref <wav2vec2>
+   enc: !ref <enc>
+   ctc_lin: !ref <ctc_lin>
+
+model: !new:torch.nn.ModuleList
+   - [!ref <enc>, !ref <ctc_lin>]
+
+model_opt_class: !name:torch.optim.Adadelta
+   lr: !ref <lr>
+   rho: 0.95
+   eps: 1.e-8
+
+wav2vec_opt_class: !name:torch.optim.Adam
+   lr: !ref <lr_wav2vec>
+
+lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
+   initial_value: !ref <lr>
+   improvement_threshold: 0.0025
+   annealing_factor: 0.8
+   patient: 0
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
+   initial_value: !ref <lr_wav2vec>
+   improvement_threshold: 0.0025
+   annealing_factor: 0.9
+   patient: 0
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+   checkpoints_dir: !ref <save_folder>
+   recoverables:
+      wav2vec2: !ref <wav2vec2>
+      model: !ref <model>
+      scheduler_model: !ref <lr_annealing_model>
+      scheduler_wav2vec: !ref <lr_annealing_wav2vec>
+      counter: !ref <epoch_counter>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+   save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+   split_tokens: True
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_rnn_rescoring.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_rnn_rescoring.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1bb9b6af31dad6404525566e4001a0130785e89d
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_rnn_rescoring.yaml
@@ -0,0 +1,256 @@
+# ################################
+# Model: wav2vec2 + DNN + CTC
+# Decoding AM: Greedy for validation, and Rescoring + Beam search for testing.
+# Augmentation: SpecAugment
+# Authors: Adel Moumen 2023
+# ################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/train_wav2vec2_char_rnn_rescoring/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# URL for the biggest Fairseq english wav2vec2 model.
+wav2vec2_hub: facebook/wav2vec2-large-960h-lv60-self
+wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
+
+# Data files
+data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
+# noise/ris dataset will automatically be downloaded
+# data_folder_rirs: !ref <data_folder>
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: ["dev-clean"]
+test_splits: ["test-clean", "test-other"]
+skip_prep: False
+ckpt_interval_minutes: 25 # save checkpoint every N min
+train_csv: !ref <output_folder>/train.csv
+valid_csv: !ref <output_folder>/dev-clean.csv
+test_csv:
+   - !ref <output_folder>/test-clean.csv
+   - !ref <output_folder>/test-other.csv
+
+####################### Training Parameters ####################################
+
+number_of_epochs: 1
+lr: 0.9
+lr_wav2vec: 0.0001
+sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
+sample_rate: 16000
+
+# With data_parallel batch_size is split into N jobs
+# With DDP batch_size is multiplied by N jobs
+# Must be 3 per GPU to fit 32GB of VRAM
+batch_size: 6
+test_batch_size: 1
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+valid_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+test_dataloader_opts:
+   batch_size: !ref <test_batch_size>
+
+####################### Model Parameters #######################################
+
+activation: !name:torch.nn.LeakyReLU
+dnn_layers: 2
+dnn_neurons: 1024
+freeze_wav2vec: True
+
+# Outputs
+output_neurons: 29  # BPE size, index(blank/eos/bos) = 0
+
+
+pretrained_lm_tokenizer_path: speechbrain/asr-crdnn-rnnlm-librispeech
+
+
+#
+# Functions and classes
+#
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+   limit: !ref <number_of_epochs>
+
+
+enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+   input_shape: [null, null, 1024]
+   activation: !ref <activation>
+   dnn_blocks: !ref <dnn_layers>
+   dnn_neurons: !ref <dnn_neurons>
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+   source: !ref <wav2vec2_hub>
+   output_norm: True
+   freeze: !ref <freeze_wav2vec>
+   save_path: !ref <wav2vec2_folder>
+
+#####
+# Uncomment this block if you prefer to use a Fairseq pretrained model instead
+# of a HuggingFace one. Here, we provide an URL that is obtained from the
+# Fairseq github for the multilingual XLSR.
+#
+#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt
+#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
+#    pretrained_path: !ref <wav2vec2_url>
+#    output_norm: True
+#    freeze: False
+#    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <dnn_neurons>
+   n_neurons: !ref <output_neurons>
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+   apply_log: True
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+   blank_index: !ref <blank_index>
+
+modules:
+   wav2vec2: !ref <wav2vec2>
+   enc: !ref <enc>
+   ctc_lin: !ref <ctc_lin>
+
+model: !new:torch.nn.ModuleList
+   - [!ref <enc>, !ref <ctc_lin>]
+
+model_opt_class: !name:torch.optim.Adadelta
+   lr: !ref <lr>
+   rho: 0.95
+   eps: 1.e-8
+
+wav2vec_opt_class: !name:torch.optim.Adam
+   lr: !ref <lr_wav2vec>
+
+lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
+   initial_value: !ref <lr>
+   improvement_threshold: 0.0025
+   annealing_factor: 0.8
+   patient: 0
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
+   initial_value: !ref <lr_wav2vec>
+   improvement_threshold: 0.0025
+   annealing_factor: 0.9
+   patient: 0
+
+label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
+
+# This is the RNNLM that is used according to the Huggingface repository
+# NB: It has to match the pre-trained RNNLM!!
+lm_model: !new:speechbrain.lobes.models.RNNLM.RNNLM
+   output_neurons: 1000
+   embedding_dim: 128
+   activation: !name:torch.nn.LeakyReLU
+   dropout: 0.0
+   rnn_layers: 2
+   rnn_neurons: 2048
+   dnn_blocks: 1
+   dnn_neurons: 512
+   return_hidden: True  # For inference
+
+
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+############################## Decoding ########################################
+
+# Decoding parameters
+lm_weight: 0.5
+blank_index: 0
+
+# topk is the number of hypotheses that will be rescored in the rescorer
+# lowering this value might decrease the wer, but will increase speed.
+test_beam_search:
+   beam_size: 20
+   topk: 20
+   blank_index: !ref <blank_index>
+   space_token: ' ' # make sure this is the same as the one used in the tokenizer
+   beam_prune_logp: -12.0
+   token_prune_min_logp: -12.0
+   prune_history: False
+   alpha: 0.8
+   beta: 1.2
+
+rnnlm: !new:speechbrain.decoders.scorer.RNNLMRescorer
+   language_model: !ref <lm_model>
+   tokenizer: !ref <tokenizer>
+   bos_index: 0
+   eos_index: 0
+   pad_index: 0
+
+rescorer: !new:speechbrain.decoders.scorer.RescorerBuilder
+   rescorers: [!ref <rnnlm>]
+   weights:
+      rnnlm: !ref <lm_weight>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Logging and Pretrainer ##########################
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+   checkpoints_dir: !ref <save_folder>
+   recoverables:
+      wav2vec2: !ref <wav2vec2>
+      model: !ref <model>
+      scheduler_model: !ref <lr_annealing_model>
+      scheduler_wav2vec: !ref <lr_annealing_wav2vec>
+      counter: !ref <epoch_counter>
+      tokenizer: !ref <label_encoder>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+   save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+   split_tokens: True
+
+# The pretrainer allows a mapping between pretrained files and instances that
+# are declared in the yaml. E.g here, we will download the file lm.ckpt
+# and it will be loaded into "lm" which is pointing to the <lm_model> defined
+# before.
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+   collect_in: !ref <save_folder>
+   loadables:
+      lm: !ref <lm_model>
+      tokenizer: !ref <tokenizer>
+   paths:
+      lm: !ref <pretrained_lm_tokenizer_path>/lm.ckpt
+      tokenizer: !ref <pretrained_lm_tokenizer_path>/tokenizer.ckpt
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_transformer_rescoring.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_transformer_rescoring.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d806b20cfd5ebf408d0f8b60e83909521563288a
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_transformer_rescoring.yaml
@@ -0,0 +1,253 @@
+# ################################
+# Model: wav2vec2 + DNN + CTC
+# Decoding AM: Greedy for validation, and Rescoring + Beam search for testing
+# Augmentation: SpecAugment
+# Authors: Adel Moumen 2023
+# ################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/train_wav2vec2_char_transformer_rescoring/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# URL for the biggest Fairseq english wav2vec2 model.
+wav2vec2_hub: facebook/wav2vec2-large-960h-lv60-self
+wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
+
+# Data files
+data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
+# noise/ris dataset will automatically be downloaded
+# data_folder_rirs: !ref <data_folder>
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: ["dev-clean"]
+test_splits: ["test-clean", "test-other"]
+skip_prep: False
+ckpt_interval_minutes: 25 # save checkpoint every N min
+train_csv: !ref <output_folder>/train.csv
+valid_csv: !ref <output_folder>/dev-clean.csv
+test_csv:
+   - !ref <output_folder>/test-clean.csv
+   - !ref <output_folder>/test-other.csv
+
+####################### Training Parameters ####################################
+
+number_of_epochs: 1
+lr: 0.9
+lr_wav2vec: 0.0001
+sorting: ascending
+precision: fp32 # bf16, fp16 or fp32
+sample_rate: 16000
+
+# With data_parallel batch_size is split into N jobs
+# With DDP batch_size is multiplied by N jobs
+# Must be 3 per GPU to fit 32GB of VRAM
+batch_size: 6
+test_batch_size: 1
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+valid_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+test_dataloader_opts:
+   batch_size: !ref <test_batch_size>
+
+####################### Model Parameters #######################################
+activation: !name:torch.nn.LeakyReLU
+dnn_layers: 2
+dnn_neurons: 1024
+freeze_wav2vec: True
+
+# Outputs
+output_neurons: 29  # BPE size, index(blank/eos/bos) = 0
+
+
+pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech
+
+# This is the TransformerLM that is used according to the Huggingface repository
+# Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path
+# For more details about the model!
+# NB: It has to match the pre-trained TransformerLM!!
+lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length
+   vocab: 5000
+   d_model: 768
+   nhead: 12
+   num_encoder_layers: 12
+   num_decoder_layers: 0
+   d_ffn: 3072
+   dropout: 0.0
+   activation: !name:torch.nn.GELU
+   normalize_before: False
+
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+# Decoding parameters
+lm_weight: 0.5
+blank_index: 0
+
+#
+# Functions and classes
+#
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+   limit: !ref <number_of_epochs>
+
+
+enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
+   input_shape: [null, null, 1024]
+   activation: !ref <activation>
+   dnn_blocks: !ref <dnn_layers>
+   dnn_neurons: !ref <dnn_neurons>
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+   source: !ref <wav2vec2_hub>
+   output_norm: True
+   freeze: !ref <freeze_wav2vec>
+   save_path: !ref <wav2vec2_folder>
+
+#####
+# Uncomment this block if you prefer to use a Fairseq pretrained model instead
+# of a HuggingFace one. Here, we provide an URL that is obtained from the
+# Fairseq github for the multilingual XLSR.
+#
+#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt
+#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
+#    pretrained_path: !ref <wav2vec2_url>
+#    output_norm: True
+#    freeze: False
+#    save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <dnn_neurons>
+   n_neurons: !ref <output_neurons>
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+   apply_log: True
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+   blank_index: !ref <blank_index>
+
+modules:
+   wav2vec2: !ref <wav2vec2>
+   enc: !ref <enc>
+   ctc_lin: !ref <ctc_lin>
+
+model: !new:torch.nn.ModuleList
+   - [!ref <enc>, !ref <ctc_lin>]
+
+model_opt_class: !name:torch.optim.Adadelta
+   lr: !ref <lr>
+   rho: 0.95
+   eps: 1.e-8
+
+wav2vec_opt_class: !name:torch.optim.Adam
+   lr: !ref <lr_wav2vec>
+
+lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
+   initial_value: !ref <lr>
+   improvement_threshold: 0.0025
+   annealing_factor: 0.8
+   patient: 0
+
+lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
+   initial_value: !ref <lr_wav2vec>
+   improvement_threshold: 0.0025
+   annealing_factor: 0.9
+   patient: 0
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
+
+############################## Decoding ########################################
+
+# topk is the number of hypotheses that will be rescored in the rescorer
+# lowering this value might decrease the wer, but will increase speed.
+test_beam_search:
+   beam_size: 20
+   topk: 20
+   blank_index: !ref <blank_index>
+   space_token: ' ' # make sure this is the same as the one used in the tokenizer
+   beam_prune_logp: -12.0
+   token_prune_min_logp: -12.0
+   prune_history: False
+   alpha: 0.8
+   beta: 1.2
+
+transformerlm: !new:speechbrain.decoders.scorer.TransformerLMRescorer
+   language_model: !ref <lm_model>
+   tokenizer: !ref <tokenizer>
+   pad_index: 0
+   bos_index: 1
+   eos_index: 2
+
+rescorer: !new:speechbrain.decoders.scorer.RescorerBuilder
+   rescorers: [!ref <transformerlm>]
+   weights:
+      transformerlm: !ref <lm_weight>
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+   checkpoints_dir: !ref <save_folder>
+   recoverables:
+      wav2vec2: !ref <wav2vec2>
+      model: !ref <model>
+      scheduler_model: !ref <lr_annealing_model>
+      scheduler_wav2vec: !ref <lr_annealing_wav2vec>
+      counter: !ref <epoch_counter>
+      tokenizer: !ref <label_encoder>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+   save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+
+cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+   split_tokens: True
+
+# The pretrainer allows a mapping between pretrained files and instances that
+# are declared in the yaml. E.g here, we will download the file lm.ckpt
+# and it will be loaded into "lm" which is pointing to the <lm_model> defined
+# before.
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+   collect_in: !ref <save_folder>
+   loadables:
+      lm: !ref <lm_model>
+      tokenizer: !ref <tokenizer>
+   paths:
+      lm: !ref <pretrained_lm_tokenizer_path>/lm.ckpt
+      tokenizer: !ref <pretrained_lm_tokenizer_path>/tokenizer.ckpt
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_whisper_encoder.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_whisper_encoder.yaml
index 7d2f27b0b7574eec8e05384d765c4b82cfae1d5a..ba20bf2acfc55dc14ff95b92d070f69261321761 100644
--- a/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_whisper_encoder.yaml
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/train_hf_whisper_encoder.yaml
@@ -1,7 +1,8 @@
 # ################################
 # Model: Whisper (Encoder only) + DNN + CTC
+# Decoding AM: Greedy for validation, and Beam search for testing
 # Augmentation: TimeDomainSpecAugment
-# Authors: Titouan Parcollet 2022
+# Authors: Titouan Parcollet 2022, Adel Moumen 2023
 # ################################
 
 # Seed needs to be set at top of yaml, before objects with parameters are made
@@ -30,13 +31,14 @@ test_csv:
    - !ref <output_folder>/test-clean.csv
    - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
+
 number_of_epochs: 15
 warmup_steps: 1000 # We freeze whisper for 1000 steps to let the CTC adapt
 lr: 0.0008
 lr_whisper: 0.0001
 sorting: random
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # BPE parameters
@@ -46,7 +48,7 @@ character_coverage: 1.0
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
 batch_size: 6
-test_batch_size: 1
+test_batch_size: 8
 num_workers: 4
 
 # Dataloader options
@@ -60,7 +62,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
    batch_size: !ref <test_batch_size>
 
-# Model parameters
+####################### Model Parameters #######################################
 dnn_neurons: 1024
 freeze_whisper: False
 whisper_output_dim: 512
@@ -68,8 +70,6 @@ whisper_output_dim: 512
 
 # Outputs
 output_neurons: 29  # BPE size, index(blank/eos/bos) = 0
-
-# Decoding parameters
 blank_index: 0
 
 #
@@ -78,9 +78,6 @@ blank_index: 0
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
 
 enc: !new:speechbrain.nnet.containers.Sequential
    input_shape: [null, null, !ref <whisper_output_dim>]
@@ -104,7 +101,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
    bn3: !name:speechbrain.nnet.normalization.LayerNorm
    activation3: !new:torch.nn.LeakyReLU
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
    source: !ref <whisper_hub>
    freeze: !ref <freeze_whisper>
    save_path: !ref <whisper_folder>
@@ -146,6 +143,57 @@ lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.75
    patient: 0
 
+############################## Decoding ########################################
+
+test_beam_search:
+   beam_size: 143
+   topk: 1
+   blank_index: !ref <blank_index>
+   space_token: ' ' # make sure this is the same as the one used in the tokenizer
+   beam_prune_logp: -12.0
+   token_prune_min_logp: -1.2
+   prune_history: True
+   alpha: 0.8
+   beta: 1.2
+   # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
+   # It can either be a .bin or .arpa ; note: .arpa is much slower at loading
+   # If you don't want to use an LM, comment it out or set it to null
+   kenlm_model_path: null
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Logging and Pretrainer ##########################
 
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
diff --git a/recipes/LibriSpeech/ASR/CTC/hparams/train_sb_wav2vec.yaml b/recipes/LibriSpeech/ASR/CTC/hparams/train_sb_wav2vec.yaml
index 0af1c68f064337f1129e8d4a6549e551d0966f7e..1b281b35c5afc35265592e26d4f3ff04d03a8d5b 100644
--- a/recipes/LibriSpeech/ASR/CTC/hparams/train_sb_wav2vec.yaml
+++ b/recipes/LibriSpeech/ASR/CTC/hparams/train_sb_wav2vec.yaml
@@ -1,7 +1,8 @@
 # ################################
 # Model: wav2vec2 + DNN + CTC
+# Decoding AM: Greedy for validation, and Beam search for testing
 # Augmentation: SpecAugment
-# Authors: Sung-Lin Yeh 2021, Rudolf A. Braun 2022, Titouan Parcollet 2022
+# Authors: Sung-Lin Yeh 2021, Rudolf A. Braun 2022, Titouan Parcollet 2022, Adel Moumen 2023
 # ################################
 
 # Seed needs to be set at top of yaml, before objects with parameters are made
@@ -14,7 +15,7 @@ train_log: !ref <output_folder>/train_log.txt
 
 # Path of the SpeechBrain checkpoints containing the pretrained wav2vec2 model
 # It can be a local path or a HuggingFace hub containing the model
-wav2vec2_hub: !PLACEHOLDER
+wav2vec2_hub: facebook/wav2vec2-large-960h-lv60-self
 wav2vec_output_dim: 768 # This corresponds to the embedding size of the w2v2
 
 # Data files
@@ -32,11 +33,11 @@ test_csv:
    - !ref <output_folder>/test-clean.csv
    - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 lr: 0.0003
 lr_wav2vec: 0.00005
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 sorting: ascending
 num_workers: 2
@@ -57,18 +58,15 @@ valid_dataloader_opts:
 test_dataloader_opts:
    batch_size: !ref <test_batch_size>
 
-# Model parameters
+####################### Model Parameters #######################################
 dnn_activation: !new:torch.nn.LeakyReLU
 dnn_neurons: 1280
 dnn_dropout: 0.15
+freeze_wav2vec: False
 
 # Outputs
 output_neurons: 29  # BPE size, index(blank/eos/bos) = 0
-
-# Decoding parameters
 blank_index: 0
-use_language_modelling: False
-# ngram_lm_path: !PLACEHOLDER
 
 #
 # Functions and classes
@@ -76,9 +74,6 @@ use_language_modelling: False
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
 
 enc: !new:speechbrain.nnet.containers.Sequential
    input_shape: [null, null, !ref <wav2vec_output_dim>]
@@ -169,6 +164,58 @@ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.7
    patient: 0
 
+############################## Decoding ########################################
+
+test_beam_search:
+   beam_size: 200
+   topk: 1
+   blank_index: !ref <blank_index>
+   space_token: ' ' # make sure this is the same as the one used in the tokenizer
+   beam_prune_logp: -10.0
+   token_prune_min_logp: -5.0
+   prune_history: True
+   alpha: 0.8
+   beta: 1.2
+   # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
+   # It can either be a .bin or .arpa ; note: .arpa is much slower at loading
+   # If you don't want to use an LM, comment it out or set it to null
+   kenlm_model_path: null
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py b/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py
index 45a3a6fcafd9e803da6a6f9f53956876a6be0f4e..1f4ccdd2c6bdcdf62563989fb2c4d2e5de916fd6 100644
--- a/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py
+++ b/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py
@@ -1,7 +1,10 @@
 #!/usr/bin/env/python3
 """Recipe for training a wav2vec-based ctc ASR system with librispeech.
 The system employs wav2vec as its encoder. Decoding is performed with
-ctc greedy decoder.
+ctc greedy decoder during validation and a beam search with an optional
+language model during test. The test searcher can be chosen from the following
+options: CTCBeamSearcher, CTCPrefixBeamSearcher, TorchAudioCTCPrefixBeamSearcher.
+
 To run this recipe, do the following:
 > python train_with_wav2vec.py hparams/train_{hf,sb}_wav2vec.yaml
 The neural network is trained on CTC likelihood target and character units
@@ -16,8 +19,8 @@ Authors
  * Abdel Heba 2020
  * Peter Plantinga 2020
  * Samuele Cornell 2020
+ * Adel Moumen 2023
 """
-
 import os
 import sys
 import torch
@@ -37,18 +40,14 @@ class ASR(sb.Brain):
         batch = batch.to(self.device)
         wavs, wav_lens = batch.sig
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+
         # Downsample the inputs if specified
         if hasattr(self.modules, "downsampler"):
             wavs = self.modules.downsampler(wavs)
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
 
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
 
         # Forward pass
 
@@ -74,13 +73,24 @@ class ASR(sb.Brain):
             )
 
         p_ctc = self.hparams.log_softmax(logits)
-        if stage == sb.Stage.VALID or (
-            stage == sb.Stage.TEST and not self.hparams.use_language_modelling
-        ):
 
+        if stage == sb.Stage.VALID:
             p_tokens = sb.decoders.ctc_greedy_decode(
                 p_ctc, wav_lens, blank_id=self.hparams.blank_index
             )
+        elif stage == sb.Stage.TEST:
+            p_tokens = test_searcher(p_ctc, wav_lens)
+
+            candidates = []
+            scores = []
+
+            for batch in p_tokens:
+                candidates.append([hyp.text for hyp in batch])
+                scores.append([hyp.score for hyp in batch])
+
+            if hasattr(self.hparams, "rescorer"):
+                p_tokens, _ = self.hparams.rescorer.rescore(candidates, scores)
+
         return p_ctc, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
@@ -91,9 +101,15 @@ class ASR(sb.Brain):
         ids = batch.id
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
+        # Labels must be extended if parallel augmentation or concatenated
+        # augmentation was performed on the input (increasing the time dimension)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            (
+                tokens,
+                tokens_lens,
+            ) = self.hparams.wav_augment.replicate_multiple_labels(
+                tokens, tokens_lens
+            )
 
         loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
         loss = loss_ctc
@@ -104,63 +120,22 @@ class ASR(sb.Brain):
                 "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
                 for utt_seq in predicted_tokens
             ]
-            target_words = [wrd.split(" ") for wrd in batch.wrd]
-            self.wer_metric.append(ids, predicted_words, target_words)
-            self.cer_metric.append(ids, predicted_words, target_words)
-        if stage == sb.Stage.TEST:  # Language model decoding only used for test
-            if self.hparams.use_language_modelling:
-                predicted_words = []
-                for logs in p_ctc:
-                    text = decoder.decode(logs.detach().cpu().numpy())
-                    predicted_words.append(text.split(" "))
+        elif stage == sb.Stage.TEST:
+            if hasattr(self.hparams, "rescorer"):
+                predicted_words = [
+                    hyp[0].split(" ") for hyp in predicted_tokens
+                ]
             else:
                 predicted_words = [
-                    "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
-                    for utt_seq in predicted_tokens
+                    hyp[0].text.split(" ") for hyp in predicted_tokens
                 ]
+
+        if stage != sb.Stage.TRAIN:
             target_words = [wrd.split(" ") for wrd in batch.wrd]
             self.wer_metric.append(ids, predicted_words, target_words)
             self.cer_metric.append(ids, predicted_words, target_words)
-        return loss
 
-    def fit_batch(self, batch):
-        should_step = self.step % self.grad_accumulation_factor == 0
-
-        # Managing automatic mixed precision
-        if self.auto_mix_prec:
-            self.wav2vec_optimizer.zero_grad()
-            self.model_optimizer.zero_grad()
-            with torch.cuda.amp.autocast():
-                with self.no_sync():
-                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            with self.no_sync(not should_step):
-                self.scaler.scale(
-                    loss / self.grad_accumulation_factor
-                ).backward()
-            if should_step:
-                if not self.hparams.freeze_wav2vec:
-                    self.scaler.unscale_(self.wav2vec_optimizer)
-                self.scaler.unscale_(self.model_optimizer)
-                if self.check_gradients(loss):
-                    self.scaler.step(self.wav2vec_optimizer)
-                    self.scaler.step(self.model_optimizer)
-                self.scaler.update()
-                self.optimizer_step += 1
-        else:
-            with self.no_sync():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            (loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                if self.check_gradients(loss):
-                    self.wav2vec_optimizer.step()
-                    self.model_optimizer.step()
-                self.wav2vec_optimizer.zero_grad()
-                self.model_optimizer.zero_grad()
-                self.optimizer_step += 1
-
-        return loss.detach().cpu()
+        return loss
 
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
@@ -168,6 +143,10 @@ class ASR(sb.Brain):
             self.cer_metric = self.hparams.cer_computer()
             self.wer_metric = self.hparams.error_rate_computer()
 
+        if stage == sb.Stage.TEST:
+            if hasattr(self.hparams, "rescorer"):
+                self.hparams.rescorer.move_rescorers_to_device()
+
     def on_stage_end(self, stage, stage_loss, epoch):
         """Gets called at the end of an epoch."""
         # Compute/store important stats
@@ -215,7 +194,7 @@ class ASR(sb.Brain):
 
     def init_optimizers(self):
         "Initializes the wav2vec2 optimizer and model optimizer"
-        # Handling SpeechBrain vs HuggingFance pretrained models
+        # Handling SpeechBrain vs HuggingFace pretrained models
         if hasattr(self.modules, "extractor"):  # SpeechBrain pretrained model
             self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
                 self.modules.encoder_wrapper.parameters()
@@ -230,16 +209,20 @@ class ASR(sb.Brain):
             self.hparams.model.parameters()
         )
 
+        # save the optimizers in a dictionary
+        # the key will be used in `freeze_optimizers()`
+        self.optimizers_dict = {
+            "model_optimizer": self.model_optimizer,
+        }
+        if not self.hparams.freeze_wav2vec:
+            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
+
         if self.checkpointer is not None:
             self.checkpointer.add_recoverable(
                 "wav2vec_opt", self.wav2vec_optimizer
             )
             self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
 
-    def zero_grad(self, set_to_none=False):
-        self.wav2vec_optimizer.zero_grad(set_to_none)
-        self.model_optimizer.zero_grad(set_to_none)
-
 
 def dataio_prepare(hparams):
     """This function prepares the datasets to be used in the brain class.
@@ -340,7 +323,6 @@ if __name__ == "__main__":
     # CLI:
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -377,30 +359,6 @@ if __name__ == "__main__":
         hparams
     )
 
-    # Loading the labels for the LM decoding and the CTC decoder
-    if hasattr(hparams, "use_language_modelling"):
-        if hparams["use_language_modelling"]:
-            try:
-                from pyctcdecode import build_ctcdecoder
-            except ImportError:
-                err_msg = "Optional dependencies must be installed to use pyctcdecode.\n"
-                err_msg += "Install using `pip install kenlm pyctcdecode`.\n"
-                raise ImportError(err_msg)
-
-            ind2lab = label_encoder.ind2lab
-            labels = [ind2lab[x] for x in range(len(ind2lab))]
-            labels = [""] + labels[
-                1:
-            ]  # Replace the <blank> token with a blank character, needed for PyCTCdecode
-            decoder = build_ctcdecoder(
-                labels,
-                kenlm_model_path=hparams["ngram_lm_path"],  # .arpa or .bin
-                alpha=0.5,  # Default by KenLM
-                beta=1.0,  # Default by KenLM
-            )
-    else:
-        hparams["use_language_modelling"] = False
-
     # Trainer initialization
     asr_brain = ASR(
         modules=hparams["modules"],
@@ -412,12 +370,21 @@ if __name__ == "__main__":
     # We load the pretrained wav2vec2 model
     if "pretrainer" in hparams.keys():
         run_on_main(hparams["pretrainer"].collect_files)
-        hparams["pretrainer"].load_collected(asr_brain.device)
+        hparams["pretrainer"].load_collected()
 
     # We dynamicaly add the tokenizer to our brain class.
     # NB: This tokenizer corresponds to the one used for the LM!!
     asr_brain.tokenizer = label_encoder
 
+    ind2lab = label_encoder.ind2lab
+    vocab_list = [ind2lab[x] for x in range(len(ind2lab))]
+
+    from speechbrain.decoders.ctc import CTCBeamSearcher
+
+    test_searcher = CTCBeamSearcher(
+        **hparams["test_beam_search"], vocab_list=vocab_list,
+    )
+
     # Training
     asr_brain.fit(
         asr_brain.hparams.epoch_counter,
diff --git a/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py b/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py
new file mode 100644
index 0000000000000000000000000000000000000000..588ee2fa547845a41e63f486ee37423f634b7b66
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py
@@ -0,0 +1,487 @@
+#!/usr/bin/env/python3
+"""Recipe for training a wav2vec-based ctc ASR system with librispeech.
+The system employs wav2vec as its encoder. Decoding is performed with
+k2 through the use of a decoding graph and, optionally, a rescoring LM.
+To run this recipe, do the following:
+> python train_with_wav2vec.py hparams/train_{hf,sb}_wav2vec.yaml
+The neural network is trained on CTC likelihood target and character units
+are used as basic recognition tokens.
+
+Authors
+ * Pierre Champion 2023
+ * Zeyu Zhao 2023
+ * Georgios Karakasidis 2023
+ * Rudolf A Braun 2022
+ * Titouan Parcollet 2022
+ * Sung-Lin Yeh 2021
+ * Ju-Chieh Chou 2020
+ * Mirco Ravanelli 2020
+ * Abdel Heba 2020
+ * Peter Plantinga 2020
+ * Samuele Cornell 2020
+"""
+
+import os
+import sys
+import torch
+import logging
+import speechbrain as sb
+from speechbrain.utils.distributed import run_on_main, if_main_process
+from hyperpyyaml import load_hyperpyyaml
+from collections import defaultdict
+from pathlib import Path
+
+import speechbrain.k2_integration as sbk2
+
+logger = logging.getLogger(__name__)
+
+
+# Define training procedure
+class ASR(sb.Brain):
+    def compute_forward(self, batch, stage):
+        """Forward computations from the waveform batches to the output probabilities."""
+        batch = batch.to(self.device)
+        wavs, wav_lens = batch.sig
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+
+        # Downsample the inputs if specified
+        if hasattr(self.modules, "downsampler"):
+            wavs = self.modules.downsampler(wavs)
+
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+
+        # Forward pass
+
+        # Handling SpeechBrain vs HuggingFance pretrained models
+        if hasattr(self.modules, "extractor"):  # SpeechBrain pretrained model
+            latents = self.modules.extractor(wavs)
+            feats = self.modules.encoder_wrapper(latents, wav_lens=wav_lens)[
+                "embeddings"
+            ]
+        else:  # HuggingFace pretrained model
+            feats = self.modules.wav2vec2(wavs, wav_lens)
+
+        x = self.modules.enc(feats)
+
+        # Compute outputs
+        logits = self.modules.ctc_lin(x)
+
+        # Upsample the inputs if they have been highly downsampled
+        if hasattr(self.hparams, "upsampling") and self.hparams.upsampling:
+            logits = logits.view(
+                logits.shape[0], -1, self.hparams.output_neurons
+            )
+
+        p_ctc = self.hparams.log_softmax(logits)
+        paths = None
+        if stage == sb.Stage.VALID or stage == sb.Stage.TEST:
+            # Decode token terms to words
+            lattice = sbk2.lattice_decoder.get_lattice(
+                p_ctc,
+                wav_lens,
+                self.decoder["decoding_graph"],
+                search_beam=self.hparams.test_search_beam,
+                output_beam=self.hparams.test_output_beam,
+                ac_scale=self.hparams.ac_scale,
+                max_active_states=self.hparams.test_max_active_state,
+                min_active_states=self.hparams.test_min_active_state,
+            )
+        if stage == sb.Stage.VALID:
+            # 1best decoding for fast valid
+            paths = {"onebest": sbk2.lattice_decoder.one_best_decoding(lattice)}
+        elif stage == sb.Stage.TEST:
+            # user defined decoding for test
+            paths = self.decoder["decoding_method"](lattice)
+
+        return p_ctc, wav_lens, paths
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss (CTC+NLL) given predictions and targets."""
+
+        p_ctc, wav_lens, paths = predictions
+
+        # Sort batch to be descending by length of wav files, which is required
+        # by `k2.intersect_dense` called in `k2.ctc_loss`
+        indices = torch.argsort(wav_lens, descending=True)
+        p_ctc = p_ctc[indices]
+        wav_lens = wav_lens[indices]
+        texts = [batch.wrd[i] for i in indices]
+
+        is_training = stage == sb.Stage.TRAIN
+        loss = self.hparams.ctc_cost(
+            log_probs=p_ctc,
+            input_lens=wav_lens,
+            graph_compiler=self.graph_compiler,
+            texts=texts,
+            is_training=is_training,
+        )
+
+        if stage == sb.Stage.TEST or stage == sb.Stage.VALID:
+            for k, path in paths.items():
+                predicted_texts = sbk2.utils.lattice_paths_to_text(
+                    path, self.lexicon.word_table
+                )
+
+                predicted_words = [wrd.split(" ") for wrd in predicted_texts]
+                target_words = [wrd.split(" ") for wrd in batch.wrd]
+                self.wer_metrics[k].append(
+                    batch.id, predicted_words, target_words
+                )
+                self.cer_metrics[k].append(
+                    batch.id, predicted_words, target_words
+                )
+            # For TEST and VALID stages, the loss value is not exact.
+            # The <UNK> words have a target length (e.g., number of phones or characters) of 1.
+            # As such, sentences with <UNK> have a higher loss during CTC loss 'mean' reduction mode.
+            # It does not impact training.
+        return loss
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called at the beginning of each epoch. In this case,
+        it initializes the wer and cer metric watchers. If the decoding
+        method is whole-lattice-rescoring then a list of wer/cer metrics
+        will be initialized (for each lm scale). Otherwise, a single class
+        will be initialized for wer and cer, respectively.
+        """
+        if stage == sb.Stage.VALID:
+            logger.info("Valid stage")
+        if stage == sb.Stage.TEST:
+            logger.info("Test stage")
+        self.cer_metrics = defaultdict(self.hparams.cer_computer)
+        self.wer_metrics = defaultdict(self.hparams.error_rate_computer)
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of an epoch. During testing, its primary goal
+        is to summarize the WER/CER stats and save them in a file.
+        """
+        # Compute/store important stats
+        stage_stats = {"loss": stage_loss}
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_stats
+        else:
+            # Only report the fist config (first rescoring_lm_scale value)
+            stage_stats["CER"] = list(self.cer_metrics.values())[0].summarize(
+                "error_rate"
+            )
+            stage_stats["WER"] = list(self.wer_metrics.values())[0].summarize(
+                "error_rate"
+            )
+
+        # Perform end-of-iteration things, like annealing, logging, etc.
+        if stage == sb.Stage.VALID:
+            old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
+                stage_stats["loss"]
+            )
+            old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
+                stage_stats["loss"]
+            )
+            sb.nnet.schedulers.update_learning_rate(
+                self.model_optimizer, new_lr_model
+            )
+            sb.nnet.schedulers.update_learning_rate(
+                self.wav2vec_optimizer, new_lr_wav2vec
+            )
+            self.hparams.train_logger.log_stats(
+                stats_meta={
+                    "epoch": epoch,
+                    "lr_model": old_lr_model,
+                    "lr_wav2vec": old_lr_wav2vec,
+                },
+                train_stats=self.train_stats,
+                valid_stats=stage_stats,
+            )
+            self.checkpointer.save_and_keep_only(
+                meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
+            )
+        elif stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+            if if_main_process():
+                for k, stat in self.wer_metrics.items():
+                    with open(self.hparams.wer_file + f"_{k}.txt", "w") as w:
+                        stat.write_stats(w)
+
+    def init_optimizers(self):
+        "Initializes the wav2vec2 optimizer and model optimizer"
+        # Handling SpeechBrain vs HuggingFace pretrained models
+        if hasattr(self.modules, "extractor"):  # SpeechBrain pretrained model
+            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
+                self.modules.encoder_wrapper.parameters()
+            )
+
+        else:  # HuggingFace pretrained model
+            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
+                self.modules.wav2vec2.parameters()
+            )
+
+        self.model_optimizer = self.hparams.model_opt_class(
+            self.hparams.model.parameters()
+        )
+
+        # save the optimizers in a dictionary
+        # the key will be used in `freeze_optimizers()`
+        self.optimizers_dict = {
+            "model_optimizer": self.model_optimizer,
+        }
+        if not self.hparams.freeze_wav2vec:
+            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
+
+        if self.checkpointer is not None:
+            self.checkpointer.add_recoverable(
+                "wav2vec_opt", self.wav2vec_optimizer
+            )
+            self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
+
+
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions.
+    """
+    data_folder = hparams["data_folder"]
+
+    train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
+        csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
+    )
+
+    if hparams["sorting"] == "ascending":
+        # we sort training data to speed up training and get better results.
+        train_data = train_data.filtered_sorted(sort_key="duration")
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["train_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "descending":
+        train_data = train_data.filtered_sorted(
+            sort_key="duration", reverse=True
+        )
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["train_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "random":
+        pass
+
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+
+    valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
+        csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
+    )
+    valid_data = valid_data.filtered_sorted(sort_key="duration")
+
+    # test is separate
+    test_datasets = {}
+    for csv_file in hparams["test_csv"]:
+        name = Path(csv_file).stem
+        test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
+            csv_path=csv_file, replacements={"data_root": data_folder}
+        )
+        test_datasets[name] = test_datasets[name].filtered_sorted(
+            sort_key="duration"
+        )
+
+    datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
+
+    # 2. Define audio pipeline:
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
+
+    # 3. Define text pipeline:
+    @sb.utils.data_pipeline.takes("wrd")
+    @sb.utils.data_pipeline.provides("wrd", "char_list")
+    def text_pipeline(wrd):
+        yield wrd
+        char_list = list(wrd)
+        yield char_list
+
+    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
+
+    # 4. Set output:
+    sb.dataio.dataset.set_output_keys(
+        datasets, ["id", "sig", "wrd", "char_list"],
+    )
+
+    return train_data, valid_data, test_datasets
+
+
+if __name__ == "__main__":
+    # CLI:
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    # If distributed_launch=True then
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # env_corrupt is not supported with k2 yet
+    if hparams.get("env_corrupt", None):
+        raise NotImplementedError("env_corrupt is not supported with k2 yet")
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing Librispeech)
+    import librispeech_prepare
+
+    # multi-gpu (ddp) save data preparation
+    run_on_main(
+        librispeech_prepare.prepare_librispeech,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "tr_splits": hparams["train_splits"],
+            "dev_splits": hparams["dev_splits"],
+            "te_splits": hparams["test_splits"],
+            "save_folder": hparams["output_folder"],
+            "merge_lst": hparams["train_splits"],
+            "merge_name": "train.csv",
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Download the vocabulary file for librispeech
+    librispeech_prepare.download_librispeech_vocab_text(
+        destination=hparams["vocab_file"]
+    )
+
+    # here we create the datasets objects as well as tokenization and encoding
+    train_data, valid_data, test_datasets = dataio_prepare(hparams)
+
+    # Create the lexicon.txt for k2
+    run_on_main(
+        sbk2.lexicon.prepare_char_lexicon,
+        kwargs={
+            "lang_dir": hparams["lang_dir"],
+            "vocab_files": [hparams["vocab_file"]],
+            "extra_csv_files": [hparams["output_folder"] + "/train.csv"]
+            if not hparams["skip_prep"]
+            else [],
+            "add_word_boundary": hparams["add_word_boundary"],
+        },
+    )
+
+    caching = (
+        {"cache": False}
+        if "caching" in hparams and hparams["caching"] is False
+        else {}
+    )
+
+    # Create the lang directory for k2
+    run_on_main(
+        sbk2.prepare_lang.prepare_lang,
+        kwargs={
+            "lang_dir": hparams["lang_dir"],
+            "sil_prob": hparams["sil_prob"],
+            **caching,
+        },
+    )
+
+    # OpenSLR ngram models
+    if (
+        hparams["G_arpa"] + ".gz"
+        in librispeech_prepare.OPEN_SLR_11_NGRAM_MODELs
+        and hparams["G_rescoring_arpa"] + ".gz"
+        in librispeech_prepare.OPEN_SLR_11_NGRAM_MODELs
+        and (
+            hparams["compose_HL_with_G"]
+            or hparams["decoding_method"] == "whole-lattice-rescoring"
+        )
+    ):
+        librispeech_prepare.download_openslr_librispeech_lm(
+            destination=hparams["lm_dir"],
+            rescoring_lm=(
+                hparams["decoding_method"] == "whole-lattice-rescoring"
+            ),
+        )
+    # SB ngram models
+    elif (
+        "sb" in hparams["G_arpa"]
+        and "sb" in hparams["G_rescoring_arpa"]
+        and (
+            hparams["compose_HL_with_G"]
+            or hparams["decoding_method"] == "whole-lattice-rescoring"
+        )
+    ):
+        librispeech_prepare.download_sb_librispeech_lm(
+            destination=hparams["lm_dir"],
+            rescoring_lm=(
+                hparams["decoding_method"] == "whole-lattice-rescoring"
+            ),
+        )
+
+    # Trainer initialization
+    asr_brain = ASR(
+        modules=hparams["modules"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    lexicon = sbk2.lexicon.Lexicon(hparams["lang_dir"])
+    graph_compiler = sbk2.graph_compiler.CtcGraphCompiler(
+        lexicon, device=asr_brain.device,
+    )
+
+    decoding_params = {}
+    for param_name in (
+        "compose_HL_with_G",
+        "lm_dir",
+        "decoding_method",
+        "caching",
+        "G_arpa",
+        "G_rescoring_arpa",
+        "lang_dir",
+        "output_folder",
+        "rescoring_lm_scale",
+    ):
+        if param_name in hparams:
+            decoding_params[param_name] = hparams[param_name]
+
+    decoder = sbk2.lattice_decoder.get_decoding(
+        decoding_params, graph_compiler, device=asr_brain.device
+    )
+
+    # Add attributes to asr_brain
+    setattr(asr_brain, "lexicon", lexicon)
+    setattr(asr_brain, "graph_compiler", graph_compiler)
+    setattr(asr_brain, "decoder", decoder)
+
+    # We load the pretrained wav2vec2 model
+    if "pretrainer" in hparams.keys():
+        run_on_main(hparams["pretrainer"].collect_files)
+        hparams["pretrainer"].load_collected(asr_brain.device)
+
+    # Training
+    asr_brain.fit(
+        asr_brain.hparams.epoch_counter,
+        train_data,
+        valid_data,
+        train_loader_kwargs=hparams["train_dataloader_opts"],
+        valid_loader_kwargs=hparams["valid_dataloader_opts"],
+    )
+
+    # Testing
+    for k in test_datasets.keys():  # keys are test_clean, test_other etc
+        wer_dir = os.path.join(hparams["output_wer_folder"], f"metric_{k}")
+        os.makedirs(wer_dir, exist_ok=True)
+        exp = "HLG" if hparams["compose_HL_with_G"] else "HL"
+        asr_brain.hparams.wer_file = os.path.join(wer_dir, f"wer_{exp}")
+        asr_brain.evaluate(
+            test_datasets[k],
+            test_loader_kwargs=hparams["test_dataloader_opts"],
+            min_key="WER",
+        )
diff --git a/recipes/LibriSpeech/ASR/CTC/train_with_whisper.py b/recipes/LibriSpeech/ASR/CTC/train_with_whisper.py
index 4dbe01ab81d7a81a3b6bfc1d8788a1d771825a5b..d575265e86f0749f1cec8f30e26241b08c07d281 100644
--- a/recipes/LibriSpeech/ASR/CTC/train_with_whisper.py
+++ b/recipes/LibriSpeech/ASR/CTC/train_with_whisper.py
@@ -41,10 +41,9 @@ class ASR(sb.Brain):
         wavs, wav_lens = batch.sig
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
 
         # Forward pass
 
@@ -56,10 +55,12 @@ class ASR(sb.Brain):
         p_tokens = None
         logits = self.modules.ctc_lin(x)
         p_ctc = self.hparams.log_softmax(logits)
-        if stage != sb.Stage.TRAIN:
+        if stage == sb.Stage.VALID:
             p_tokens = sb.decoders.ctc_greedy_decode(
                 p_ctc, wav_lens, blank_id=self.hparams.blank_index
             )
+        elif stage == sb.Stage.TEST:
+            p_tokens = test_searcher(p_ctc, wav_lens)
 
         return p_ctc, wav_lens, p_tokens
 
@@ -71,16 +72,32 @@ class ASR(sb.Brain):
         ids = batch.id
         tokens, tokens_lens = batch.tokens
 
+        # Labels must be extended if parallel augmentation or concatenated
+        # augmentation was performed on the input (increasing the time dimension)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            (
+                tokens,
+                tokens_lens,
+            ) = self.hparams.wav_augment.replicate_multiple_labels(
+                tokens, tokens_lens
+            )
+
         loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
         loss = loss_ctc
 
-        if stage != sb.Stage.TRAIN:
+        if stage == sb.Stage.VALID:
 
             # Decode token terms to words
             predicted_words = self.tokenizer(
                 predicted_tokens, task="decode_from_list"
             )
 
+        elif stage == sb.Stage.TEST:
+            predicted_words = [
+                hyp[0].text.split(" ") for hyp in predicted_tokens
+            ]
+
+        if stage != sb.Stage.TRAIN:
             # Convert indices to words
             target_words = undo_padding(tokens, tokens_lens)
             target_words = self.tokenizer(target_words, task="decode_from_list")
@@ -90,45 +107,6 @@ class ASR(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        should_step = self.step % self.grad_accumulation_factor == 0
-
-        # Managing automatic mixed precision
-        if self.auto_mix_prec:
-            self.whisper_optimizer.zero_grad()
-            self.model_optimizer.zero_grad()
-            with torch.cuda.amp.autocast():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            self.scaler.scale(loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                self.scaler.unscale_(self.whisper_optimizer)
-                self.scaler.unscale_(self.model_optimizer)
-                if self.check_gradients(loss):
-                    if self.optimizer_step > self.hparams.warmup_steps:
-                        # Here we added a warmup to the CTC encoder to make sure that
-                        # it does not screw the whisper with too large gradients.
-                        self.scaler.step(self.whisper_optimizer)
-                    self.scaler.step(self.model_optimizer)
-                self.scaler.update()
-                self.optimizer_step += 1
-        else:
-            outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            (loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                if self.check_gradients(loss):
-                    # Here we added a warmup to the CTC encoder to make sure that
-                    # it does not screw the whisper with too large gradients.
-                    if self.optimizer_step > self.hparams.warmup_steps:
-                        self.whisper_optimizer.step()
-                    self.model_optimizer.step()
-                self.whisper_optimizer.zero_grad()
-                self.model_optimizer.zero_grad()
-                self.optimizer_step += 1
-
-        return loss.detach().cpu()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -190,12 +168,32 @@ class ASR(sb.Brain):
             self.hparams.model.parameters()
         )
 
+        # save the optimizers in a dictionary
+        # the key will be used in `freeze_optimizers()`
+        self.optimizers_dict = {
+            "model_optimizer": self.model_optimizer,
+            "whisper_optimizer": self.whisper_optimizer,
+        }
+
         if self.checkpointer is not None:
             self.checkpointer.add_recoverable(
                 "whisper_opt", self.whisper_optimizer
             )
             self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
 
+    def freeze_optimizers(self, optimizers):
+        """Freezes the wav2vec2 optimizer according to the warmup steps"""
+        valid_optimizers = {}
+        if not self.hparams.freeze_whisper:
+            # Here we added a warmup to the CTC encoder to make sure that
+            # it does not break the whisper with too large gradients.
+            if self.optimizer_step > self.hparams.warmup_steps:
+                valid_optimizers["whisper_optimizer"] = optimizers[
+                    "whisper_optimizer"
+                ]
+        valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+        return valid_optimizers
+
 
 def dataio_prepare(hparams, tokenizer):
     """This function prepares the datasets to be used in the brain class.
@@ -283,7 +281,6 @@ if __name__ == "__main__":
     # CLI:
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -345,6 +342,16 @@ if __name__ == "__main__":
     # NB: This tokenizer corresponds to the one used for the LM!!
     asr_brain.tokenizer = tokenizer
 
+    vocab_list = [
+        tokenizer.sp.id_to_piece(i) for i in range(tokenizer.sp.vocab_size())
+    ]
+
+    from speechbrain.decoders.ctc import CTCBeamSearcher
+
+    test_searcher = CTCBeamSearcher(
+        **hparams["test_beam_search"], vocab_list=vocab_list,
+    )
+
     # Training
     asr_brain.fit(
         asr_brain.hparams.epoch_counter,
diff --git a/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000.yaml b/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000.yaml
index bbebbdcfae24442e690b37661cdb33e7d35b61d5..3d0aaa200486f70767bec1c3e20c9058a0c978f1 100644
--- a/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000.yaml
+++ b/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000.yaml
@@ -27,21 +27,25 @@ pretrained_lm_tokenizer_path: speechbrain/asr-crdnn-rnnlm-librispeech
 
 # Data files
 data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
-# noise/ris dataset will automatically be downloaded
-data_folder_rirs: !ref <data_folder> # where to store noisy data for augment (change it if needed)
 
 train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
 dev_splits: ["dev-clean"]
 test_splits: ["test-clean", "test-other"]
 skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
-train_csv: !ref <output_folder>/train.csv
-valid_csv: !ref <output_folder>/dev-clean.csv
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev-clean.csv
 test_csv:
-   - !ref <output_folder>/test-clean.csv
-   - !ref <output_folder>/test-other.csv
+   - !ref <save_folder>/test-clean.csv
+   - !ref <save_folder>/test-other.csv
+
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
+####################### Training Parameters ####################################
 
-# Training parameters
 number_of_epochs: 15
 number_of_ctc_epochs: 5
 batch_size: 8
@@ -49,14 +53,19 @@ lr: 1.0
 ctc_weight: 0.5
 sorting: ascending
 dynamic_batching: False
+precision: fp32 # bf16, fp16 or fp32
 
 # dynamic batching parameters, if used
+feats_hop_size: 0.01
+max_batch_length: 20000 # in terms of frames
+shuffle: True
+batch_ordering: random
+num_buckets: 20
 dynamic_batch_sampler:
-   feats_hop_size: 0.01
-   max_batch_len: 20000 # in terms of frames
-   shuffle_ex: True
-   batch_ordering: random
-   num_buckets: 20
+   max_batch_length: !ref <max_batch_length>
+   shuffle: !ref <shuffle>
+   batch_ordering: !ref <batch_ordering>
+   num_buckets: !ref <num_buckets>
 
 # Feature parameters
 sample_rate: 16000
@@ -69,16 +78,20 @@ opt_class: !name:torch.optim.Adadelta
    eps: 1.e-8
 
 # Dataloader options
+num_workers: 4
 train_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
 valid_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
 test_dataloader_opts:
-   batch_size: !ref <batch_size>
+   batch_size: 1
+
+####################### Model Parameters #######################################
 
-# Model parameters
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -107,12 +120,13 @@ test_beam_size: 80
 eos_threshold: 1.5
 using_max_attn_shift: True
 max_attn_shift: 240
-lm_weight: 0.50
-ctc_weight_decode: 0.0
-coverage_penalty: 1.5
 temperature: 1.25
 temperature_lm: 1.25
 
+# Scoring parameters
+lm_weight: 0.5
+coverage_penalty: 1.5
+
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
 
@@ -124,17 +138,6 @@ compute_features: !new:speechbrain.lobes.features.Fbank
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mels>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-   openrir_folder: !ref <data_folder_rirs>
-   babble_prob: 0.0
-   reverb_prob: 0.0
-   noise_prob: 1.0
-   noise_snr_low: 0
-   noise_snr_high: 15
-
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
 
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
    input_shape: [null, null, !ref <n_mels>]
@@ -214,49 +217,56 @@ modules:
    ctc_lin: !ref <ctc_lin>
    seq_lin: !ref <seq_lin>
    normalize: !ref <normalize>
-   env_corrupt: !ref <env_corrupt>
    lm_model: !ref <lm_model>
 
 model: !new:torch.nn.ModuleList
    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
 
+############################## Decoding & optimiser ############################
+
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+   vocab_size: !ref <output_neurons>
+
+rnnlm_scorer: !new:speechbrain.decoders.scorer.RNNLMScorer
+   language_model: !ref <lm_model>
+   temperature: !ref <temperature_lm>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+   full_scorers: [!ref <rnnlm_scorer>,
+                  !ref <coverage_scorer>]
+   weights:
+      rnnlm: !ref <lm_weight>
+      coverage: !ref <coverage_penalty>
+
+# Search
 valid_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <valid_beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
    temperature: !ref <temperature>
 
-test_search: !new:speechbrain.decoders.S2SRNNBeamSearchLM
+test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
-   language_model: !ref <lm_model>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <test_beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
-   lm_weight: !ref <lm_weight>
-   ctc_weight: !ref <ctc_weight_decode>
    temperature: !ref <temperature>
-   temperature_lm: !ref <temperature_lm>
+   scorer: !ref <scorer>
 
 lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
@@ -264,6 +274,57 @@ lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.8
    patient: 0
 
+############################## Augmentations ###################################
+
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+   URL: !ref <NOISE_DATASET_URL>
+   dest_folder: !ref <data_folder_noise>
+   ext: wav
+   csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+   csv_file: !ref <noise_annotation>
+   snr_low: 0
+   snr_high: 15
+   noise_sample_rate: !ref <sample_rate>
+   clean_sample_rate: !ref <sample_rate>
+   num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <add_noise>,
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000_sligru.yaml b/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000_sligru.yaml
index 15cded354480056cdf9cfefcfa1e8482e79b2494..355c49d36be26aeed18eac40ef91ea9249d13905 100644
--- a/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000_sligru.yaml
+++ b/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000_sligru.yaml
@@ -27,21 +27,25 @@ pretrained_lm_tokenizer_path: speechbrain/asr-crdnn-rnnlm-librispeech
 
 # Data files
 data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
-# noise/ris dataset will automatically be downloaded
-data_folder_rirs: !ref <data_folder> # where to store noisy data for augment (change it if needed)
 
 train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
 dev_splits: ["dev-clean"]
 test_splits: ["test-clean", "test-other"]
 skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
-train_csv: !ref <output_folder>/train.csv
-valid_csv: !ref <output_folder>/dev-clean.csv
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev-clean.csv
 test_csv:
-   - !ref <output_folder>/test-clean.csv
-   - !ref <output_folder>/test-other.csv
+   - !ref <save_folder>/test-clean.csv
+   - !ref <save_folder>/test-other.csv
+
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
+####################### Training Parameters ####################################
 
-# Training parameters
 number_of_epochs: 15
 number_of_ctc_epochs: 15
 batch_size: 24
@@ -49,14 +53,19 @@ lr: 1.0
 ctc_weight: 0.5
 sorting: ascending
 dynamic_batching: False
+precision: fp32 # bf16, fp16 or fp32
 
 # dynamic batching parameters, if used
+feats_hop_size: 0.01
+max_batch_length: 20000 # in terms of frames
+shuffle: True
+batch_ordering: random
+num_buckets: 20
 dynamic_batch_sampler:
-   feats_hop_size: 0.01
-   max_batch_len: 20000 # in terms of frames
-   shuffle_ex: True
-   batch_ordering: random
-   num_buckets: 20
+   max_batch_length: !ref <max_batch_length>
+   shuffle: !ref <shuffle>
+   batch_ordering: !ref <batch_ordering>
+   num_buckets: !ref <num_buckets>
 
 # Feature parameters
 sample_rate: 16000
@@ -69,16 +78,20 @@ opt_class: !name:torch.optim.Adadelta
    eps: 1.e-8
 
 # Dataloader options
+num_workers: 4
 train_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
 valid_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
 test_dataloader_opts:
-   batch_size: !ref <batch_size>
+   batch_size: 1
+
+####################### Model Parameters #######################################
 
-# Model parameters
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -108,7 +121,6 @@ eos_threshold: 1.5
 using_max_attn_shift: True
 max_attn_shift: 240
 lm_weight: 0.50
-ctc_weight_decode: 0.0
 coverage_penalty: 1.5
 temperature: 1.25
 temperature_lm: 1.25
@@ -124,18 +136,6 @@ compute_features: !new:speechbrain.lobes.features.Fbank
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mels>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-   openrir_folder: !ref <data_folder_rirs>
-   babble_prob: 0.0
-   reverb_prob: 0.0
-   noise_prob: 1.0
-   noise_snr_low: 0
-   noise_snr_high: 15
-
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
-
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
    input_shape: [null, null, !ref <n_mels>]
    activation: !ref <activation>
@@ -214,49 +214,56 @@ modules:
    ctc_lin: !ref <ctc_lin>
    seq_lin: !ref <seq_lin>
    normalize: !ref <normalize>
-   env_corrupt: !ref <env_corrupt>
    lm_model: !ref <lm_model>
 
 model: !new:torch.nn.ModuleList
    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
 
+############################## Decoding & optimiser ############################
+
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+   vocab_size: !ref <output_neurons>
+
+rnnlm_scorer: !new:speechbrain.decoders.scorer.RNNLMScorer
+   language_model: !ref <lm_model>
+   temperature: !ref <temperature_lm>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+   full_scorers: [!ref <rnnlm_scorer>,
+                  !ref <coverage_scorer>]
+   weights:
+      rnnlm: !ref <lm_weight>
+      coverage: !ref <coverage_penalty>
+
+# Search
 valid_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <valid_beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
    temperature: !ref <temperature>
 
-test_search: !new:speechbrain.decoders.S2SRNNBeamSearchLM
+test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
-   language_model: !ref <lm_model>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <test_beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
-   lm_weight: !ref <lm_weight>
-   ctc_weight: !ref <ctc_weight_decode>
    temperature: !ref <temperature>
-   temperature_lm: !ref <temperature_lm>
+   scorer: !ref <scorer>
 
 lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
@@ -264,6 +271,57 @@ lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.8
    patient: 0
 
+############################## Augmentations ###################################
+
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+   URL: !ref <NOISE_DATASET_URL>
+   dest_folder: !ref <data_folder_noise>
+   ext: wav
+   csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+   csv_file: !ref <noise_annotation>
+   snr_low: 0
+   snr_high: 15
+   noise_sample_rate: !ref <sample_rate>
+   clean_sample_rate: !ref <sample_rate>
+   num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <add_noise>,
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml b/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml
index 9b1bf36ebff6eff18fb64b6778ca1b9fe651460d..3046dfea80643d7c90025e21e33a830f6b615613 100644
--- a/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml
+++ b/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml
@@ -28,20 +28,25 @@ pretrained_lm_tokenizer_path: speechbrain/asr-crdnn-transformerlm-librispeech
 
 # Data files
 data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
-# noise/ris dataset will automatically be downloaded
-data_folder_rirs: !ref <data_folder>
+
 train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
 dev_splits: ["dev-clean"]
 test_splits: ["test-clean", "test-other"]
 skip_prep: False
 ckpt_interval_minutes: 25 # save checkpoint every N min
-train_csv: !ref <output_folder>/train.csv
-valid_csv: !ref <output_folder>/dev-clean.csv
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev-clean.csv
 test_csv:
-   - !ref <output_folder>/test-clean.csv
-   - !ref <output_folder>/test-other.csv
+   - !ref <save_folder>/test-clean.csv
+   - !ref <save_folder>/test-other.csv
+
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
+####################### Training Parameters ####################################
 
-# Training parameters
 number_of_epochs: 25
 number_of_ctc_epochs: 25
 batch_size: 8
@@ -49,14 +54,19 @@ lr: 1.0
 ctc_weight: 0.5
 sorting: ascending
 dynamic_batching: False
+precision: fp32 # bf16, fp16 or fp32
 
 # dynamic batching parameters, if used
+feats_hop_size: 0.01
+max_batch_length: 20000 # in terms of frames
+shuffle: True
+batch_ordering: random
+num_buckets: 20
 dynamic_batch_sampler:
-   feats_hop_size: 0.01
-   max_batch_len: 20000 # in terms of frames
-   shuffle_ex: True
-   batch_ordering: random
-   num_buckets: 20
+   max_batch_length: !ref <max_batch_length>
+   shuffle: !ref <shuffle>
+   batch_ordering: !ref <batch_ordering>
+   num_buckets: !ref <num_buckets>
 
 # Feature parameters
 sample_rate: 16000
@@ -69,16 +79,20 @@ opt_class: !name:torch.optim.Adadelta
    eps: 1.e-8
 
 # Dataloader options
+num_workers: 4
 train_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
 valid_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
 test_dataloader_opts:
    batch_size: 1
 
-# Model parameters
+####################### Model Parameters #######################################
+
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -104,15 +118,18 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 valid_beam_size: 20
 test_beam_size: 40
+using_eos_threshold: True
 eos_threshold: 1.5
 using_max_attn_shift: True
 max_attn_shift: 300
-lm_weight: 0.80
-ctc_weight_decode: 0.40
+lm_weight: 0.8
+temperature: 1.0
 ctc_window_size: 200
+
+# Scoring parameters
+ctc_weight_decode: 0.40
 coverage_penalty: 1.5
-temperature: 1.0
-temperature_lm: 1.0
+
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
@@ -125,17 +142,6 @@ compute_features: !new:speechbrain.lobes.features.Fbank
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mels>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-   openrir_folder: !ref <data_folder_rirs>
-   babble_prob: 0.0
-   reverb_prob: 0.0
-   noise_prob: 1.0
-   noise_snr_low: 0
-   noise_snr_high: 15
-
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
 
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
    input_shape: [null, null, !ref <n_mels>]
@@ -216,50 +222,71 @@ modules:
    ctc_lin: !ref <ctc_lin>
    seq_lin: !ref <seq_lin>
    normalize: !ref <normalize>
-   env_corrupt: !ref <env_corrupt>
    lm_model: !ref <lm_model>
 
 model: !new:torch.nn.ModuleList
    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
 
+############################## Decoding & optimiser ############################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+   eos_index: !ref <eos_index>
+   blank_index: !ref <blank_index>
+   ctc_fc: !ref <ctc_lin>
+   ctc_window_size: !ref <ctc_window_size>
+
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+   vocab_size: !ref <output_neurons>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+   language_model: !ref <lm_model>
+
+valid_scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+   full_scorers: [!ref <coverage_scorer>]
+   weights:
+      coverage: !ref <coverage_penalty>
+
+test_scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+   full_scorers: [
+      !ref <transformerlm_scorer>,
+      !ref <coverage_scorer>]
+   partial_scorers: [!ref <ctc_scorer>]
+   weights:
+      transformerlm: !ref <lm_weight>
+      coverage: !ref <coverage_penalty>
+      ctc: !ref <ctc_weight_decode>
+
+# Search
 valid_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <valid_beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
    temperature: !ref <temperature>
+   scorer: !ref <valid_scorer>
 
-test_search: !new:speechbrain.decoders.S2SRNNBeamSearchTransformerLM
+test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
-   language_model: !ref <lm_model>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <test_beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
-   lm_weight: !ref <lm_weight>
-   ctc_weight: !ref <ctc_weight_decode>
-   ctc_window_size: !ref <ctc_window_size>
+   using_eos_threshold: !ref <using_eos_threshold>
    temperature: !ref <temperature>
-   temperature_lm: !ref <temperature_lm>
+   scorer: !ref <test_scorer>
 
 lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
@@ -267,6 +294,57 @@ lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    annealing_factor: 0.8
    patient: 0
 
+############################## Augmentations ###################################
+
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+   URL: !ref <NOISE_DATASET_URL>
+   dest_folder: !ref <data_folder_noise>
+   ext: wav
+   csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+   csv_file: !ref <noise_annotation>
+   snr_low: 0
+   snr_high: 15
+   noise_sample_rate: !ref <sample_rate>
+   clean_sample_rate: !ref <sample_rate>
+   num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <add_noise>,
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
+
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
diff --git a/recipes/LibriSpeech/ASR/seq2seq/train.py b/recipes/LibriSpeech/ASR/seq2seq/train.py
index 4442598514256c8240206649ac8e5acd8ea5f3fc..7f535100876f9b05a935b30af641bc66f47dc168 100644
--- a/recipes/LibriSpeech/ASR/seq2seq/train.py
+++ b/recipes/LibriSpeech/ASR/seq2seq/train.py
@@ -3,37 +3,28 @@
 The system employs an encoder, a decoder, and an attention mechanism
 between them. Decoding is performed with beamsearch coupled with a neural
 language model.
-
 To run this recipe, do the following:
 > python train.py hparams/train_BPE1000.yaml
-
 With the default hyperparameters, the system employs a CRDNN encoder.
 The decoder is based on a standard  GRU. Beamsearch coupled with a RNN
 language model is used  on the top of decoder probabilities.
-
 The neural network is trained on both CTC and negative-log likelihood
 targets and sub-word units estimated with Byte Pairwise Encoding (BPE)
 are used as basic recognition tokens. Training is performed on the full
 LibriSpeech dataset (960 h).
-
 The experiment file is flexible enough to support a large variety of
 different systems. By properly changing the parameter files, you can try
 different encoders, decoders, tokens (e.g, characters instead of BPE),
 training split (e.g, train-clean 100 rather than the full one), and many
 other possible variations.
-
 This recipe assumes that the tokenizer and the LM are already trained.
 To avoid token mismatches, the tokenizer used for the acoustic model is
 the same use for the LM.  The recipe downloads the pre-trained tokenizer
 and LM.
-
 If you would like to train a full system from scratch do the following:
 1- Train a tokenizer (see ../../Tokenizer)
 2- Train a language model (see ../../LM)
 3- Train the acoustic model (with this code).
-
-
-
 Authors
  * Ju-Chieh Chou 2020
  * Mirco Ravanelli 2020
@@ -63,16 +54,10 @@ class ASR(sb.Brain):
         tokens_bos, _ = batch.tokens_bos
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
 
         # Forward pass
         feats = self.hparams.compute_features(wavs)
@@ -86,45 +71,43 @@ class ASR(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
+        p_ctc, p_tokens = None, None
         if stage == sb.Stage.TRAIN:
             current_epoch = self.hparams.epoch_counter.current
             if current_epoch <= self.hparams.number_of_ctc_epochs:
                 # Output layer for ctc log-probabilities
                 logits = self.modules.ctc_lin(x)
                 p_ctc = self.hparams.log_softmax(logits)
-                return p_ctc, p_seq, wav_lens
-            else:
-                return p_seq, wav_lens
         else:
             if stage == sb.Stage.VALID:
-                p_tokens, scores = self.hparams.valid_search(x, wav_lens)
+                # Get token strings from index prediction
+                p_tokens, _, _, _ = self.hparams.valid_search(x, wav_lens)
             else:
-                p_tokens, scores = self.hparams.test_search(x, wav_lens)
-            return p_seq, wav_lens, p_tokens
+                p_tokens, _, _, _ = self.hparams.test_search(x, wav_lens)
+
+        return p_ctc, p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (CTC+NLL) given predictions and targets."""
 
         current_epoch = self.hparams.epoch_counter.current
-        if stage == sb.Stage.TRAIN:
-            if current_epoch <= self.hparams.number_of_ctc_epochs:
-                p_ctc, p_seq, wav_lens = predictions
-            else:
-                p_seq, wav_lens = predictions
-        else:
-            p_seq, wav_lens, predicted_tokens = predictions
+        p_ctc, p_seq, wav_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
+        # Labels must be extended if parallel augmentation or concatenated
+        # augmentation was performed on the input (increasing the time dimension)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            (
+                tokens,
+                tokens_lens,
+                tokens_eos,
+                tokens_eos_lens,
+            ) = self.hparams.wav_augment.replicate_multiple_labels(
+                tokens, tokens_lens, tokens_eos, tokens_eos_lens
             )
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
 
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
@@ -155,23 +138,6 @@ class ASR(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -301,26 +267,18 @@ def dataio_prepare(hparams):
         from speechbrain.dataio.batch import PaddedBatch  # noqa
 
         dynamic_hparams = hparams["dynamic_batch_sampler"]
-        hop_size = dynamic_hparams["feats_hop_size"]
-
-        num_buckets = dynamic_hparams["num_buckets"]
+        hop_size = hparams["feats_hop_size"]
 
         train_batch_sampler = DynamicBatchSampler(
             train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
             length_func=lambda x: x["duration"] * (1 / hop_size),
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            **dynamic_hparams,
         )
 
         valid_batch_sampler = DynamicBatchSampler(
             valid_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
             length_func=lambda x: x["duration"] * (1 / hop_size),
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            **dynamic_hparams,
         )
 
     return (
@@ -337,7 +295,6 @@ if __name__ == "__main__":
     # CLI:
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -368,6 +325,7 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
 
     # here we create the datasets objects as well as tokenization and encoding
     (
@@ -381,7 +339,7 @@ if __name__ == "__main__":
     # We download the pretrained LM from HuggingFace (or elsewhere depending on
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Trainer initialization
     asr_brain = ASR(
diff --git a/recipes/LibriSpeech/ASR/transducer/README.md b/recipes/LibriSpeech/ASR/transducer/README.md
index f3e6d3d1eaef8ec41c2782c5e32f9c29124aff98..b55a6fd90c93395b50b3525cc4fcde16cd540ba9 100644
--- a/recipes/LibriSpeech/ASR/transducer/README.md
+++ b/recipes/LibriSpeech/ASR/transducer/README.md
@@ -20,13 +20,61 @@ pip install numba
 python train.py hparams/conformer_transducer.yaml
 ```
 
+## Precision Notes
+If your GPU effectively supports fp16 (half-precision) computations, it is recommended to execute the training script with the `--precision=fp16` (or `--precision=bf16`) option.
+Enabling half precision can significantly reduce the peak VRAM requirements. For example, in the case of the Conformer Transducer recipe trained with Librispeech, the peak VRAM decreases from 39GB to 12GB when using fp16.
+According to our tests, the performance is not affected.
+
 # Librispeech Results
 
 Dev. clean is evaluated with Greedy Decoding while the test sets are using Greedy Decoding OR a RNNLM + Beam Search.
+Evaluation is performed in fp32. However, we found that during inference, fp16 or bf16 autocast has very little incidence on the WER.
+
+| Release | Hyperparams file | Train precision | Dev-clean Greedy | Test-clean Greedy | Test-other Greedy | Test-clean BS+RNNLM | Test-other BS+RNNLM | Model link | GPUs |
+|:-------------:|:---------------------------:|:-:| :------:| :-----------:| :------------------:| :------------------:| :------------------:| :--------:| :-----------:|
+| 2023-12-12 | conformer_transducer.yaml `streaming: True` | bf16 | 2.56% | 2.72% | 6.47% | \* | \* | https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing | [4x A100SXM4 40GB](https://docs.alliancecan.ca/wiki/Narval/en) |
+
+<sub>\*: not evaluated due to performance issues, see [issue #2301](https://github.com/speechbrain/speechbrain/issues/2301)</sub>
+
+## Streaming model
+
+### WER vs chunk size & left context
+
+The following matrix presents the Word Error Rate (WER%) achieved on LibriSpeech
+`test-clean` with various chunk sizes (in ms) and left context sizes (in # of
+chunks).
+
+The relative difference is not trivial to interpret, because we are not testing
+against a continuous stream of speech, but rather against utterances of various
+lengths. This tends to bias results in favor of larger chunk sizes.
+
+The chunk size might not accurately represent expected latency due to slight
+padding differences in streaming contexts.
+
+The left chunk size is not representative of the receptive field of the model.
+Because the model caches the streaming context at different layers, the model
+may end up forming indirect dependencies to audio many seconds ago.
+
+|       | full | cs=32 (1280ms) | 24 (960ms) | 16 (640ms) | 12 (480ms) | 8 (320ms) |
+|:-----:|:----:|:-----:|:-----:|:-----:|:-----:|:-----:|
+| full  | 2.72%| -     | -     | -     | -     | -     |
+| lc=32 | -    | 3.09% | 3.07% | 3.26% | 3.31% | 3.44% |
+| 16    | -    | 3.10% | 3.07% | 3.27% | 3.32% | 3.50% |
+| 8     | -    | 3.10% | 3.11% | 3.31% | 3.39% | 3.62% |
+| 4     | -    | 3.12% | 3.13% | 3.37% | 3.51% | 3.80% |
+| 2     | -    | 3.19% | 3.24% | 3.50% | 3.79% | 4.38% |
+
+### Inference
+
+Once your model is trained, you need a few manual steps in order to use it with the high-level streaming interfaces (`speechbrain.inference.ASR.StreamingASR`):
+
+1. Create a new directory where you want to store the model.
+2. Copy `results/conformer_transducer/<seed>/lm.ckpt` (optional; currently, for streaming rescoring LMs might be unsupported) and `tokenizer.ckpt` to that directory.
+3. Copy `results/conformer_transducer/<seed>/save/CKPT+????/model.ckpt` and `normalizer.ckpt` to that directory.
+4. Copy your hyperparameters file to that directory. Uncomment the streaming specific keys and remove any training-specific keys. Alternatively, grab the inference hyperparameters YAML for this model from HuggingFace and adapt it to any changes you may have done.
+5. You can now instantiate a `StreamingASR` with your model using `StreamingASR.from_hparams("/path/to/model/")`.
 
-| Release | Hyperparams file | Dev-clean Greedy | Test-clean Greedy | Test-other Greedy | Test-clean BS+RNNLM | Test-other BS+RNNLM | Model link | GPUs |
-|:-------------:|:---------------------------:| :------:| :-----------:| :------------------:| :------------------:| :------------------:| :--------:| :-----------:|
-| 2023-07-19 | conformer_transducer.yaml | 2.62 | 2.84 | 6.98 | 2.62 | 6.31 | https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing | 3x3090 24GB |
+The contents of that directory may be uploaded as a HuggingFace model, in which case the model source path can just be specified as `youruser/yourmodel`.
 
 # **About SpeechBrain**
 - Website: https://speechbrain.github.io/
diff --git a/recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml b/recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml
index 288c517d9d1053d24a66f93479c3af6cad186ddd..6e7475403a53cb99939d66d6d90abe6ca0c1c682 100644
--- a/recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml
+++ b/recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml
@@ -10,7 +10,7 @@
 
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 3407
-__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
 output_folder: !ref results/conformer_transducer_large/<seed>
 output_wer_folder: !ref <output_folder>/
 save_folder: !ref <output_folder>/save
@@ -29,7 +29,6 @@ data_folder: !PLACEHOLDER # e.g, /localscratch/LibriSpeech
 # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES
 # then data_folder_rirs should be /localscratch/xxx_corpus
 # otherwise the dataset will automatically be downloaded
-data_folder_rirs: !ref <data_folder>
 train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
 dev_splits: ["dev-clean"]
 test_splits: ["test-clean", "test-other"]
@@ -41,9 +40,10 @@ test_csv:
 skip_prep: False
 ckpt_interval_minutes: 5 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
+
 # To make Transformers converge, the global bath size should be large enough.
-# The global batch size is computed as batch_size * n_gpus * gradient_accumulation.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
 # Empirically, we found that this value should be >= 128.
 # Please, set your parameters accordingly.
 number_of_epochs: 100
@@ -57,12 +57,14 @@ ctc_weight: 0.3 # Multitask with CTC for the encoder (0.0 = disabled)
 ce_weight: 0.0 # Multitask with CE for the decoder (0.0 = disabled)
 max_grad_norm: 5.0
 loss_reduction: 'batchmean'
+precision: fp32 # bf16, fp16 or fp32
 
 # The batch size is used if and only if dynamic batching is set to False
 # Validation and testing are done with fixed batches and not dynamic batching.
 batch_size: 8
 grad_accumulation_factor: 4
 sorting: random
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
 
 # Feature parameters
 sample_rate: 16000
@@ -70,6 +72,28 @@ n_fft: 512
 n_mels: 80
 win_length: 32
 
+# Streaming & dynamic chunk training options
+# At least for the current architecture on LibriSpeech, we found out that
+# non-streaming accuracy is very similar between `streaming: True` and
+# `streaming: False`.
+streaming: True  # controls all Dynamic Chunk Training & chunk size & left context mechanisms
+
+# Configuration for Dynamic Chunk Training.
+# In this model, a chunk is roughly equivalent to 40ms of audio.
+dynchunktrain_config_sampler: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfigRandomSampler # yamllint disable-line rule:line-length
+   chunkwise_prob: 0.6 # Probability during a batch to limit attention and sample a random chunk size in the following range
+   chunk_size_min: 8 # Minimum chunk size (if in a DynChunkTrain batch)
+   chunk_size_max: 32 # Maximum chunk size (if in a DynChunkTrain batch)
+   limited_left_context_prob: 0.75 # If in a DynChunkTrain batch, the probability during a batch to restrict left context to a random number of chunks
+   left_context_chunks_min: 2 # Minimum left context size (in # of chunks)
+   left_context_chunks_max: 32 # Maximum left context size (in # of chunks)
+   # If you specify a valid/test config, you can optionally have evaluation be
+   # done with a specific DynChunkTrain configuration.
+   # valid_config: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfig
+   #    chunk_size: 24
+   #    left_context_size: 16
+   # test_config: ...
+
 # Dataloader options
 train_dataloader_opts:
    batch_size: !ref <batch_size>
@@ -97,7 +121,8 @@ dynamic_batch_sampler:
    batch_ordering: random
    max_batch_ex: 256
 
-# Model parameters
+####################### Model Parameters #######################################
+
 # Transformer
 d_model: 512
 joint_dim: 640
@@ -141,30 +166,55 @@ compute_features: !new:speechbrain.lobes.features.Fbank
    n_mels: !ref <n_mels>
    win_length: !ref <win_length>
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-   time_warp: True
-   time_warp_window: 5
-   time_warp_mode: bicubic
-   freq_mask: True
-   n_freq_mask: 2
-   time_mask: True
-   n_time_mask: 5
-   replace_with_zero: False
-   freq_mask_width: 30
-   time_mask_width: 20
-
-speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
    orig_freq: !ref <sample_rate>
    speeds: [95, 100, 105]
 
-# Uncomment if interested in env corruption
-# env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-#   openrir_folder: !ref <data_folder_rirs>
-#   babble_prob: 0.0
-#   reverb_prob: 0.0
-#   noise_prob: 1.0
-#   noise_snr_low: 0
-#   noise_snr_high: 15
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   min_augmentations: 1
+   max_augmentations: 1
+   augment_prob: 1.0
+   augmentations: [!ref <speed_perturb>]
+
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 15
+   drop_length_high: 25
+   drop_count_low: 5
+   drop_count_high: 5
+   replace: "zeros"
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: 25
+   drop_length_high: 35
+   drop_count_low: 2
+   drop_count_high: 2
+   replace: "zeros"
+   dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+   parallel_augment: False
+   concat_original: False
+   repeat_augment: 1
+   shuffle_augmentations: False
+   min_augmentations: 3
+   max_augmentations: 3
+   augment_prob: 1.0
+   augmentations: [
+      !ref <time_drop>,
+      !ref <freq_drop>,
+      !ref <time_warp>]
+
+############################## Models ##########################################
 
 CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
    input_shape: (8, 10, 80)
@@ -276,7 +326,6 @@ modules:
    Tjoint: !ref <Tjoint>
    transducer_lin: !ref <transducer_lin>
    normalize: !ref <normalize>
-   augmentation: !ref <augmentation>
    lm_model: !ref <lm_model>
    proj_ctc: !ref <proj_ctc>
    proj_dec: !ref <proj_dec>
@@ -288,6 +337,8 @@ modules:
 model: !new:torch.nn.ModuleList
    - [!ref <CNN>, !ref <enc>, !ref <emb>, !ref <dec>, !ref <proj_enc>, !ref <proj_dec>, !ref <proj_ctc>, !ref <transducer_lin>]
 
+############################## Decoding & optimiser ############################
+
 # Tokenizer initialization
 tokenizer: !new:sentencepiece.SentencePieceProcessor
 
@@ -321,6 +372,8 @@ noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
    lr_initial: !ref <lr>
    n_warmup_steps: !ref <warmup_steps>
 
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
@@ -346,3 +399,23 @@ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
 
 cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
    split_tokens: True
+
+# for the inference hparams, you will need to include and uncomment something like this:
+
+# make_tokenizer_streaming_context: !name:speechbrain.tokenizers.SentencePiece.SentencePieceDecoderStreamingContext
+# tokenizer_decode_streaming: !name:speechbrain.tokenizers.SentencePiece.spm_decode_preserve_leading_space
+
+# make_decoder_streaming_context: !name:speechbrain.decoders.transducer.TransducerGreedySearcherStreamingContext # default constructor
+# decoding_function: !name:speechbrain.decoders.transducer.TransducerBeamSearcher.transducer_greedy_decode_streaming
+#    - !ref <Greedysearcher>  # self
+
+# fea_streaming_extractor: !new:speechbrain.lobes.features.StreamingFeatureWrapper
+#    module: !new:speechbrain.nnet.containers.LengthsCapableSequential
+#       - !ref <compute_features>
+#       - !ref <normalize>
+#       - !ref <CNN>
+#    # don't consider normalization as part of the input filter chain.
+#    # normalization will operate at chunk level, which mismatches training
+#    # somewhat, but does not appear to result in noticeable degradation.
+#    properties: !apply:speechbrain.utils.filter_analysis.stack_filter_properties
+#       - [!ref <compute_features>, !ref <CNN>]
diff --git a/recipes/LibriSpeech/ASR/transducer/train.py b/recipes/LibriSpeech/ASR/transducer/train.py
index 60a8e4d1b5ea68445c3675d2ca7d03b867c2c552..84d7e05ffa3701c172be2c79765a59bade434d30 100644
--- a/recipes/LibriSpeech/ASR/transducer/train.py
+++ b/recipes/LibriSpeech/ASR/transducer/train.py
@@ -50,35 +50,43 @@ class ASR(sb.Brain):
         wavs, wav_lens = batch.sig
         tokens_with_bos, token_with_bos_lens = batch.tokens_bos
 
-        # Add env corruption if specified
+        # Add waveform augmentation if specified.
         if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                batch.sig = wavs, wav_lens
-                tokens_with_bos = torch.cat(
-                    [tokens_with_bos, tokens_with_bos], dim=0
+            if hasattr(self.hparams, "wav_augment"):
+                wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+                tokens_with_bos = self.hparams.wav_augment.replicate_labels(
+                    tokens_with_bos
                 )
-                token_with_bos_lens = torch.cat(
-                    [token_with_bos_lens, token_with_bos_lens]
-                )
-                batch.tokens_bos = tokens_with_bos, token_with_bos_lens
-
-            if hasattr(self.hparams, "speed_perturb"):
-                wavs = hparams["speed_perturb"](wavs)
 
-        # Forward pass
         feats = self.hparams.compute_features(wavs)
+
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_with_bos = self.hparams.fea_augment.replicate_labels(
+                tokens_with_bos
+            )
+
         current_epoch = self.hparams.epoch_counter.current
-        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                feats = self.hparams.augmentation(feats)
+        # Old models may not have the streaming hparam, we don't break them in
+        # any other way so just check for its presence
+        if hasattr(self.hparams, "streaming") and self.hparams.streaming:
+            dynchunktrain_config = self.hparams.dynchunktrain_config_sampler(
+                stage
+            )
+        else:
+            dynchunktrain_config = None
+
+        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
 
         src = self.modules.CNN(feats)
-        x = self.modules.enc(src, wav_lens, pad_idx=self.hparams.pad_index)
+        x = self.modules.enc(
+            src,
+            wav_lens,
+            pad_idx=self.hparams.pad_index,
+            dynchunktrain_config=dynchunktrain_config,
+        )
         x = self.modules.proj_enc(x)
 
         e_in = self.modules.emb(tokens_with_bos)
@@ -146,11 +154,18 @@ class ASR(sb.Brain):
         else:
             logits_transducer, wav_lens, predicted_tokens = predictions
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            token_eos_lens = torch.cat([token_eos_lens, token_eos_lens], dim=0)
-            tokens = torch.cat([tokens, tokens], dim=0)
-            token_lens = torch.cat([token_lens, token_lens], dim=0)
+        if stage == sb.Stage.TRAIN:
+            # Labels must be extended if parallel augmentation or concatenated
+            # augmentation was performed on the input (increasing the time dimension)
+            if hasattr(self.hparams, "fea_augment"):
+                (
+                    tokens,
+                    token_lens,
+                    tokens_eos,
+                    token_eos_lens,
+                ) = self.hparams.fea_augment.replicate_multiple_labels(
+                    tokens, token_lens, tokens_eos, token_eos_lens
+                )
 
         if stage == sb.Stage.TRAIN:
             CTC_loss = 0.0
@@ -189,55 +204,10 @@ class ASR(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        should_step = self.step % self.grad_accumulation_factor == 0
-
-        with self.no_sync(not should_step):
-            # Managing automatic mixed precision
-            if self.auto_mix_prec:
-                with torch.autocast(torch.device(self.device).type):
-                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-
-                # Losses are excluded from mixed precision to avoid instabilities
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-
-                self.scaler.scale(
-                    loss / self.grad_accumulation_factor
-                ).backward()
-
-                if should_step:
-                    self.scaler.unscale_(self.optimizer)
-                    if self.check_gradients(loss):
-                        self.scaler.step(self.optimizer)
-                    self.scaler.update()
-                    self.zero_grad(set_to_none=True)
-                    self.optimizer_step += 1
-                    self.hparams.noam_annealing(self.optimizer)
-            else:
-                if self.bfloat16_mix_prec:
-                    with torch.autocast(
-                        device_type=torch.device(self.device).type,
-                        dtype=torch.bfloat16,
-                    ):
-                        outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                        loss = self.compute_objectives(
-                            outputs, batch, sb.Stage.TRAIN
-                        )
-                else:
-                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                    loss = self.compute_objectives(
-                        outputs, batch, sb.Stage.TRAIN
-                    )
-                (loss / self.grad_accumulation_factor).backward()
-                if should_step:
-                    if self.check_gradients(loss):
-                        self.optimizer.step()
-                    self.zero_grad(set_to_none=True)
-                    self.optimizer_step += 1
-                    self.hparams.noam_annealing(self.optimizer)
-
-        self.on_fit_batch_end(batch, outputs, loss, should_step)
-        return loss.detach().cpu()
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            self.hparams.noam_annealing(self.optimizer)
 
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
@@ -256,7 +226,7 @@ class ASR(sb.Brain):
             stage_stats["WER"] = self.wer_metric.summarize("error_rate")
 
         # Perform end-of-iteration things, like annealing, logging, etc.
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
 
             lr = self.hparams.noam_annealing.current_lr
             steps = self.optimizer_step
@@ -274,11 +244,12 @@ class ASR(sb.Brain):
                 train_stats=self.train_stats,
                 valid_stats=stage_stats,
             )
-            # We save multiple checkpoints as we will average them!
             self.checkpointer.save_and_keep_only(
                 meta={"WER": stage_stats["WER"], "epoch": epoch},
                 min_keys=["WER"],
+                num_to_keep=self.hparams.avg_checkpoints,
             )
+
         elif stage == sb.Stage.TEST:
             self.hparams.train_logger.log_stats(
                 stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
@@ -288,15 +259,24 @@ class ASR(sb.Brain):
                 with open(self.hparams.test_wer_file, "w") as w:
                     self.wer_metric.write_stats(w)
 
+            # save the averaged checkpoint at the end of the evaluation stage
+            # delete the rest of the intermediate checkpoints
+            # WER is set to -0.1 so checkpointer only keeps the averaged checkpoint
+            self.checkpointer.save_and_keep_only(
+                meta={"WER": -0.1, "epoch": epoch},
+                min_keys=["WER"],
+                num_to_keep=1,
+            )
+
     def on_evaluate_start(self, max_key=None, min_key=None):
         """perform checkpoint averge if needed"""
         super().on_evaluate_start()
 
         ckpts = self.checkpointer.find_checkpoints(
-            max_key=max_key, min_key=min_key
+            max_key=max_key, min_key=min_key,
         )
         ckpt = sb.utils.checkpoints.average_checkpoints(
-            ckpts, recoverable_name="model", device=self.device
+            ckpts, recoverable_name="model"
         )
 
         self.hparams.model.load_state_dict(ckpt, strict=True)
@@ -429,7 +409,13 @@ if __name__ == "__main__":
     # CLI:
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
-    # If --distributed_launch then
+    # Use torchaudio if the device is CPU
+    if run_opts.get("device") == "cpu":
+        if "use_torchaudio: True" in overrides:
+            overrides.replace("use_torchaudio: True", "use_torchaudio: False")
+        else:
+            overrides += "\nuse_torchaudio: True"
+
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -475,7 +461,7 @@ if __name__ == "__main__":
     # depending on the path given in the YAML file). The tokenizer is loaded at
     # the same time.
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Trainer initialization
     asr_brain = ASR(
diff --git a/recipes/LibriSpeech/ASR/transformer/README.md b/recipes/LibriSpeech/ASR/transformer/README.md
index 1d78f3dd645385b4d6047b44f3884f16b1186c69..120b891260d7e3efa753b996022dcbcfeac8ca26 100644
--- a/recipes/LibriSpeech/ASR/transformer/README.md
+++ b/recipes/LibriSpeech/ASR/transformer/README.md
@@ -27,8 +27,33 @@ installed in your environment (see extra-requirements.txt)**
 | 23-05-23 | branchformer_large.yaml | 2.72 (1.9 with LM) | 2.04 | 4.13 | Not Avail. | [GoogleDrive](https://www.dropbox.com/sh/gxkye4efa6hvl2c/AADO85EkkfbIGe5KjBAU6BrEa?dl=0) | 4xA100 80GB |
 | 23-05-23 | conformer_large.yaml | 2.62 (1.9 with LM) | 2.01 | 4.52 | [HuggingFace](https://huggingface.co/speechbrain/asr-conformer-transformerlm-librispeech) | [GoogleDrive](https://www.dropbox.com/sh/ef3chrau8i45ip1/AAD9un8oabOB1a9OiSomZEhZa?dl=0) | 4xA100 80GB |
 | 24-03-22 | transformer.yaml | 3.32 | 2.27 | 5.53 | [HuggingFace](https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech) | [GoogleDrive](https://www.dropbox.com/sh/653kq8h2k87md4p/AAByAaAryXtQKpRzYtzV9ih5a?dl=0) | 4xV100 32GB |
-| 24-03-22 | conformer_small.yaml | 4.05 | 2.49 | 6.1 (**only 13.3M parameters**) | [HuggingFace](https://huggingface.co/speechbrain/asr-conformersmall-transformerlm-librispeech) | [GoogleDrive](https://www.dropbox.com/sh/s0x6ni124858b8i/AAALaCH6sGTMRUVTjh8Tm8Jwa?dl=0) | 1xV100 32GB |
+| 24-03-22 | conformer_small.yaml | 4.05 | 2.49 | 6.1 (**only 13.3M parameters**) | [HuggingFace](https://huggingface.co/speechbrain/asr-conformersmall-transformerlm-librispeech) | [DropBox](https://www.dropbox.com/sh/s0x6ni124858b8i/AAALaCH6sGTMRUVTjh8Tm8Jwa?dl=0) | 1xV100 32GB |
 | 06-12-23 | train_hf_whisper.yaml | 3.60 | Not Avail. | Not Avail. | Not Avail. | Not Avail. | 1xA100 40GB |
+| 27-03-23 | hyperconformer_8M.yaml | 4.69 | 2.55 | 6.61 (**only 7.9M parameters**) | NA |  [DropBox](https://www.dropbox.com/sh/8jc96avmivr8fke/AABrFEhtWy_3-Q7BHhkh0enwa?dl=0) | 1xP40 24GB
+| 27-03-23 | hyperconformer_22M.yaml | 3.19 | 2.23 | 5.54  (**only 21.7M parameters**)  | NA | [DropBox](https://www.dropbox.com/sh/30xsmqj13jexzoh/AACvZNtX1Fsr0Wa1Z3C9rHLXa?dl=0) | 1xP40 24GB
+| 03-09-23 | hyperbranchformer_13M.yaml | NA | 2.54 | 6.58  | NA | soon | 1xP40 24GB
+| 03-09-23 | hyperbranchformer_25M.yaml | NA | 2.36 | 5.89 | NA | soon | 1xP40 24GB
+| 05-01-24 | bayesspeech.yaml | 4.28 | 2.84 | 6.27 | NA | [DropBox](https://www.dropbox.com/scl/fo/cdken4jqfj96ev1v84jxm/h?rlkey=25eu1ytgm5ac51zqj8p65zwxd&dl=0) | 1xV100 32GB |
+
+# **About HyperConformer**
+HyperConformer is a new architecture, which replaces the self-attention mechanism of Conformer with the linear-time token mixing architecture HyperMixer.
+It achieves competitive or better results than Conformer while requiring less memory and compute.
+
+- Paper: https://arxiv.org/abs/2305.18281
+- HyperMixer code: https://github.com/idiap/hypermixing
+
+Please cite HyperConformer if you use it for your research or business.
+
+```bibtex
+@inproceedings{mai23_interspeech,
+  author={Florian Mai and Juan Zuluaga-Gomez and Titouan Parcollet and Petr Motlicek},
+  title={{HyperConformer}: Multi-head HyperMixer for Efficient Speech Recognition},
+  year=2023,
+  booktitle={Proc. Interspeech 2023},
+  pages={2213--2217},
+  doi={10.21437/Interspeech.2023-1611}
+}
+```
 
 # **About SpeechBrain**
 - Website: https://speechbrain.github.io/
diff --git a/recipes/LibriSpeech/ASR/transformer/extra_requirements.txt b/recipes/LibriSpeech/ASR/transformer/extra_requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cd14b4fe5361ce69b3d5545d5a9311d631128429
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/transformer/extra_requirements.txt
@@ -0,0 +1 @@
+bayestorch>=0.0.3 # For Bayes ASR recipe
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/bayesspeech.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/bayesspeech.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2eee3646ecc87e806ae067ba4417178a17dbb021
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/bayesspeech.yaml
@@ -0,0 +1,355 @@
+# ############################################################################
+# Model: E2E ASR with Bayesian Transformer (https://arxiv.org/abs/2301.11276)
+# Encoder: Bayesian Transformer Encoder
+# Decoder: Bayesian Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM
+# Tokens: unigram
+# losses: CTC + KLdiv (Label Smoothing loss)
+# Training: Librispeech 960h
+# Authors:  Jianyuan Zhong, Titouan Parcollet, Samuele Cornell, Luca Della Libera
+# ############################################################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+
+seed: 74443
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/bayesspeech/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Language model (LM) pretraining
+# NB: To avoid mismatch, the speech recognizer must be trained with the same
+# tokenizer used for LM training. Here, we download everything from the
+# speechbrain HuggingFace repository. However, a local path pointing to a
+# directory containing the lm.ckpt and tokenizer.ckpt may also be specified
+# instead. E.g if you want to use your own LM / tokenizer.
+pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech
+
+# Data files
+data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech
+# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES
+# then data_folder_rirs should be /localscratch/xxx_corpus
+# otherwise the dataset will automatically be downloaded
+# data_folder_rirs: !ref <data_folder>
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: ["dev-clean"]
+test_splits: ["test-clean", "test-other"]
+skip_prep: False
+train_csv: !ref <output_folder>/train.csv
+valid_csv: !ref <output_folder>/dev-clean.csv
+test_csv:
+    - !ref <output_folder>/test-clean.csv
+    - !ref <output_folder>/test-other.csv
+
+ckpt_interval_minutes: 30 # save checkpoint every N min
+
+####################### Training Parameters ####################################
+
+# To make Transformers converge, the global bath size should be large enough.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
+# Empirically, we found that this value should be >= 128.
+# Please, set your parameters accordingly.
+number_of_epochs: 30
+batch_size: 32 # This works for 1x GPU with 40GB with no dynamic batching
+ctc_weight: 0.3
+grad_accumulation_factor: 1
+max_grad_norm: 5.0
+loss_reduction: 'batchmean'
+sorting: random
+num_workers: 4
+precision: fp32 # bf16, fp16 or fp32
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
+
+# index
+pad_index: 0
+bos_index: 1
+eos_index: 2
+
+# This setup works well for V100 32GB GPU, adapts it to your needs.
+# Or turn it off (but training speed will decrease)
+dynamic_batching: True
+max_batch_length_train: 600
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
+num_bucket: 200
+shuffle: True  # if true re-creates batches at each epoch shuffling examples.
+batch_ordering: random
+max_batch_ex: 128
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+# stages related parameters
+lr_adam: 0.001
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+# Dataloader options
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: !ref <num_workers>
+    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
+        padding_kwargs:
+            value: !ref <pad_index>
+
+valid_dataloader_opts:
+    batch_size: 1
+    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
+        padding_kwargs:
+            value: !ref <pad_index>
+
+test_dataloader_opts:
+    batch_size: 1
+    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
+        padding_kwargs:
+            value: !ref <pad_index>
+
+####################### Model Parameters #######################################
+# Transformer
+d_model: 512
+nhead: 4
+num_encoder_layers: 12
+num_decoder_layers: 6
+d_ffn: 2048
+transformer_dropout: 0.0
+activation: !name:torch.nn.GELU
+output_neurons: 5000
+
+# Bayesian inference parameters
+normal_prior_log_scale: -1.0
+normal_posterior_softplus_inv_scale: -5.0
+kl_div_weight: 0.000001  # Set based on the number of model parameters
+num_eval_mc_samples: 10
+
+# Outputs
+blank_index: 0
+label_smoothing: 0.0
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+valid_search_interval: 10
+valid_beam_size: 10
+test_beam_size: 66
+
+# Scoring parameters
+lm_weight: 0.60
+ctc_weight_decode: 0.40
+
+############################## Models ##########################################
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+    input_shape: (8, 10, 80)
+    num_blocks: 3
+    num_layers_per_block: 1
+    out_channels: (64, 64, 64)
+    kernel_sizes: (5, 5, 1)
+    strides: (2, 2, 1)
+    residuals: (False, False, True)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+    input_size: 1280
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    d_ffn: !ref <d_ffn>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    encoder_module: transformer
+    attention_type: regularMHA
+    normalize_before: True
+    causal: False
+
+# This is the TransformerLM that is used according to the Huggingface repository
+# Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path
+# For more details about the model!
+# NB: It has to match the pre-trained TransformerLM!!
+lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length
+    vocab: !ref <output_neurons>
+    d_model: 768
+    nhead: 12
+    num_encoder_layers: 12
+    num_decoder_layers: 0
+    d_ffn: 3072
+    dropout: 0.0
+    activation: !name:torch.nn.GELU
+    normalize_before: False
+
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+modules:
+    CNN: !ref <CNN>
+    Transformer: !ref <Transformer>
+    seq_lin: !ref <seq_lin>
+    ctc_lin: !ref <ctc_lin>
+    normalize: !ref <normalize>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+# define two optimizers here for two-stage training
+Adam: !name:torch.optim.Adam
+    lr: !ref <lr_adam>
+    betas: (0.9, 0.98)
+    eps: 0.000000001
+
+
+############################## Decoding & optimiser ############################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>, !ref <transformerlm_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_test_search>
+
+log_softmax: !new:torch.nn.LogSoftmax
+    dim: -1
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+    reduction: !ref <loss_reduction>
+
+seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr_adam>
+    n_warmup_steps: 25000
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        noam_scheduler: !ref <noam_annealing>
+        normalizer: !ref <normalize>
+        counter: !ref <epoch_counter>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+    update_until_epoch: 4
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+    sample_rate: !ref <sample_rate>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mels>
+
+############################## Logging and Pretrainer ##########################
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+
+# The pretrainer allows a mapping between pretrained files and instances that
+# are declared in the yaml. E.g here, we will download the file lm.ckpt
+# and it will be loaded into "lm" which is pointing to the <lm_model> defined
+# before.
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>
+    loadables:
+        lm: !ref <lm_model>
+        tokenizer: !ref <tokenizer>
+    paths:
+        lm: !ref <pretrained_lm_tokenizer_path>/lm.ckpt
+        tokenizer: !ref <pretrained_lm_tokenizer_path>/tokenizer.ckpt
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/branchformer_large.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/branchformer_large.yaml
index 6f76ab3d68c44fe0e6b59b6220ca350751db1cca..02fc2eac46b99dfa8801fb986f0fbf0a507ffde9 100644
--- a/recipes/LibriSpeech/ASR/transformer/hparams/branchformer_large.yaml
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/branchformer_large.yaml
@@ -41,9 +41,11 @@ test_csv:
     - !ref <output_folder>/test-clean.csv
     - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
+
 # To make Transformers converge, the global bath size should be large enough.
-# The global batch size is computed as batch_size * n_gpus * gradient_accumulation.
+# The global batch size is computed as batch_size * n_gpus *
+# grad_accumulation_factor.
 # Empirically, we found that this value should be >= 128.
 # Please, set your parameters accordingly.
 number_of_epochs: 120
@@ -54,6 +56,8 @@ max_grad_norm: 5.0
 loss_reduction: 'batchmean'
 sorting: random
 num_workers: 4
+precision: fp32 # bf16, fp16 or fp32
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
 
 # stages related parameters
 # stage_one_epochs: 90
@@ -69,17 +73,25 @@ win_length: 32
 # This setup works well for A100 80GB GPU, adapts it to your needs.
 # Or turn it off (but training speed will decrease)
 dynamic_batching: True
-max_batch_len: 500
-max_batch_len_val: 100 # we reduce it as the beam is much wider (VRAM)
+max_batch_length_train: 500
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
 num_bucket: 200
+shuffle: True # if true re-creates batches at each epoch shuffling examples.
+max_batch_ex: 128
+batch_ordering: random
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
 
-dynamic_batch_sampler:
-    max_batch_len: !ref <max_batch_len>
-    max_batch_len_val: !ref <max_batch_len_val>
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
     num_buckets: !ref <num_bucket>
-    shuffle_ex: True # if true re-creates batches at each epoch shuffling examples.
-    batch_ordering: random
-    max_batch_ex: 128
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
 
 # Dataloader options
 train_dataloader_opts:
@@ -93,7 +105,8 @@ valid_dataloader_opts:
 test_dataloader_opts:
     batch_size: 1
 
-####################### Model parameters ###########################
+####################### Model Parameters #######################################
+
 # Transformer
 d_model: 512
 nhead: 8
@@ -121,7 +134,7 @@ test_beam_size: 66
 lm_weight: 0.60
 ctc_weight_decode: 0.40
 
-############################## models ################################
+############################## Models ##########################################
 
 CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
     input_shape: (8, 10, 80)
@@ -194,34 +207,51 @@ Adam: !name:torch.optim.AdamW
     eps: 0.000000001
     weight_decay: !ref <weight_decay>
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-    bos_index: !ref <bos_index>
+####################### Decoding & optimiser ###################################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
     eos_index: !ref <eos_index>
     blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
-    length_normalization: False
-
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
-    lm_weight: !ref <lm_weight>
-    lm_modules: !ref <lm_model>
     temperature: 1.15
-    temperature_lm: 1.15
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer_test_search>
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -249,28 +279,50 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 10
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 25
-
-speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
+####################### Augmentations ###########################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
     orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Freq Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
 compute_features: !new:speechbrain.lobes.features.Fbank
     sample_rate: !ref <sample_rate>
     n_fft: !ref <n_fft>
     win_length: !ref <win_length>
     n_mels: !ref <n_mels>
 
+############################## Logging and Pretrainer ##########################
+
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
 
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/conformer_large.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/conformer_large.yaml
index 85569db0b7c6f98487e742b17098d0e448689683..7cdd4c06f91cac155d273ba9a9f9edd7dcc52885 100644
--- a/recipes/LibriSpeech/ASR/transformer/hparams/conformer_large.yaml
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/conformer_large.yaml
@@ -41,9 +41,10 @@ test_csv:
     - !ref <output_folder>/test-clean.csv
     - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
+
 # To make Transformers converge, the global bath size should be large enough.
-# The global batch size is computed as batch_size * n_gpus * gradient_accumulation.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
 # Empirically, we found that this value should be >= 128.
 # Please, set your parameters accordingly.
 number_of_epochs: 120
@@ -54,6 +55,8 @@ max_grad_norm: 5.0
 loss_reduction: 'batchmean'
 sorting: random
 num_workers: 4
+precision: fp32 # bf16, fp16 or fp32
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
 
 # stages related parameters
 lr_adam: 0.0008
@@ -67,17 +70,26 @@ win_length: 32
 # This setup works well for A100 80GB GPU, adapts it to your needs.
 # Or turn it off (but training speed will decrease)
 dynamic_batching: True
-max_batch_len: 500
-max_batch_len_val: 100 # we reduce it as the beam is much wider (VRAM)
+max_batch_length_train: 500
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
 num_bucket: 200
+shuffle: True # if true re-creates batches at each epoch shuffling examples.
+batch_ordering: random
+max_batch_ex: 256
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
 
-dynamic_batch_sampler:
-    max_batch_len: !ref <max_batch_len>
-    max_batch_len_val: !ref <max_batch_len_val>
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
     num_buckets: !ref <num_bucket>
-    shuffle_ex: True # if true re-creates batches at each epoch shuffling examples.
-    batch_ordering: random
-    max_batch_ex: 256
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
 
 # Dataloader options
 train_dataloader_opts:
@@ -91,7 +103,8 @@ valid_dataloader_opts:
 test_dataloader_opts:
     batch_size: 1
 
-####################### Model parameters ###########################
+####################### Model Parameters #######################################
+
 # Transformer
 d_model: 512
 nhead: 8
@@ -118,7 +131,7 @@ test_beam_size: 66
 lm_weight: 0.60
 ctc_weight_decode: 0.40
 
-############################## models ################################
+############################## Models ##########################################
 
 CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
     input_shape: (8, 10, 80)
@@ -180,43 +193,60 @@ modules:
     ctc_lin: !ref <ctc_lin>
     normalize: !ref <normalize>
 
-model: !new:torch.nn.ModuleList
-    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-
 # define two optimizers here for two-stage training
 Adam: !name:torch.optim.AdamW
     lr: !ref <lr_adam>
     betas: (0.9, 0.98)
     eps: 0.000000001
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-    bos_index: !ref <bos_index>
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+####################### Decoding & optimiser ###########################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
     eos_index: !ref <eos_index>
     blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
-    length_normalization: False
-
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
-    lm_weight: !ref <lm_weight>
-    lm_modules: !ref <lm_model>
     temperature: 1.15
-    temperature_lm: 1.15
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer_test_search>
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -244,28 +274,50 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 10
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 25
-
-speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
     orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Freq Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
 compute_features: !new:speechbrain.lobes.features.Fbank
     sample_rate: !ref <sample_rate>
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
     win_length: !ref <win_length>
 
+############################## Logging and Pretrainer ##########################
+
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
 
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/conformer_small.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/conformer_small.yaml
index 2672107a1dd9a6c579e3e4ae871643412df6b59f..a24e6649a28488cf2d3cdf4a29242717134d6ae8 100644
--- a/recipes/LibriSpeech/ASR/transformer/hparams/conformer_small.yaml
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/conformer_small.yaml
@@ -40,9 +40,10 @@ test_csv:
     - !ref <output_folder>/test-clean.csv
     - !ref <output_folder>/test-other.csv
 
-# Training parameters
+####################### Training Parameters ####################################
+
 # To make Transformers converge, the global bath size should be large enough.
-# The global batch size is computed as batch_size * n_gpus * gradient_accumulation.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
 # Empirically, we found that this value should be >= 128.
 # Please, set your parameters accordingly.
 number_of_epochs: 110
@@ -53,6 +54,8 @@ max_grad_norm: 5.0
 loss_reduction: 'batchmean'
 sorting: random
 num_workers: 4
+precision: fp32 # bf16, fp16 or fp32
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
 
 # stages related parameters
 # stage_one_epochs: 90
@@ -67,17 +70,26 @@ n_mels: 80
 # This setup works well for V100 32GB GPU, adapts it to your needs.
 # Or turn it off (but training speed will decrease)
 dynamic_batching: True
-max_batch_len: 900
-max_batch_len_val: 100 # we reduce it as the beam is much wider (VRAM)
+max_batch_length_train: 900
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
 num_bucket: 200
+shuffle: True # if true re-creates batches at each epoch shuffling examples.
+batch_ordering: random
+max_batch_ex: 128
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
 
-dynamic_batch_sampler:
-    max_batch_len: !ref <max_batch_len>
-    max_batch_len_val: !ref <max_batch_len_val>
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
     num_buckets: !ref <num_bucket>
-    shuffle_ex: True # if true re-creates batches at each epoch shuffling examples.
-    batch_ordering: random
-    max_batch_ex: 128
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
 
 # Dataloader options
 train_dataloader_opts:
@@ -91,7 +103,8 @@ valid_dataloader_opts:
 test_dataloader_opts:
     batch_size: 1
 
-####################### Model parameters ###########################
+####################### Model Parameters #######################################
+
 # Transformer
 d_model: 144
 nhead: 4
@@ -118,7 +131,7 @@ test_beam_size: 66
 lm_weight: 0.60
 ctc_weight_decode: 0.40
 
-############################## models ################################
+############################## Models ##########################################
 
 CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
     input_shape: (8, 10, 80)
@@ -189,39 +202,52 @@ Adam: !name:torch.optim.Adam
     betas: (0.9, 0.98)
     eps: 0.000000001
 
-#SGD: !name:torch.optim.SGD
-#    lr: !ref <lr_sgd>
-#    momentum: 0.99
-#    nesterov: True
+############################## Decoding & optimiser ############################
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-    bos_index: !ref <bos_index>
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
     eos_index: !ref <eos_index>
     blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
-    length_normalization: False
-
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
-    lm_weight: !ref <lm_weight>
-    lm_modules: !ref <lm_model>
     temperature: 1.15
-    temperature_lm: 1.15
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer_test_search>
+
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -249,27 +275,49 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: True
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 2
-    time_mask: True
-    n_time_mask: 2
-    replace_with_zero: False
-    freq_mask_width: 30
-    time_mask_width: 40
-
-speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
     orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Freq Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
 compute_features: !new:speechbrain.lobes.features.Fbank
     sample_rate: !ref <sample_rate>
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
+############################## Logging and Pretrainer ##########################
+
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
 
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_13M.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_13M.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4b6ca718f306a9d81a406d4c92c6d8536aa9721a
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_13M.yaml
@@ -0,0 +1,342 @@
+# ############################################################################
+# Model: E2E ASR with Transformer
+# Encoder: HyperConformer Encoder
+# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM
+# Tokens: unigram
+# losses: CTC + KLdiv (Label Smoothing loss)
+# Training: Librispeech 960h
+# Authors:  Juan Pablo Zuluaga, Florian Mai, Titouan Parcollet
+# ############################################################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+
+seed: 7775
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/hyperbranchformer_13M/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Language model (LM) pretraining
+# NB: To avoid mismatch, the speech recognizer must be trained with the same
+# tokenizer used for LM training. Here, we download everything from the
+# speechbrain HuggingFace repository. However, a local path pointing to a
+# directory containing the lm.ckpt and tokenizer.ckpt may also be specified
+# instead. E.g if you want to use your own LM / tokenizer.
+pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g., /path/to/LibriSpeech
+# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES
+# then data_folder_rirs should be /localscratch/xxx_corpus
+# otherwise the dataset will automatically be downloaded
+# data_folder_rirs: !ref <data_folder>
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: ["dev-clean"]
+test_splits: ["test-clean", "test-other"]
+skip_prep: False
+train_csv: !ref <output_folder>/train.csv
+valid_csv: !ref <output_folder>/dev-clean.csv
+test_csv:
+    - !ref <output_folder>/test-clean.csv
+    - !ref <output_folder>/test-other.csv
+
+############################## Training Parameters #############################
+
+# To make Transformers converge, the global bath size should be large enough.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
+# Empirically, we found that this value should be >= 128.
+# Please, set your parameters accordingly.
+number_of_epochs: 110
+batch_size: 16 # This works for 2x GPUs with 32GB
+ctc_weight: 0.3
+grad_accumulation_factor: 3
+max_grad_norm: 5.0
+loss_reduction: 'batchmean'
+sorting: random
+num_workers: 4
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
+
+# stages related parameters
+lr_adam: 0.001
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+# This setup works well for a P40 24GB GPU, adapt it to your needs.
+# Or turn it off (but training speed will decrease)
+dynamic_batching: True
+max_batch_length_train: 600
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
+num_bucket: 200
+shuffle: True # if true re-creates batches at each epoch shuffling examples.
+max_batch_ex: 128
+batch_ordering: random
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+# Dataloader options
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: !ref <num_workers>
+
+valid_dataloader_opts:
+    batch_size: 1
+
+test_dataloader_opts:
+    batch_size: 1
+
+####################### Model Parameters #######################################
+
+# Transformer
+d_model: 144
+nhead: 8
+num_encoder_layers: 10
+num_decoder_layers: 4
+csgu_linear_units: 3072
+csgu_kernel_size: 31
+transformer_dropout: 0.1
+activation: !name:torch.nn.GELU
+output_neurons: 5000
+# specify 'hypermixing' for usage of multi-head HyperMixer instead of MultiHeadAttention
+# You can also specify RelPosMHAXL for conformer
+attention_type: hypermixing
+
+# option 1) 'conformer' for HyperConformer; option 2) 'transformer' for vanilla HyperMixer
+encoder_module: branchformer
+
+# Outputs
+blank_index: 0
+label_smoothing: 0.1
+pad_index: 0
+bos_index: 1
+eos_index: 2
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+valid_search_interval: 30
+valid_beam_size: 10
+test_beam_size: 66
+lm_weight: 0.60
+ctc_weight_decode: 0.40
+
+############################## Models ##########################################
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+    input_shape: (8, 10, 80)
+    num_blocks: 2
+    num_layers_per_block: 1
+    out_channels: (64, 32)
+    kernel_sizes: (3, 3)
+    strides: (2, 2)
+    residuals: (False, False)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+    input_size: 640
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    branchformer_activation: !ref <activation>
+    encoder_module: !ref <encoder_module>
+    csgu_linear_units: !ref <csgu_linear_units>
+    kernel_size: !ref <csgu_kernel_size>
+    attention_type: !ref <attention_type>
+    normalize_before: True
+    causal: False
+
+# This is the TransformerLM that is used according to the Huggingface repository
+# Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path
+# For more details about the model!
+# NB: It has to match the pre-trained TransformerLM!!
+lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length
+    vocab: !ref <output_neurons>
+    d_model: 768
+    nhead: 12
+    num_encoder_layers: 12
+    num_decoder_layers: 0
+    d_ffn: 3072
+    dropout: 0.0
+    activation: !name:torch.nn.GELU
+    normalize_before: False
+
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+    update_until_epoch: 4
+
+modules:
+    CNN: !ref <CNN>
+    Transformer: !ref <Transformer>
+    seq_lin: !ref <seq_lin>
+    ctc_lin: !ref <ctc_lin>
+    normalize: !ref <normalize>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+# define two optimizers here for two-stage training
+Adam: !name:torch.optim.Adam
+    lr: !ref <lr_adam>
+    betas: (0.9, 0.98)
+    eps: 0.000000001
+
+############################## Decoding & optimiser ############################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_test_search>
+
+log_softmax: !new:torch.nn.LogSoftmax
+    dim: -1
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+    reduction: !ref <loss_reduction>
+
+seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr_adam>
+    n_warmup_steps: 25000
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        noam_scheduler: !ref <noam_annealing>
+        normalizer: !ref <normalize>
+        counter: !ref <epoch_counter>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Freq Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+    sample_rate: !ref <sample_rate>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mels>
+
+############################## Logging and Pretrainer ##########################
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+
+# The pretrainer allows a mapping between pretrained files and instances that
+# are declared in the yaml. E.g here, we will download the file lm.ckpt
+# and it will be loaded into "lm" which is pointing to the <lm_model> defined
+# before.
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>
+    loadables:
+        lm: !ref <lm_model>
+        tokenizer: !ref <tokenizer>
+    paths:
+        lm: !ref <pretrained_lm_tokenizer_path>/lm.ckpt
+        tokenizer: !ref <pretrained_lm_tokenizer_path>/tokenizer.ckpt
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_25M.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_25M.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2e0242e311442c5ece15bdf30b2287f054a1475d
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_25M.yaml
@@ -0,0 +1,342 @@
+# ############################################################################
+# Model: E2E ASR with Transformer
+# Encoder: HyperConformer Encoder
+# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM
+# Tokens: unigram
+# losses: CTC + KLdiv (Label Smoothing loss)
+# Training: Librispeech 960h
+# Authors:  Juan Pablo Zuluaga, Florian Mai, Titouan Parcollet
+# ############################################################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+
+seed: 7775
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/hyperbranchformer_25M/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Language model (LM) pretraining
+# NB: To avoid mismatch, the speech recognizer must be trained with the same
+# tokenizer used for LM training. Here, we download everything from the
+# speechbrain HuggingFace repository. However, a local path pointing to a
+# directory containing the lm.ckpt and tokenizer.ckpt may also be specified
+# instead. E.g if you want to use your own LM / tokenizer.
+pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g., /path/to/LibriSpeech
+# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES
+# then data_folder_rirs should be /localscratch/xxx_corpus
+# otherwise the dataset will automatically be downloaded
+# data_folder_rirs: !ref <data_folder>
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: ["dev-clean"]
+test_splits: ["test-clean", "test-other"]
+skip_prep: False
+train_csv: !ref <output_folder>/train.csv
+valid_csv: !ref <output_folder>/dev-clean.csv
+test_csv:
+    - !ref <output_folder>/test-clean.csv
+    - !ref <output_folder>/test-other.csv
+
+############################## Training Parameters #############################
+
+# To make Transformers converge, the global bath size should be large enough.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
+# Empirically, we found that this value should be >= 128.
+# Please, set your parameters accordingly.
+number_of_epochs: 110
+batch_size: 16 # This works for 2x GPUs with 32GB
+ctc_weight: 0.3
+grad_accumulation_factor: 3
+max_grad_norm: 5.0
+loss_reduction: 'batchmean'
+sorting: random
+num_workers: 4
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
+
+# stages related parameters
+lr_adam: 0.001
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+# This setup works well for a P40 24GB GPU, adapt it to your needs.
+# Or turn it off (but training speed will decrease)
+dynamic_batching: True
+max_batch_length_train: 600
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
+num_bucket: 200
+shuffle: True # if true re-creates batches at each epoch shuffling examples.
+max_batch_ex: 128
+batch_ordering: random
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+
+# Dataloader options
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: !ref <num_workers>
+
+valid_dataloader_opts:
+    batch_size: 1
+
+test_dataloader_opts:
+    batch_size: 1
+
+####################### Model Parameters #######################################
+# Transformer
+d_model: 256
+nhead: 8
+num_encoder_layers: 10
+num_decoder_layers: 4
+csgu_linear_units: 3072
+csgu_kernel_size: 31
+transformer_dropout: 0.1
+activation: !name:torch.nn.GELU
+output_neurons: 5000
+# specify 'hypermixing' for usage of multi-head HyperMixer instead of MultiHeadAttention
+# You can also specify RelPosMHAXL for conformer
+attention_type: hypermixing
+
+# option 1) 'conformer' for HyperConformer; option 2) 'transformer' for vanilla HyperMixer
+encoder_module: branchformer
+
+# Outputs
+blank_index: 0
+label_smoothing: 0.1
+pad_index: 0
+bos_index: 1
+eos_index: 2
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+valid_search_interval: 30
+valid_beam_size: 10
+test_beam_size: 66
+lm_weight: 0.60
+ctc_weight_decode: 0.40
+
+############################## Models ##########################################
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+    input_shape: (8, 10, 80)
+    num_blocks: 2
+    num_layers_per_block: 1
+    out_channels: (64, 32)
+    kernel_sizes: (3, 3)
+    strides: (2, 2)
+    residuals: (False, False)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+    input_size: 640
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    branchformer_activation: !ref <activation>
+    encoder_module: !ref <encoder_module>
+    csgu_linear_units: !ref <csgu_linear_units>
+    kernel_size: !ref <csgu_kernel_size>
+    attention_type: !ref <attention_type>
+    normalize_before: True
+    causal: False
+
+# This is the TransformerLM that is used according to the Huggingface repository
+# Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path
+# For more details about the model!
+# NB: It has to match the pre-trained TransformerLM!!
+lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length
+    vocab: !ref <output_neurons>
+    d_model: 768
+    nhead: 12
+    num_encoder_layers: 12
+    num_decoder_layers: 0
+    d_ffn: 3072
+    dropout: 0.0
+    activation: !name:torch.nn.GELU
+    normalize_before: False
+
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+    update_until_epoch: 4
+
+modules:
+    CNN: !ref <CNN>
+    Transformer: !ref <Transformer>
+    seq_lin: !ref <seq_lin>
+    ctc_lin: !ref <ctc_lin>
+    normalize: !ref <normalize>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+# define two optimizers here for two-stage training
+Adam: !name:torch.optim.Adam
+    lr: !ref <lr_adam>
+    betas: (0.9, 0.98)
+    eps: 0.000000001
+
+############################## Decoding & optimiser ############################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_test_search>
+
+log_softmax: !new:torch.nn.LogSoftmax
+    dim: -1
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+    reduction: !ref <loss_reduction>
+
+seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr_adam>
+    n_warmup_steps: 25000
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        noam_scheduler: !ref <noam_annealing>
+        normalizer: !ref <normalize>
+        counter: !ref <epoch_counter>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Freq Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+    sample_rate: !ref <sample_rate>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mels>
+
+############################## Logging and Pretrainer ##########################
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+
+# The pretrainer allows a mapping between pretrained files and instances that
+# are declared in the yaml. E.g here, we will download the file lm.ckpt
+# and it will be loaded into "lm" which is pointing to the <lm_model> defined
+# before.
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>
+    loadables:
+        lm: !ref <lm_model>
+        tokenizer: !ref <tokenizer>
+    paths:
+        lm: !ref <pretrained_lm_tokenizer_path>/lm.ckpt
+        tokenizer: !ref <pretrained_lm_tokenizer_path>/tokenizer.ckpt
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_22M.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_22M.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6e165ed5c65424d1432c3c7c39f6615175830ed5
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_22M.yaml
@@ -0,0 +1,340 @@
+# ############################################################################
+# Model: E2E ASR with Transformer
+# Encoder: HyperConformer Encoder
+# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM
+# Tokens: unigram
+# losses: CTC + KLdiv (Label Smoothing loss)
+# Training: Librispeech 960h
+# Authors:  Juan Pablo Zuluaga, Florian Mai, Titouan Parcollet
+# ############################################################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+
+seed: 7775
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/hyperconformer_22M/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Language model (LM) pretraining
+# NB: To avoid mismatch, the speech recognizer must be trained with the same
+# tokenizer used for LM training. Here, we download everything from the
+# speechbrain HuggingFace repository. However, a local path pointing to a
+# directory containing the lm.ckpt and tokenizer.ckpt may also be specified
+# instead. E.g if you want to use your own LM / tokenizer.
+pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g., /path/to/LibriSpeech
+# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES
+# then data_folder_rirs should be /localscratch/xxx_corpus
+# otherwise the dataset will automatically be downloaded
+# data_folder_rirs: !ref <data_folder>
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: ["dev-clean"]
+test_splits: ["test-clean", "test-other"]
+skip_prep: False
+train_csv: !ref <output_folder>/train.csv
+valid_csv: !ref <output_folder>/dev-clean.csv
+test_csv:
+    - !ref <output_folder>/test-clean.csv
+    - !ref <output_folder>/test-other.csv
+
+####################### Training Parameters ####################################
+
+# To make Transformers converge, the global bath size should be large enough.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
+# Empirically, we found that this value should be >= 128.
+# Please, set your parameters accordingly.
+number_of_epochs: 110
+batch_size: 16 # This works for 2x GPUs with 32GB
+ctc_weight: 0.3
+grad_accumulation_factor: 1
+max_grad_norm: 5.0
+loss_reduction: 'batchmean'
+sorting: random
+num_workers: 4
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
+
+# stages related parameters
+lr_adam: 0.001
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+# This setup works well for a P40 24GB GPU, adapt it to your needs.
+# Or turn it off (but training speed will decrease)
+dynamic_batching: True
+max_batch_length_train: 600
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
+num_bucket: 200
+shuffle: True # if true re-creates batches at each epoch shuffling examples.
+max_batch_ex: 128
+batch_ordering: random
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+
+# Dataloader options
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: !ref <num_workers>
+
+valid_dataloader_opts:
+    batch_size: 1
+
+test_dataloader_opts:
+    batch_size: 1
+
+####################### Model Parameters #######################################
+
+# Transformer
+d_model: 256
+nhead: 8
+num_encoder_layers: 10
+num_decoder_layers: 4
+d_ffn: 1024
+transformer_dropout: 0.1
+activation: !name:torch.nn.GELU
+output_neurons: 5000
+# specify 'hypermixing' for usage of multi-head HyperMixer instead of MultiHeadAttention
+# You can also specify RelPosMHAXL for conformer
+attention_type: hypermixing
+
+# option 1) 'conformer' for HyperConformer; option 2) 'transformer' for vanilla HyperMixer
+encoder_module: conformer
+
+# Outputs
+blank_index: 0
+label_smoothing: 0.0
+pad_index: 0
+bos_index: 1
+eos_index: 2
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+valid_search_interval: 10
+valid_beam_size: 10
+test_beam_size: 66
+lm_weight: 0.60
+ctc_weight_decode: 0.40
+
+############################## Models ##########################################
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+    input_shape: (8, 10, 80)
+    num_blocks: 2
+    num_layers_per_block: 1
+    out_channels: (64, 32)
+    kernel_sizes: (3, 3)
+    strides: (2, 2)
+    residuals: (False, False)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+    input_size: 640
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    d_ffn: !ref <d_ffn>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    encoder_module: !ref <encoder_module>
+    attention_type: !ref <attention_type>
+    normalize_before: True
+    causal: False
+
+# This is the TransformerLM that is used according to the Huggingface repository
+# Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path
+# For more details about the model!
+# NB: It has to match the pre-trained TransformerLM!!
+lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length
+    vocab: !ref <output_neurons>
+    d_model: 768
+    nhead: 12
+    num_encoder_layers: 12
+    num_decoder_layers: 0
+    d_ffn: 3072
+    dropout: 0.0
+    activation: !name:torch.nn.GELU
+    normalize_before: False
+
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+    update_until_epoch: 4
+
+modules:
+    CNN: !ref <CNN>
+    Transformer: !ref <Transformer>
+    seq_lin: !ref <seq_lin>
+    ctc_lin: !ref <ctc_lin>
+    normalize: !ref <normalize>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+# define two optimizers here for two-stage training
+Adam: !name:torch.optim.Adam
+    lr: !ref <lr_adam>
+    betas: (0.9, 0.98)
+    eps: 0.000000001
+
+####################### Decoding & optimiser ###################################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_test_search>
+
+log_softmax: !new:torch.nn.LogSoftmax
+    dim: -1
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+    reduction: !ref <loss_reduction>
+
+seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr_adam>
+    n_warmup_steps: 25000
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        noam_scheduler: !ref <noam_annealing>
+        normalizer: !ref <normalize>
+        counter: !ref <epoch_counter>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Freq Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+    sample_rate: !ref <sample_rate>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mels>
+
+############################## Logging and Pretrainer ##########################
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+
+# The pretrainer allows a mapping between pretrained files and instances that
+# are declared in the yaml. E.g here, we will download the file lm.ckpt
+# and it will be loaded into "lm" which is pointing to the <lm_model> defined
+# before.
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>
+    loadables:
+        lm: !ref <lm_model>
+        tokenizer: !ref <tokenizer>
+    paths:
+        lm: !ref <pretrained_lm_tokenizer_path>/lm.ckpt
+        tokenizer: !ref <pretrained_lm_tokenizer_path>/tokenizer.ckpt
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_8M.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_8M.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fe3bd599c7712a95d2d55300d0073329a068a7e5
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_8M.yaml
@@ -0,0 +1,340 @@
+# ############################################################################
+# Model: E2E ASR with Transformer
+# Encoder: HyperConformer Encoder
+# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM
+# Tokens: unigram
+# losses: CTC + KLdiv (Label Smoothing loss)
+# Training: Librispeech 960h
+# Authors:  Juan Pablo Zuluaga, Florian Mai, Titouan Parcollet
+# ############################################################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+
+seed: 7775
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/hyperconformer_8M/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Language model (LM) pretraining
+# NB: To avoid mismatch, the speech recognizer must be trained with the same
+# tokenizer used for LM training. Here, we download everything from the
+# speechbrain HuggingFace repository. However, a local path pointing to a
+# directory containing the lm.ckpt and tokenizer.ckpt may also be specified
+# instead. E.g if you want to use your own LM / tokenizer.
+pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g., /path/to/LibriSpeech
+# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES
+# then data_folder_rirs should be /localscratch/xxx_corpus
+# otherwise the dataset will automatically be downloaded
+# data_folder_rirs: !ref <data_folder>
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: ["dev-clean"]
+test_splits: ["test-clean", "test-other"]
+skip_prep: False
+train_csv: !ref <output_folder>/train.csv
+valid_csv: !ref <output_folder>/dev-clean.csv
+test_csv:
+    - !ref <output_folder>/test-clean.csv
+    - !ref <output_folder>/test-other.csv
+
+####################### Training Parameters ####################################
+
+# To make Transformers converge, the global bath size should be large enough.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
+# Empirically, we found that this value should be >= 128.
+# Please, set your parameters accordingly.
+number_of_epochs: 110
+batch_size: 16 # This works for 2x GPUs with 32GB
+ctc_weight: 0.3
+grad_accumulation_factor: 1
+max_grad_norm: 5.0
+loss_reduction: 'batchmean'
+sorting: random
+num_workers: 4
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
+
+# stages related parameters
+lr_adam: 0.001
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+
+# This setup works well for a P40 24GB GPU, adapt it to your needs.
+# Or turn it off (but training speed will decrease)
+dynamic_batching: True
+max_batch_length_train: 600
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
+num_bucket: 200
+shuffle: True # if true re-creates batches at each epoch shuffling examples.
+max_batch_ex: 128
+batch_ordering: random
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+
+# Dataloader options
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: !ref <num_workers>
+
+valid_dataloader_opts:
+    batch_size: 1
+
+test_dataloader_opts:
+    batch_size: 1
+
+####################### Model Parameters #######################################
+
+# Transformer
+d_model: 144
+nhead: 8
+num_encoder_layers: 10
+num_decoder_layers: 4
+d_ffn: 576
+transformer_dropout: 0.1
+activation: !name:torch.nn.GELU
+output_neurons: 5000
+# specify 'hypermixing' for usage of multi-head HyperMixer instead of MultiHeadAttention
+# You can also specify RelPosMHAXL for conformer
+attention_type: hypermixing
+
+# option 1) 'conformer' for HyperConformer; option 2) 'transformer' for vanilla HyperMixer
+encoder_module: conformer
+
+# Outputs
+blank_index: 0
+label_smoothing: 0.0
+pad_index: 0
+bos_index: 1
+eos_index: 2
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+valid_search_interval: 10
+valid_beam_size: 10
+test_beam_size: 66
+lm_weight: 0.60
+ctc_weight_decode: 0.40
+
+############################## Models ##########################################
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+    input_shape: (8, 10, 80)
+    num_blocks: 2
+    num_layers_per_block: 1
+    out_channels: (64, 32)
+    kernel_sizes: (3, 3)
+    strides: (2, 2)
+    residuals: (False, False)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+    input_size: 640
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    d_ffn: !ref <d_ffn>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    encoder_module: !ref <encoder_module>
+    attention_type: !ref <attention_type>
+    normalize_before: True
+    causal: False
+
+# This is the TransformerLM that is used according to the Huggingface repository
+# Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path
+# For more details about the model!
+# NB: It has to match the pre-trained TransformerLM!!
+lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length
+    vocab: !ref <output_neurons>
+    d_model: 768
+    nhead: 12
+    num_encoder_layers: 12
+    num_decoder_layers: 0
+    d_ffn: 3072
+    dropout: 0.0
+    activation: !name:torch.nn.GELU
+    normalize_before: False
+
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+    update_until_epoch: 4
+
+modules:
+    CNN: !ref <CNN>
+    Transformer: !ref <Transformer>
+    seq_lin: !ref <seq_lin>
+    ctc_lin: !ref <ctc_lin>
+    normalize: !ref <normalize>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+# define two optimizers here for two-stage training
+Adam: !name:torch.optim.Adam
+    lr: !ref <lr_adam>
+    betas: (0.9, 0.98)
+    eps: 0.000000001
+
+####################### Decoding & optimiser ###########################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer_test_search>
+
+log_softmax: !new:torch.nn.LogSoftmax
+    dim: -1
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+    reduction: !ref <loss_reduction>
+
+seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr_adam>
+    n_warmup_steps: 25000
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        noam_scheduler: !ref <noam_annealing>
+        normalizer: !ref <normalize>
+        counter: !ref <epoch_counter>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Freq Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+    sample_rate: !ref <sample_rate>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mels>
+
+############################## Logging and Pretrainer ##########################
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+
+# The pretrainer allows a mapping between pretrained files and instances that
+# are declared in the yaml. E.g here, we will download the file lm.ckpt
+# and it will be loaded into "lm" which is pointing to the <lm_model> defined
+# before.
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>
+    loadables:
+        lm: !ref <lm_model>
+        tokenizer: !ref <tokenizer>
+    paths:
+        lm: !ref <pretrained_lm_tokenizer_path>/lm.ckpt
+        tokenizer: !ref <pretrained_lm_tokenizer_path>/tokenizer.ckpt
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/train_hf_whisper.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/train_hf_whisper.yaml
index aea438f02a69123dd3baf8865dae274a1acebbd6..4891ca61746a15f7af937e5e6bf53ef7ed1c372f 100644
--- a/recipes/LibriSpeech/ASR/transformer/hparams/train_hf_whisper.yaml
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/train_hf_whisper.yaml
@@ -35,11 +35,12 @@ test_csv:
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+############################## Training Parameters #############################
+
 number_of_epochs: 1
 lr_whisper: 0.00003
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
@@ -61,7 +62,7 @@ min_decode_ratio: 0.0
 max_decode_ratio: 1.0
 test_beam_size: 8
 
-# Model parameters
+####################### Model Parameters #######################################
 freeze_whisper: False
 
 
@@ -74,18 +75,45 @@ valid_loader_kwargs:
 test_loader_kwargs:
     batch_size: !ref <test_batch_size>
 
-
-#
-# Functions and classes
-#
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0  # Min frequency band dropout probability
+    drop_freq_high: 1  # Max frequency band dropout probability
+    drop_freq_count_low: 1  # Min number of frequency bands to drop
+    drop_freq_count_high: 3  # Max number of frequency bands to drop
+    drop_freq_width: 0.05  # Width of frequency bands to drop
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1
+    drop_length_high: 5
+    drop_count_low: 1000
+    drop_count_high: 2000
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
     source: !ref <whisper_hub>
     freeze: !ref <freeze_whisper>
     save_path: !ref <whisper_folder>
@@ -99,18 +127,20 @@ nll_loss: !name:speechbrain.nnet.losses.nll_loss
 modules:
     whisper: !ref <whisper>
 
+############################## Decoding & optimiser ############################
+
 whisper_opt_class: !name:torch.optim.AdamW
     lr: !ref <lr_whisper>
     weight_decay: 0.01
 
-valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
+valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
     model: !ref <whisper>
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
+test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
     module: [!ref <whisper>]
     bos_index: !ref <timestamp_index>
     eos_index: !ref <eos_index>
@@ -124,6 +154,8 @@ lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler
     annealing_factor: 0.9
     patient: 0
 
+############################## Logging and Pretrainer ##########################
+
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
     recoverables:
diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml
index d8407b8a0e99b0dbc451a9c578a4a3bf02a803e0..173453e9d4b992e109f5314261aeb129e14ddc79 100644
--- a/recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml
+++ b/recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml
@@ -42,9 +42,10 @@ test_csv:
 
 ckpt_interval_minutes: 30 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
+
 # To make Transformers converge, the global bath size should be large enough.
-# The global batch size is computed as batch_size * n_gpus * gradient_accumulation.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
 # Empirically, we found that this value should be >= 128.
 # Please, set your parameters accordingly.
 number_of_epochs: 100
@@ -55,6 +56,8 @@ max_grad_norm: 5.0
 loss_reduction: 'batchmean'
 sorting: random
 num_workers: 4
+precision: fp32 # bf16, fp16 or fp32
+avg_checkpoints: 10 # Number of checkpoints to average for evaluation
 
 # index
 pad_index: 0
@@ -64,17 +67,26 @@ eos_index: 2
 # This setup works well for V100 32GB GPU, adapts it to your needs.
 # Or turn it off (but training speed will decrease)
 dynamic_batching: True
-max_batch_len: 600
-max_batch_len_val: 100 # we reduce it as the beam is much wider (VRAM)
+max_batch_length_train: 600
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
 num_bucket: 200
+shuffle: True  # if true re-creates batches at each epoch shuffling examples.
+batch_ordering: random
+max_batch_ex: 128
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
 
-dynamic_batch_sampler:
-    max_batch_len: !ref <max_batch_len>
-    max_batch_len_val: !ref <max_batch_len_val>
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
     num_buckets: !ref <num_bucket>
-    shuffle_ex: True # if true re-creates batches at each epoch shuffling examples.
-    batch_ordering: random
-    max_batch_ex: 128
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
 
 # stages related parameters
 lr_adam: 0.001
@@ -105,7 +117,8 @@ test_dataloader_opts:
         padding_kwargs:
             value: !ref <pad_index>
 
-####################### Model parameters ###########################
+####################### Model Parameters #######################################
+
 # Transformer
 d_model: 512
 nhead: 4
@@ -126,10 +139,12 @@ max_decode_ratio: 1.0
 valid_search_interval: 10
 valid_beam_size: 10
 test_beam_size: 66
+
+# Scoring parameters
 lm_weight: 0.60
 ctc_weight_decode: 0.40
 
-############################## models ################################
+############################## Models ##########################################
 
 CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
     input_shape: (8, 10, 80)
@@ -197,34 +212,50 @@ Adam: !name:torch.optim.Adam
     eps: 0.000000001
 
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-    bos_index: !ref <bos_index>
+####################### Decoding & optimiser ###################################
+
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
     eos_index: !ref <eos_index>
     blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: 1.15
+
+scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>, !ref <transformerlm_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+        transformerlm: !ref <lm_weight>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
     using_eos_threshold: False
-    length_normalization: False
-
+    length_normalization: True
+    scorer: !ref <scorer_valid_search>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
-    ctc_weight: !ref <ctc_weight_decode>
-    lm_weight: !ref <lm_weight>
-    lm_modules: !ref <lm_model>
     temperature: 1.15
-    temperature_lm: 1.15
     using_eos_threshold: False
     length_normalization: True
+    scorer: !ref <scorer_test_search>
 
 log_softmax: !new:torch.nn.LogSoftmax
     dim: -1
@@ -256,27 +287,49 @@ normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
     update_until_epoch: 4
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-    time_warp: False
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    freq_mask: True
-    n_freq_mask: 4
-    time_mask: True
-    n_time_mask: 4
-    replace_with_zero: False
-    freq_mask_width: 15
-    time_mask_width: 20
-
-speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
     orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 15
+    drop_length_high: 25
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+
+# Freq Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: 10
+    drop_length_high: 20
+    drop_count_low: 4
+    drop_count_high: 4
+    replace: "mean"
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
 compute_features: !new:speechbrain.lobes.features.Fbank
     sample_rate: !ref <sample_rate>
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
+############################## Logging and Pretrainer ##########################
+
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
 
diff --git a/recipes/LibriSpeech/ASR/transformer/train.py b/recipes/LibriSpeech/ASR/transformer/train.py
index 35a9f79fb0cb74b850d9f8b9fa0d78832460e122..292d7cc4249f6108f0cfba6f732073260ab80d9c 100644
--- a/recipes/LibriSpeech/ASR/transformer/train.py
+++ b/recipes/LibriSpeech/ASR/transformer/train.py
@@ -54,22 +54,15 @@ class ASR(sb.core.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, _ = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-
         # compute features
         feats = self.hparams.compute_features(wavs)
         current_epoch = self.hparams.epoch_counter.current
         feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                feats = self.hparams.augmentation(feats)
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)
 
         # forward modules
         src = self.modules.CNN(feats)
@@ -88,17 +81,26 @@ class ASR(sb.core.Brain):
 
         # Compute outputs
         hyps = None
-        if stage == sb.Stage.TRAIN:
-            hyps = None
-        elif stage == sb.Stage.VALID:
-            hyps = None
-            current_epoch = self.hparams.epoch_counter.current
-            if current_epoch % self.hparams.valid_search_interval == 0:
-                # for the sake of efficiency, we only perform beamsearch with limited capacity
-                # and no LM to give user some idea of how the AM is doing
-                hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens)
-        elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
+        current_epoch = self.hparams.epoch_counter.current
+        is_valid_search = (
+            stage == sb.Stage.VALID
+            and current_epoch % self.hparams.valid_search_interval == 0
+        )
+        is_test_search = stage == sb.Stage.TEST
+
+        if any([is_valid_search, is_test_search]):
+            # Note: For valid_search, for the sake of efficiency, we only perform beamsearch with
+            # limited capacity and no LM to give user some idea of how the AM is doing
+
+            # Decide searcher for inference: valid or test search
+            if stage == sb.Stage.VALID:
+                hyps, _, _, _ = self.hparams.valid_search(
+                    enc_out.detach(), wav_lens
+                )
+            else:
+                hyps, _, _, _ = self.hparams.test_search(
+                    enc_out.detach(), wav_lens
+                )
 
         return p_ctc, p_seq, wav_lens, hyps
 
@@ -111,13 +113,18 @@ class ASR(sb.core.Brain):
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
-            )
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
+        if stage == sb.Stage.TRAIN:
+            # Labels must be extended if parallel augmentation or concatenated
+            # augmentation was performed on the input (increasing the time dimension)
+            if hasattr(self.hparams, "fea_augment"):
+                (
+                    tokens,
+                    tokens_lens,
+                    tokens_eos,
+                    tokens_eos_lens,
+                ) = self.hparams.fea_augment.replicate_multiple_labels(
+                    tokens, tokens_lens, tokens_eos, tokens_eos_lens
+                )
 
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
@@ -157,20 +164,13 @@ class ASR(sb.core.Brain):
             max_key=max_key, min_key=min_key
         )
         ckpt = sb.utils.checkpoints.average_checkpoints(
-            ckpts, recoverable_name="model", device=self.device
+            ckpts, recoverable_name="model",
         )
 
         self.hparams.model.load_state_dict(ckpt, strict=True)
         self.hparams.model.eval()
         print("Loaded the average")
 
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        with torch.no_grad():
-            predictions = self.compute_forward(batch, stage=stage)
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -194,7 +194,7 @@ class ASR(sb.core.Brain):
                 stage_stats["WER"] = self.wer_metric.summarize("error_rate")
 
         # log stats and save checkpoint at end-of-epoch
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
 
             lr = self.hparams.noam_annealing.current_lr
             steps = self.optimizer_step
@@ -214,7 +214,7 @@ class ASR(sb.core.Brain):
             self.checkpointer.save_and_keep_only(
                 meta={"ACC": stage_stats["ACC"], "epoch": epoch},
                 max_keys=["ACC"],
-                num_to_keep=10,
+                num_to_keep=self.hparams.avg_checkpoints,
             )
 
         elif stage == sb.Stage.TEST:
@@ -235,52 +235,10 @@ class ASR(sb.core.Brain):
                 num_to_keep=1,
             )
 
-    def fit_batch(self, batch):
-
-        should_step = self.step % self.grad_accumulation_factor == 0
-        # Managing automatic mixed precision
-        if self.auto_mix_prec:
-            with torch.autocast(torch.device(self.device).type):
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-
-            # Losses are excluded from mixed precision to avoid instabilities
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            with self.no_sync(not should_step):
-                self.scaler.scale(
-                    loss / self.grad_accumulation_factor
-                ).backward()
-            if should_step:
-                self.scaler.unscale_(self.optimizer)
-                if self.check_gradients(loss):
-                    self.scaler.step(self.optimizer)
-                self.scaler.update()
-                self.zero_grad()
-                self.optimizer_step += 1
-                self.hparams.noam_annealing(self.optimizer)
-        else:
-            if self.bfloat16_mix_prec:
-                with torch.autocast(
-                    device_type=torch.device(self.device).type,
-                    dtype=torch.bfloat16,
-                ):
-                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                    loss = self.compute_objectives(
-                        outputs, batch, sb.Stage.TRAIN
-                    )
-            else:
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            with self.no_sync(not should_step):
-                (loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                if self.check_gradients(loss):
-                    self.optimizer.step()
-                self.zero_grad()
-                self.optimizer_step += 1
-                self.hparams.noam_annealing(self.optimizer)
-
-        self.on_fit_batch_end(batch, outputs, loss, should_step)
-        return loss.detach().cpu()
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            self.hparams.noam_annealing(self.optimizer)
 
 
 def dataio_prepare(hparams):
@@ -388,26 +346,18 @@ def dataio_prepare(hparams):
     if hparams["dynamic_batching"]:
         from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
 
-        dynamic_hparams = hparams["dynamic_batch_sampler"]
-        num_buckets = dynamic_hparams["num_buckets"]
+        dynamic_hparams_train = hparams["dynamic_batch_sampler_train"]
+        dynamic_hparams_valid = hparams["dynamic_batch_sampler_valid"]
 
         train_batch_sampler = DynamicBatchSampler(
             train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
             length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
-            max_batch_ex=dynamic_hparams["max_batch_ex"],
+            **dynamic_hparams_train,
         )
-
         valid_batch_sampler = DynamicBatchSampler(
             valid_data,
-            dynamic_hparams["max_batch_len_val"],
-            num_buckets=num_buckets,
             length_func=lambda x: x["duration"],
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
+            **dynamic_hparams_valid,
         )
 
     return (
@@ -426,7 +376,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -468,7 +417,7 @@ if __name__ == "__main__":
     # We download the pretrained LM from HuggingFace (or elsewhere depending on
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Trainer initialization
     asr_brain = ASR(
diff --git a/recipes/LibriSpeech/ASR/transformer/train_bayesspeech.py b/recipes/LibriSpeech/ASR/transformer/train_bayesspeech.py
new file mode 100644
index 0000000000000000000000000000000000000000..1feef3b5b8294a926dc9603b24b852068c857cae
--- /dev/null
+++ b/recipes/LibriSpeech/ASR/transformer/train_bayesspeech.py
@@ -0,0 +1,555 @@
+#!/usr/bin/env python3
+"""Recipe for training a Bayesian Transformer ASR system (https://arxiv.org/abs/2301.11276)
+with LibriSpeech via Bayes by Backprop (https://arxiv.org/abs/1505.05424).
+The system employs an encoder, a decoder, and an attention mechanism between them.
+Decoding is performed with (CTC/Att joint) beamsearch coupled with a neural language model.
+
+To run this recipe, do the following:
+> python train_bayesspeech.py hparams/transformer_bayesspeech.yaml
+
+With the default hyperparameters, the system employs a convolutional frontend and a transformer.
+The decoder is based on a Transformer decoder. Beamsearch coupled with a Transformer
+language model is used  on the top of decoder probabilities.
+
+Linear layers are turned into Bayesian linear layers by placing a normal prior and a normal
+variational posterior upon their weights and biases. The Bayesian neural network is trained
+to minimize the evidence lower bound (ELBO), which is a trade-off between the simplicity
+of the prior (complexity loss) and the complexity of the data (likelihood loss).
+The likelihood loss is the standard loss function used in non-Bayesian ASR transformers
+(CTC + negative-log likelihood), the complexity loss is the Kullback-Leibler divergence between
+variational posterior and prior. Sub-word units estimated with Byte Pairwise Encoding (BPE) are
+used as basic recognition tokens. Training is performed on the full LibriSpeech dataset (960 h).
+
+The best model is the average of the checkpoints from last 5 epochs.
+
+The experiment file is flexible enough to support a large variety of
+different systems. By properly changing the parameter files, you can try
+different encoders, decoders, tokens (e.g, characters instead of BPE),
+training split (e.g, train-clean 100 rather than the full one), and many
+other possible variations.
+
+
+Authors
+ * Jianyuan Zhong 2020
+ * Mirco Ravanelli 2020
+ * Peter Plantinga 2020
+ * Samuele Cornell 2020, 2021, 2022
+ * Titouan Parcollet 2021, 2022
+ * Luca Della Libera 2023
+"""
+
+import os
+import sys
+import torch
+import logging
+from pathlib import Path
+import speechbrain as sb
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.distributed import run_on_main, if_main_process
+
+logger = logging.getLogger(__name__)
+
+
+# Define training procedure
+class ASR(sb.core.Brain):
+    def compute_forward(self, batch, stage):
+        """Forward computations from the waveform batches to the output probabilities."""
+        batch = batch.to(self.device)
+        wavs, wav_lens = batch.sig
+        tokens_bos, _ = batch.tokens_bos
+
+        # compute features
+        feats = self.hparams.compute_features(wavs)
+        current_epoch = self.hparams.epoch_counter.current
+        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
+
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)
+
+        # forward modules
+        src = self.modules.CNN(feats)
+
+        enc_out, pred = self.modules.Transformer(
+            src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index,
+        )
+
+        # output layer for ctc log-probabilities
+        logits = self.modules.ctc_lin(enc_out)
+        p_ctc = self.hparams.log_softmax(logits)
+
+        # output layer for seq2seq log-probabilities
+        pred = self.modules.seq_lin(pred)
+        p_seq = self.hparams.log_softmax(pred)
+
+        # Compute outputs
+        hyps = None
+        current_epoch = self.hparams.epoch_counter.current
+        is_valid_search = (
+            stage == sb.Stage.VALID
+            and current_epoch % self.hparams.valid_search_interval == 0
+        )
+        is_test_search = stage == sb.Stage.TEST
+
+        if any([is_valid_search, is_test_search]):
+            # Note: For valid_search, for the sake of efficiency, we only perform beamsearch with
+            # limited capacity and no LM to give user some idea of how the AM is doing
+
+            # Decide searcher for inference: valid or test search
+            if stage == sb.Stage.VALID:
+                hyps, _, _, _ = self.hparams.valid_search(
+                    enc_out.detach(), wav_lens
+                )
+            else:
+                hyps, _, _, _ = self.hparams.test_search(
+                    enc_out.detach(), wav_lens
+                )
+
+        return p_ctc, p_seq, wav_lens, hyps
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss (CTC+NLL) given predictions and targets."""
+
+        (p_ctc, p_seq, wav_lens, hyps,) = predictions
+
+        ids = batch.id
+        tokens_eos, tokens_eos_lens = batch.tokens_eos
+        tokens, tokens_lens = batch.tokens
+
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "fea_augment"):
+                tokens = self.hparams.fea_augment.replicate_labels(tokens)
+                tokens_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_lens
+                )
+                tokens_eos = self.hparams.fea_augment.replicate_labels(
+                    tokens_eos
+                )
+                tokens_eos_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_eos_lens
+                )
+
+        loss_seq = self.hparams.seq_cost(
+            p_seq, tokens_eos, length=tokens_eos_lens
+        ).sum()
+
+        loss_ctc = self.hparams.ctc_cost(
+            p_ctc, tokens, wav_lens, tokens_lens
+        ).sum()
+
+        loss = (
+            self.hparams.ctc_weight * loss_ctc
+            + (1 - self.hparams.ctc_weight) * loss_seq
+            + self.hparams.kl_div_weight * self.modules.Transformer.kl_div
+        )
+
+        if stage != sb.Stage.TRAIN:
+            current_epoch = self.hparams.epoch_counter.current
+            valid_search_interval = self.hparams.valid_search_interval
+            if current_epoch % valid_search_interval == 0 or (
+                stage == sb.Stage.TEST
+            ):
+                # Decode token terms to words
+                predicted_words = [
+                    tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps
+                ]
+                target_words = [wrd.split(" ") for wrd in batch.wrd]
+                self.wer_metric.append(ids, predicted_words, target_words)
+
+            # compute the accuracy of the one-step-forward prediction
+            self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
+        return loss
+
+    def on_evaluate_start(self, max_key=None, min_key=None):
+        """perform checkpoint averge if needed"""
+        super().on_evaluate_start()
+
+        ckpts = self.checkpointer.find_checkpoints(
+            max_key=max_key, min_key=min_key
+        )
+        ckpt = sb.utils.checkpoints.average_checkpoints(
+            ckpts, recoverable_name="model",
+        )
+
+        self.hparams.model.load_state_dict(ckpt, strict=True)
+        self.hparams.model.eval()
+        print("Loaded the average")
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called at the beginning of each epoch"""
+        if stage != sb.Stage.TRAIN:
+            self.acc_metric = self.hparams.acc_computer()
+            self.wer_metric = self.hparams.error_rate_computer()
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of a epoch."""
+        # Compute/store important stats
+        stage_stats = {"loss": stage_loss}
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_stats
+        else:
+            stage_stats["ACC"] = self.acc_metric.summarize()
+            current_epoch = self.hparams.epoch_counter.current
+            valid_search_interval = self.hparams.valid_search_interval
+            if (
+                current_epoch % valid_search_interval == 0
+                or stage == sb.Stage.TEST
+            ):
+                stage_stats["WER"] = self.wer_metric.summarize("error_rate")
+
+        # log stats and save checkpoint at end-of-epoch
+        if stage == sb.Stage.VALID:
+
+            lr = self.hparams.noam_annealing.current_lr
+            steps = self.optimizer_step
+            optimizer = self.optimizer.__class__.__name__
+
+            epoch_stats = {
+                "epoch": epoch,
+                "lr": lr,
+                "steps": steps,
+                "optimizer": optimizer,
+            }
+            self.hparams.train_logger.log_stats(
+                stats_meta=epoch_stats,
+                train_stats=self.train_stats,
+                valid_stats=stage_stats,
+            )
+            self.checkpointer.save_and_keep_only(
+                meta={"ACC": stage_stats["ACC"], "epoch": epoch},
+                max_keys=["ACC"],
+                num_to_keep=self.hparams.avg_checkpoints,
+            )
+
+        elif stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+            if if_main_process():
+                with open(self.hparams.test_wer_file, "w") as w:
+                    self.wer_metric.write_stats(w)
+
+            # save the averaged checkpoint at the end of the evaluation stage
+            # delete the rest of the intermediate checkpoints
+            # ACC is set to 1.1 so checkpointer only keeps the averaged checkpoint
+            self.checkpointer.save_and_keep_only(
+                meta={"ACC": 1.1, "epoch": epoch},
+                max_keys=["ACC"],
+                num_to_keep=1,
+            )
+
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            self.hparams.noam_annealing(self.optimizer)
+
+
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions."""
+    data_folder = hparams["data_folder"]
+
+    train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
+        csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
+    )
+
+    if hparams["sorting"] == "ascending":
+        # we sort training data to speed up training and get better results.
+        train_data = train_data.filtered_sorted(sort_key="duration")
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["train_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "descending":
+        train_data = train_data.filtered_sorted(
+            sort_key="duration", reverse=True
+        )
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["train_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "random":
+        pass
+
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+    valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
+        csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
+    )
+    valid_data = valid_data.filtered_sorted(sort_key="duration")
+
+    # test is separate
+    test_datasets = {}
+    for csv_file in hparams["test_csv"]:
+        name = Path(csv_file).stem
+        test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
+            csv_path=csv_file, replacements={"data_root": data_folder}
+        )
+        test_datasets[name] = test_datasets[name].filtered_sorted(
+            sort_key="duration"
+        )
+
+    datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
+    valtest_datasets = [valid_data] + [i for k, i in test_datasets.items()]
+
+    # We get the tokenizer as we need it to encode the labels when creating
+    # mini-batches.
+    tokenizer = hparams["tokenizer"]
+
+    # 2. Define audio pipeline:
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    sb.dataio.dataset.add_dynamic_item(valtest_datasets, audio_pipeline)
+
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline_train(wav):
+        # Speed Perturb is done here so it is multi-threaded with the
+        # workers of the dataloader (faster).
+        if "speed_perturb" in hparams:
+            sig = sb.dataio.dataio.read_audio(wav)
+
+            sig = hparams["speed_perturb"](sig.unsqueeze(0)).squeeze(0)
+        else:
+            sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    sb.dataio.dataset.add_dynamic_item([train_data], audio_pipeline_train)
+
+    # 3. Define text pipeline:
+    @sb.utils.data_pipeline.takes("wrd")
+    @sb.utils.data_pipeline.provides(
+        "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
+    )
+    def text_pipeline(wrd):
+        yield wrd
+        tokens_list = tokenizer.encode_as_ids(wrd)
+        yield tokens_list
+        tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
+        yield tokens_bos
+        tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
+        yield tokens_eos
+        tokens = torch.LongTensor(tokens_list)
+        yield tokens
+
+    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
+
+    # 4. Set output:
+    sb.dataio.dataset.set_output_keys(
+        datasets, ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"],
+    )
+
+    # 5. If Dynamic Batching is used, we instantiate the needed samplers.
+    train_batch_sampler = None
+    valid_batch_sampler = None
+    if hparams["dynamic_batching"]:
+        from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
+
+        dynamic_hparams_train = hparams["dynamic_batch_sampler_train"]
+        dynamic_hparams_valid = hparams["dynamic_batch_sampler_valid"]
+
+        train_batch_sampler = DynamicBatchSampler(
+            train_data,
+            length_func=lambda x: x["duration"],
+            **dynamic_hparams_train,
+        )
+        valid_batch_sampler = DynamicBatchSampler(
+            valid_data,
+            length_func=lambda x: x["duration"],
+            **dynamic_hparams_valid,
+        )
+
+    return (
+        train_data,
+        valid_data,
+        test_datasets,
+        tokenizer,
+        train_batch_sampler,
+        valid_batch_sampler,
+    )
+
+
+if __name__ == "__main__":
+    # CLI:
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # 1.  # Dataset prep (parsing Librispeech)
+    from librispeech_prepare import prepare_librispeech  # noqa
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # multi-gpu (ddp) save data preparation
+    run_on_main(
+        prepare_librispeech,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "tr_splits": hparams["train_splits"],
+            "dev_splits": hparams["dev_splits"],
+            "te_splits": hparams["test_splits"],
+            "save_folder": hparams["output_folder"],
+            "merge_lst": hparams["train_splits"],
+            "merge_name": "train.csv",
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # here we create the datasets objects as well as tokenization and encoding
+    (
+        train_data,
+        valid_data,
+        test_datasets,
+        tokenizer,
+        train_bsampler,
+        valid_bsampler,
+    ) = dataio_prepare(hparams)
+
+    # We download the pretrained LM from HuggingFace (or elsewhere depending on
+    # the path given in the YAML file). The tokenizer is loaded at the same time.
+    run_on_main(hparams["pretrainer"].collect_files)
+    hparams["pretrainer"].load_collected()
+
+    # ###################################################################
+    # Define Bayesian modules
+    # ###################################################################
+    from speechbrain.nnet.attention import PositionalwiseFeedForward
+
+    try:
+        from bayestorch.distributions import (
+            get_log_scale_normal,
+            get_softplus_inv_scale_normal,
+        )
+        from bayestorch.nn import VariationalPosteriorModule
+    except ImportError:
+        raise ImportError(
+            "Please install BayesTorch to use BayesSpeech (e.g. `pip install bayestorch>=0.0.3`)"
+        )
+
+    # Minimize number of modifications to existing training/evaluation loops
+    # NOTE: differently from https://arxiv.org/abs/2301.11276, we employ the standard
+    # reparameterization trick instead of the local reparameterization trick
+    class BBBModule(VariationalPosteriorModule):
+        def forward(self, *args, **kwargs):
+            if self.training:
+                output, self.kl_div = super().forward(
+                    *args, num_mc_samples=1, return_kl_div=True, **kwargs
+                )
+                return output
+            output, self.kl_div = (
+                super().forward(
+                    *args,
+                    num_mc_samples=hparams["num_eval_mc_samples"],
+                    **kwargs,
+                ),
+                0.0,
+            )
+            return output
+
+    parameters = []
+    for module in hparams["modules"]["Transformer"].modules():
+        if isinstance(module, PositionalwiseFeedForward):
+            parameters += list(module.parameters())
+    prior_builder, prior_kwargs = get_log_scale_normal(
+        parameters, log_scale=hparams["normal_prior_log_scale"],
+    )
+    posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal(
+        parameters,
+        softplus_inv_scale=hparams["normal_posterior_softplus_inv_scale"],
+        requires_grad=True,
+    )
+    hparams["Transformer"] = hparams["modules"]["Transformer"] = BBBModule(
+        hparams["modules"]["Transformer"],
+        prior_builder,
+        prior_kwargs,
+        posterior_builder,
+        posterior_kwargs,
+        parameters,
+    )
+    hparams["model"] = torch.nn.ModuleList(
+        [hparams["CNN"], hparams["seq_lin"], hparams["ctc_lin"]]
+    )
+    hparams["ctc_scorer"].ctc_fc = hparams["ctc_lin"]
+    hparams["test_search"].modules = hparams["valid_search"].modules = [
+        hparams["Transformer"],
+        hparams["seq_lin"],
+    ]
+    hparams["checkpointer"].recoverables["model"] = hparams["model"]
+    hparams["checkpointer"].add_recoverable(
+        "Transformer", hparams["Transformer"],
+    )
+    # ###################################################################
+
+    # Trainer initialization
+    asr_brain = ASR(
+        modules=hparams["modules"],
+        opt_class=hparams["Adam"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    # adding objects to trainer:
+    asr_brain.tokenizer = hparams["tokenizer"]
+    train_dataloader_opts = hparams["train_dataloader_opts"]
+    valid_dataloader_opts = hparams["valid_dataloader_opts"]
+
+    if train_bsampler is not None:
+        collate_fn = None
+        if "collate_fn" in train_dataloader_opts:
+            collate_fn = train_dataloader_opts["collate_fn"]
+
+        train_dataloader_opts = {
+            "batch_sampler": train_bsampler,
+            "num_workers": hparams["num_workers"],
+        }
+
+        if collate_fn is not None:
+            train_dataloader_opts["collate_fn"] = collate_fn
+
+    if valid_bsampler is not None:
+        collate_fn = None
+        if "collate_fn" in valid_dataloader_opts:
+            collate_fn = valid_dataloader_opts["collate_fn"]
+
+        valid_dataloader_opts = {"batch_sampler": valid_bsampler}
+
+        if collate_fn is not None:
+            valid_dataloader_opts["collate_fn"] = collate_fn
+
+    # Training
+    asr_brain.fit(
+        asr_brain.hparams.epoch_counter,
+        train_data,
+        valid_data,
+        train_loader_kwargs=train_dataloader_opts,
+        valid_loader_kwargs=valid_dataloader_opts,
+    )
+
+    # Testing
+    if not os.path.exists(hparams["output_wer_folder"]):
+        os.makedirs(hparams["output_wer_folder"])
+
+    for k in test_datasets.keys():  # keys are test_clean, test_other etc
+        asr_brain.hparams.test_wer_file = os.path.join(
+            hparams["output_wer_folder"], f"wer_{k}.txt"
+        )
+        asr_brain.evaluate(
+            test_datasets[k],
+            max_key="ACC",
+            test_loader_kwargs=hparams["test_dataloader_opts"],
+        )
diff --git a/recipes/LibriSpeech/ASR/transformer/train_with_whisper.py b/recipes/LibriSpeech/ASR/transformer/train_with_whisper.py
index 311b504538f11e4754fbddae3f0f210c3c7d4bef..d5c5789bc8c25262127479920b6d2de5899a6496 100644
--- a/recipes/LibriSpeech/ASR/transformer/train_with_whisper.py
+++ b/recipes/LibriSpeech/ASR/transformer/train_with_whisper.py
@@ -35,10 +35,13 @@ class ASR(sb.Brain):
         wavs, wav_lens = batch.sig
         bos_tokens, bos_tokens_lens = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            bos_tokens = self.hparams.wav_augment.replicate_labels(bos_tokens)
+            bos_tokens_lens = self.hparams.wav_augment.replicate_labels(
+                bos_tokens_lens
+            )
 
         # We compute the padding mask and replace the values with the pad_token_id
         # that the Whisper decoder expect to see.
@@ -55,10 +58,16 @@ class ASR(sb.Brain):
         log_probs = self.hparams.log_softmax(logits)
 
         hyps = None
-        if stage == sb.Stage.VALID:
-            hyps, _ = self.hparams.valid_greedy_searcher(enc_out, wav_lens)
-        elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_beam_searcher(enc_out, wav_lens)
+        if stage == sb.Stage.VALID or stage == sb.Stage.TEST:
+            # Decide searcher for inference: valid or test search
+            if stage == sb.Stage.VALID:
+                hyps, _, _, _ = self.hparams.valid_search(
+                    enc_out.detach(), wav_lens
+                )
+            else:
+                hyps, _, _, _ = self.hparams.test_search(
+                    enc_out.detach(), wav_lens
+                )
 
         return log_probs, hyps, wav_lens
 
@@ -70,6 +79,13 @@ class ASR(sb.Brain):
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
 
+        # Label Augmentation
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
+            )
+
         loss = self.hparams.nll_loss(
             log_probs, tokens_eos, length=tokens_eos_lens,
         )
@@ -77,6 +93,8 @@ class ASR(sb.Brain):
         if stage != sb.Stage.TRAIN:
             tokens, tokens_lens = batch.tokens
 
+            hyps = [hyp[0] if len(hyp) > 0 else [] for hyp in hyps]
+
             # Decode token terms to words
             predicted_words = self.tokenizer.batch_decode(
                 hyps, skip_special_tokens=True
@@ -241,7 +259,6 @@ if __name__ == "__main__":
     # CLI:
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -278,17 +295,11 @@ if __name__ == "__main__":
     tokenizer.set_prefix_tokens(hparams["language"], "transcribe", False)
 
     # we need to prepare the tokens for searchers
-    hparams["valid_greedy_searcher"].set_decoder_input_tokens(
-        tokenizer.prefix_tokens
-    )
-    hparams["valid_greedy_searcher"].set_language_token(
-        tokenizer.prefix_tokens[1]
-    )
+    hparams["valid_search"].set_decoder_input_tokens(tokenizer.prefix_tokens)
+    hparams["valid_search"].set_language_token(tokenizer.prefix_tokens[1])
 
-    hparams["test_beam_searcher"].set_decoder_input_tokens(
-        tokenizer.prefix_tokens
-    )
-    hparams["test_beam_searcher"].set_language_token(tokenizer.prefix_tokens[1])
+    hparams["test_search"].set_decoder_input_tokens(tokenizer.prefix_tokens)
+    hparams["test_search"].set_language_token(tokenizer.prefix_tokens[1])
 
     # here we create the datasets objects as well as tokenization and encoding
     train_data, valid_data, test_datasets = dataio_prepare(hparams, tokenizer)
diff --git a/recipes/LibriSpeech/G2P/README.md b/recipes/LibriSpeech/G2P/README.md
index 5462d320a467bf030a28e53645f88c4cb40334ff..61a621f06b032ba96878be4045271d3d2f768d3b 100644
--- a/recipes/LibriSpeech/G2P/README.md
+++ b/recipes/LibriSpeech/G2P/README.md
@@ -109,6 +109,10 @@ Pretrained language models can be found at the following URLs:
 * **RNN**: https://www.dropbox.com/sh/pig0uk80xxii7cg/AACQ1rrRLYthvpNZ5FadPLtRa?dl=0
 * **Transformer**: https://www.dropbox.com/sh/tkf6di10edpz4i6/AAArnGAkE0bEEOvOGfc6KWuma?dl=0
 
+
+The best model is available on HuggingFace:
+https://huggingface.co/speechbrain/soundchoice-g2p
+
 Training Time
 -------------
 All reference times are given for a Quattro P5000 GPU. These are rough estimations only - exact training times will vary depending on the hyperparameters chosen and system configuration.
diff --git a/recipes/LibriSpeech/G2P/hparams/hparams_g2p_rnn.yaml b/recipes/LibriSpeech/G2P/hparams/hparams_g2p_rnn.yaml
index 516d943713ff88b9ff71e8cfc9e1ea9a6009c837..f487ffbe1ce7aa22c06cfd7fd11e4c6568371a68 100644
--- a/recipes/LibriSpeech/G2P/hparams/hparams_g2p_rnn.yaml
+++ b/recipes/LibriSpeech/G2P/hparams/hparams_g2p_rnn.yaml
@@ -95,7 +95,7 @@ homograph_loss_weight: 2.0
 lr: 0.002
 save_for_pretrained: True
 
-# Model parameters
+####################### Model Parameters #######################################
 output_neurons: !apply:speechbrain.utils.hparams.choice
     value: !ref <phn_tokenize>
     choices:
@@ -342,60 +342,79 @@ lm_model: !new:speechbrain.lobes.models.RNNLM.RNNLM
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
 
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: !ref <beam_search_temperature_lm>
+
+# Scorer
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+    vocab_size: !ref <output_neurons>
+
+
+scorer_lm: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>, !ref <coverage_scorer>]
+    weights:
+        ctc: !ref <beam_search_ctc_weight_decode>
+        transformerlm: !ref <beam_search_lm_weight>
+        coverage: !ref <beam_search_coverage_penalty>
+
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>, !ref <coverage_scorer>]
+    weights:
+        ctc: !ref <beam_search_ctc_weight_decode>
+        coverage: !ref <beam_search_coverage_penalty>
+
 beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     embedding: !ref <emb>
     decoder: !ref <dec>
     linear: !ref <lin>
-    ctc_linear: !ref <ctc_lin>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <beam_search_min_decode_ratio>
     max_decode_ratio: !ref <beam_search_max_decode_ratio>
     beam_size: !ref <beam_search_beam_size>
     eos_threshold: !ref <beam_search_eos_threshold>
     using_max_attn_shift: !ref <beam_search_using_max_attn_shift>
     max_attn_shift: !ref <beam_search_max_attn_shift>
-    coverage_penalty: !ref <beam_search_coverage_penalty>
-    ctc_weight: !ref <beam_search_ctc_weight_decode>
+    temperature: !ref <beam_search_temperature>
+    scorer: !ref <scorer>
 
 beam_searcher_valid: !new:speechbrain.decoders.S2SRNNBeamSearcher
     embedding: !ref <emb>
     decoder: !ref <dec>
     linear: !ref <lin>
-    ctc_linear: !ref <ctc_lin>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <beam_search_min_decode_ratio>
     max_decode_ratio: !ref <beam_search_max_decode_ratio>
     beam_size: !ref <beam_search_beam_size_valid>
     eos_threshold: !ref <beam_search_eos_threshold>
     using_max_attn_shift: !ref <beam_search_using_max_attn_shift>
     max_attn_shift: !ref <beam_search_max_attn_shift>
-    coverage_penalty: !ref <beam_search_coverage_penalty>
-    ctc_weight: !ref <beam_search_ctc_weight_decode>
+    temperature: !ref <beam_search_temperature>
+    scorer: !ref <scorer>
 
-beam_searcher_lm: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearchLM
+beam_searcher_lm: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
     embedding: !ref <emb>
     decoder: !ref <dec>
     linear: !ref <lin>
-    ctc_linear: !ref <ctc_lin>
-    language_model: !ref <lm_model>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <beam_search_min_decode_ratio>
     max_decode_ratio: !ref <beam_search_max_decode_ratio>
     beam_size: !ref <beam_search_beam_size>
     eos_threshold: !ref <beam_search_eos_threshold>
     using_max_attn_shift: !ref <beam_search_using_max_attn_shift>
     max_attn_shift: !ref <beam_search_max_attn_shift>
-    coverage_penalty: !ref <beam_search_coverage_penalty>
-    ctc_weight: !ref <beam_search_ctc_weight_decode>
-    lm_weight: !ref <beam_search_lm_weight>
     temperature: !ref <beam_search_temperature>
-    temperature_lm: !ref <beam_search_temperature_lm>
+    scorer: !ref <scorer_lm>
 
 
 lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
diff --git a/recipes/LibriSpeech/G2P/hparams/hparams_g2p_transformer.yaml b/recipes/LibriSpeech/G2P/hparams/hparams_g2p_transformer.yaml
index f4e0ec7a6b6ec9d83362b7cdb8d5cd88a04b89a5..e1c0f44c79ad48982beb34c64f81cff2ae326deb 100644
--- a/recipes/LibriSpeech/G2P/hparams/hparams_g2p_transformer.yaml
+++ b/recipes/LibriSpeech/G2P/hparams/hparams_g2p_transformer.yaml
@@ -95,7 +95,7 @@ lr_dont_halve_until_epoch: 1
 lr_patience: 1
 save_for_pretrained: True
 
-# Model parameters
+####################### Model Parameters #######################################
 output_neurons: !apply:speechbrain.utils.hparams.choice
     value: !ref <phn_tokenize>
     choices:
@@ -367,69 +367,73 @@ opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
     betas: (0.99, 0.998)
 
-beam_searcher: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules:
-        - !ref <model>
-        - !ref <lin>
-        - !ref <ctc_lin>
-    bos_index: !ref <bos_index>
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
     eos_index: !ref <eos_index>
     blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+    vocab_size: !ref <output_neurons>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+    language_model: !ref <lm_model>
+    temperature: !ref <beam_search_temperature_lm>
+
+scorer_lm: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>, !ref <coverage_scorer>]
+    weights:
+        ctc: !ref <beam_search_ctc_weight_decode>
+        transformerlm: !ref <beam_search_lm_weight>
+        coverage: !ref <beam_search_coverage_penalty>
+
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>, !ref <coverage_scorer>]
+    weights:
+        ctc: !ref <beam_search_ctc_weight_decode>
+        coverage: !ref <beam_search_coverage_penalty>
+
+beam_searcher: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <model>, !ref <lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
     min_decode_ratio: !ref <beam_search_min_decode_ratio>
     max_decode_ratio: !ref <beam_search_max_decode_ratio>
     beam_size: !ref <beam_search_beam_size>
     eos_threshold: !ref <beam_search_eos_threshold>
-    using_max_attn_shift: !ref <beam_search_using_max_attn_shift>
+    using_max_attn_shift: False
+    length_normalization: True
     max_attn_shift: !ref <beam_search_max_attn_shift>
-    coverage_penalty: !ref <beam_search_coverage_penalty>
-    ctc_weight: !ref <beam_search_ctc_weight_decode>
-    using_eos_threshold: False
-    length_normalization: False
-
-
-beam_searcher_lm: !new:speechbrain.decoders.seq2seq.S2STransformerBeamSearch
-    modules:
-        - !ref <model>
-        - !ref <lin>
-        - !ref <ctc_lin>
-    lm_modules:
-        - !ref <lm_model>
+    temperature: !ref <beam_search_temperature>
+    scorer: !ref <scorer>
+
+beam_searcher_lm: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <model>, !ref <lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <beam_search_min_decode_ratio>
     max_decode_ratio: !ref <beam_search_max_decode_ratio>
     beam_size: !ref <beam_search_beam_size>
     eos_threshold: !ref <beam_search_eos_threshold>
-    using_max_attn_shift: !ref <beam_search_using_max_attn_shift>
+    using_max_attn_shift: False
+    length_normalization: True
     max_attn_shift: !ref <beam_search_max_attn_shift>
-    coverage_penalty: !ref <beam_search_coverage_penalty>
-    ctc_weight: !ref <beam_search_ctc_weight_decode>
-    lm_weight: !ref <beam_search_lm_weight>
     temperature: !ref <beam_search_temperature>
-    temperature_lm: !ref <beam_search_temperature_lm>
-    using_eos_threshold: False
-    length_normalization: False
+    scorer: !ref <scorer_lm>
 
-
-beam_searcher_valid: !new:speechbrain.decoders.S2STransformerBeamSearch
-    modules:
-        - !ref <model>
-        - !ref <lin>
-        - !ref <ctc_lin>
+beam_searcher_valid: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <model>, !ref <lin>]
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <beam_search_min_decode_ratio>
     max_decode_ratio: !ref <beam_search_max_decode_ratio>
     beam_size: !ref <beam_search_beam_size_valid>
     eos_threshold: !ref <beam_search_eos_threshold>
-    using_max_attn_shift: !ref <beam_search_using_max_attn_shift>
+    using_max_attn_shift: False
+    length_normalization: True
     max_attn_shift: !ref <beam_search_max_attn_shift>
-    coverage_penalty: !ref <beam_search_coverage_penalty>
-    ctc_weight: !ref <beam_search_ctc_weight_decode>
-    using_eos_threshold: False
-    length_normalization: False
+    temperature: !ref <beam_search_temperature>
+    scorer: !ref <scorer_lm>
 
 
 lr_annealing: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
diff --git a/recipes/LibriSpeech/G2P/hparams/hparams_lm_rnn.yaml b/recipes/LibriSpeech/G2P/hparams/hparams_lm_rnn.yaml
index 077ed019b96f458b9e2640b897bfaf6469dd4cb8..7e1b7bc4af0d06d9258a2c5b85be9d7074e3ce38 100644
--- a/recipes/LibriSpeech/G2P/hparams/hparams_lm_rnn.yaml
+++ b/recipes/LibriSpeech/G2P/hparams/hparams_lm_rnn.yaml
@@ -50,11 +50,11 @@ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
 tokenizer_file: <output_folder>/save/phoneme_tokenizer.model
 
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 batch_size: 80
 lr: 0.001
-accu_steps: 1 # Gradient accumulation to simulate large batch training
+grad_accumulation_factor: 1 # Gradient accumulation to simulate large batch training
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
 # Dataloader options
@@ -68,7 +68,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
     batch_size: 1
 
-# Model parameters
+####################### Model Parameters #######################################
 model_dim: !apply:speechbrain.utils.hparams.choice
     value: !ref <phn_tokenize>
     choices:
diff --git a/recipes/LibriSpeech/G2P/hparams/hparams_lm_transformer.yaml b/recipes/LibriSpeech/G2P/hparams/hparams_lm_transformer.yaml
index d6b9da36964decb14e8eb6d6c1bce3107f8aaaa2..5e319e3d861df3c7a832f6219aaf142d35832a80 100644
--- a/recipes/LibriSpeech/G2P/hparams/hparams_lm_transformer.yaml
+++ b/recipes/LibriSpeech/G2P/hparams/hparams_lm_transformer.yaml
@@ -39,11 +39,11 @@ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
 # Tokenizer model (you must use the same tokenizer for LM and ASR training)
 tokenizer_file: <output_folder>/save/phoneme_tokenizer.model
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 batch_size: 80
 lr: 0.001
-accu_steps: 1 # Gradient accumulation to simulate large batch training
+grad_accumulation_factor: 1 # Gradient accumulation to simulate large batch training
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
 # Dataloader options
@@ -57,7 +57,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
     batch_size: 1
 
-# Model parameters
+####################### Model Parameters #######################################
 
 emb_dim: 256 # dimension of the embeddings
 transformer_num_heads: 4
diff --git a/recipes/LibriSpeech/G2P/train.py b/recipes/LibriSpeech/G2P/train.py
index a6cc8033c70731180650907529790c6f70152c47..f714e266ce099fd5a2910822ca1769617d9d89f5 100644
--- a/recipes/LibriSpeech/G2P/train.py
+++ b/recipes/LibriSpeech/G2P/train.py
@@ -18,7 +18,6 @@ import datasets
 import logging
 import os
 import random
-import torch
 import speechbrain as sb
 import sys
 from enum import Enum
@@ -26,7 +25,7 @@ from collections import namedtuple
 from hyperpyyaml import load_hyperpyyaml
 from functools import partial
 from speechbrain.utils.distributed import run_on_main
-from speechbrain.pretrained.training import save_for_pretrained
+from speechbrain.utils.pretrained import save_for_pretrained
 from speechbrain.lobes.models.g2p.dataio import (
     enable_eos_bos,
     grapheme_pipeline,
@@ -154,7 +153,6 @@ class G2PBrain(sb.Brain):
             step = self.train_step["name"]
             logger.info(f"Attempting to restore checkpoint for step {step}")
             result = self.checkpointer.recover_if_possible(
-                device=torch.device(self.device),
                 min_key=min_key,
                 max_key=max_key,
                 ckpt_predicate=(lambda ckpt: ckpt.meta.get("step") == step),
@@ -166,9 +164,7 @@ class G2PBrain(sb.Brain):
                     step,
                 )
                 result = self.checkpointer.recover_if_possible(
-                    device=torch.device(self.device),
-                    min_key=min_key,
-                    max_key=max_key,
+                    min_key=min_key, max_key=max_key,
                 )
                 if result:
                     logger.info(
@@ -221,7 +217,8 @@ class G2PBrain(sb.Brain):
                 if stage == sb.Stage.VALID
                 else self.beam_searcher
             )
-            hyps, scores = beam_searcher(encoder_out, char_lens)
+
+            hyps, _, _, _ = beam_searcher(encoder_out, char_lens)
 
         return G2PPredictions(p_seq, char_lens, hyps, ctc_logprobs, attn)
 
@@ -386,24 +383,6 @@ class G2PBrain(sb.Brain):
         current_epoch = self.epoch_counter.current
         return current_epoch <= self.train_step["ctc_epochs"]
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         self.seq_metrics = self.hparams.seq_stats()
@@ -1139,7 +1118,7 @@ def load_dependencies(hparams, run_opts):
     deps_pretrainer = hparams.get("deps_pretrainer")
     if deps_pretrainer:
         run_on_main(deps_pretrainer.collect_files)
-        deps_pretrainer.load_collected(device=run_opts["device"])
+        deps_pretrainer.load_collected()
 
 
 def check_tensorboard(hparams):
diff --git a/recipes/LibriSpeech/G2P/train_lm.py b/recipes/LibriSpeech/G2P/train_lm.py
index 2bc0acd9991aaa85aa7928e890eeda7c204731ac..b05a4feb637bebb515d04981d8de401fb799c205 100644
--- a/recipes/LibriSpeech/G2P/train_lm.py
+++ b/recipes/LibriSpeech/G2P/train_lm.py
@@ -92,37 +92,9 @@ class LM(sb.core.Brain):
         )
         return loss
 
-    def fit_batch(self, batch):
-        """Runs all the steps needed to train the model on a single batch.
-
-        Arguments
-        ---------
-        batch : PaddedBatch
-            This batch object contains all the relevant tensors for computation.
-
-        Returns
-        -------
-        Loss : torch.Tensor
-            A tensor containing the loss (single real number).
-        """
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        # Loss backpropagation (gradient computation)
-        (loss / self.hparams.accu_steps).backward()
-
-        # Manage gradient accumulation
-        if self.step % self.hparams.accu_steps == 0:
-
-            # Gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
-
-            # Update the parameters
-            self.optimizer.step()
-
-            # Reset the gradient
-            self.optimizer.zero_grad()
-
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
             if isinstance(
                 self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
             ) or isinstance(
@@ -131,8 +103,6 @@ class LM(sb.core.Brain):
             ):
                 self.hparams.lr_annealing(self.optimizer)
 
-        return loss
-
     def on_stage_end(self, stage, stage_loss, epoch):
         """Gets called at the end of an epoch.
 
diff --git a/recipes/LibriSpeech/LM/README.md b/recipes/LibriSpeech/LM/README.md
index 836bbe0477e8a17ce31a64401890efb5b9d6fe34..63458f49462c065252c41bb34b37f1f029fcb1e7 100644
--- a/recipes/LibriSpeech/LM/README.md
+++ b/recipes/LibriSpeech/LM/README.md
@@ -1,7 +1,7 @@
 # Language Model with LibriSpeech
 This folder contains recipes for training language models for the LibriSpeech Dataset.
-It supports both an RNN-based LM and a Transformer-based LM.
-The scripts rely on the HuggingFace dataset, which manages data reading and loading from
+It supports n-gram LM, RNN-based LM, and Transformer-based LM.
+The scripts is relying on the HuggingFace dataset for RNN/Transformer based LM, which manages data reading and loading from
 large text corpora.
 
 You can download LibriSpeech at http://www.openslr.org/12
@@ -14,10 +14,31 @@ Before proceeding, ensure you have installed the necessary additional dependenci
 pip install -r extra_requirements.txt
 ```
 
+If you want to train an n-gram, in this recipe we are using  the popular KenLM library. Let's start by installing the Ubuntu library prerequisites. For a complete guide on how to install required dependencies, please refer to [this](https://kheafield.com/code/kenlm/dependencies/) link:
+ ```
+ sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
+ ```
+
+ Next, we need to start downloading and unpacking the KenLM repo.
+ ```
+ wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
+ ```
+
+KenLM is written in C++, so we'll make use of cmake to build the binaries.
+ ```
+mkdir kenlm/build && cd kenlm/build && cmake .. && make -j2
+ ```
+
+Now, make sure that the executables are added to your .bashrc file. To do it,
+- Open the ~/.bashrc file in a text editor.
+- Scroll to the end of the file and add the following line:  ```export PATH=$PATH:/your/path/to/kenlm/build/bin ```
+- Save it and type:  `source ~/.bashrc `
+
 # How to run:
 ```shell
 python train.py hparams/RNNLM.yaml
 python train.py hparams/transformer.yaml
+python train_ngram.py hparams/train_ngram.yaml  --data_folder=your/data/folder
 ```
 
 | Release | hyperparams file | Test PP | Model link | GPUs |
@@ -25,7 +46,8 @@ python train.py hparams/transformer.yaml
 | 20-05-22 | RNNLM.yaml (1k BPE) | --.-- | [link](https://www.dropbox.com/sh/8xpybezuv70ibcg/AAByv2NuNv_ZFXuDdG89-MVPa?dl=0) | 1xV100 32GB |
 | 20-05-22 | RNNLM.yaml (5k BPE) | --.-- | [link](https://www.dropbox.com/sh/8462ef441wvava2/AABNfHr07J_0SsdaM1yO5qkxa?dl=0) | 1xV100 32GB |
 | 20-05-22 | transformer.yaml | --.-- | [link](https://www.dropbox.com/sh/6uwqlw2tvv3kiy6/AACgvTR5jihyMrugBrpZPFNha?dl=0) | 1xV100 32GB |
-
+| 22-01-24 | 4-gram - train_ngram.yaml | --.-- | [link](https://www.dropbox.com/scl/fi/kkd5jrwthpahn4t7e7sgk/4gram_lm.arpa?rlkey=mc820i9bugpi3oxtwwd6ulz0b&dl=0) | --.-- |
+| 22-01-24 | 3-gram - train_ngram.yaml | --.-- | [link](https://www.dropbox.com/scl/fi/juryiq2e50bsbdy1qx540/3gram_lm.arpa?rlkey=3ntfnkn6zxda9memm5zh1mmt9&dl=0) | --.-- |
 
 # Training time
 Training a LM takes a lot of time. In our case, it take 3/4 weeks on 4 TESLA V100. Use the pre-trained model to avoid training it from scratch
diff --git a/recipes/LibriSpeech/LM/hparams/RNNLM.yaml b/recipes/LibriSpeech/LM/hparams/RNNLM.yaml
index bbc5943faed7fe039a2e66733624bce1cfca43da..0896de96032620c15d8b0e0cf19960aed1b953c8 100644
--- a/recipes/LibriSpeech/LM/hparams/RNNLM.yaml
+++ b/recipes/LibriSpeech/LM/hparams/RNNLM.yaml
@@ -29,11 +29,11 @@ test_transcripts_pattern: "test*/**/*.trans.txt"
 # Tokenizer model
 tokenizer_file: https://www.dropbox.com/s/o7gnouwdoqchotj/1000_unigram.model?dl=1
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 batch_size: 80
 lr: 0.001
-accu_steps: 1 # Gradient accumulation to simulate large batch training
+grad_accumulation_factor: 1 # Gradient accumulation to simulate large batch training
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
 # Dataloader options
@@ -47,7 +47,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
     batch_size: 1
 
-# Model parameters
+####################### Model Parameters #######################################
 emb_size: 128
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.0
diff --git a/recipes/LibriSpeech/LM/hparams/train_ngram.yaml b/recipes/LibriSpeech/LM/hparams/train_ngram.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e83d79c52c5b16d9cff5fd9f5551231de812d98e
--- /dev/null
+++ b/recipes/LibriSpeech/LM/hparams/train_ngram.yaml
@@ -0,0 +1,24 @@
+#########
+# Recipe for Training kenLM on LibriSpeech Data.
+# It is used to boost any CTC or CTC/joint attention models.
+#
+# Author:
+#  - Adel Moumen 2024
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+output_folder: !ref results/n_gram_lm/
+# Data files
+data_folder: !PLACEHOLDER # e.g, /localscratch/LibriSpeech
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: []
+test_splits: []
+train_csv: !ref <output_folder>/train.csv
+lang_dir: !ref <output_folder>/lang
+vocab_file: !ref <output_folder>/librispeech-vocab.txt
+sil_prob: 0.
+add_word_boundary: True
+caching: False
+skip_prep: False
+arpa_order: 3
+prune_level: [0, 1, 2]
+output_arpa: !ref <output_folder>/<arpa_order>-gram.arpa
diff --git a/recipes/LibriSpeech/LM/hparams/transformer.yaml b/recipes/LibriSpeech/LM/hparams/transformer.yaml
index e3d8373417329e3aa2d6a5a5f6bbcb9e2ff6fb62..50123a4c3cdbfabbc53f006aa6b132fa3392e9e0 100644
--- a/recipes/LibriSpeech/LM/hparams/transformer.yaml
+++ b/recipes/LibriSpeech/LM/hparams/transformer.yaml
@@ -29,11 +29,11 @@ test_transcripts_pattern: "test*/**/*.trans.txt"
 # Tokenizer model
 tokenizer_file: speechbrain/asr-transformer-transformerlm-librispeech/tokenizer.ckpt
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 batch_size: 16
 lr: 10
-accu_steps: 8 # Gradient accumulation to simulate large batch training
+grad_accumulation_factor: 8 # Gradient accumulation to simulate large batch training
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
 # Dataloader options
diff --git a/recipes/LibriSpeech/LM/train.py b/recipes/LibriSpeech/LM/train.py
index 3b3b846cab1d01a9abfbbbdf4a0aa98a95f3bb83..e4912d0b0aeefe27175172f3abc89a069fbcebf8 100644
--- a/recipes/LibriSpeech/LM/train.py
+++ b/recipes/LibriSpeech/LM/train.py
@@ -42,20 +42,9 @@ class LM(sb.core.Brain):
         )
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        (loss / self.hparams.accu_steps).backward()
-
-        if self.step % self.hparams.accu_steps == 0:
-            # gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
-
-            self.optimizer.step()
-            self.optimizer.zero_grad()
-
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
             if isinstance(
                 self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
             ) or isinstance(
@@ -64,15 +53,13 @@ class LM(sb.core.Brain):
             ):
                 self.hparams.lr_annealing(self.optimizer)
 
-        return loss
-
     def on_stage_end(self, stage, stage_loss, epoch):
         """Gets called at the end of a epoch."""
         stage_stats = {"loss": stage_loss}
         if stage == sb.Stage.TRAIN:
             self.train_stats = stage_stats
 
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
             if not (
                 isinstance(
                     self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
@@ -173,7 +160,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -190,7 +176,7 @@ if __name__ == "__main__":
     # We download the tokenizer from HuggingFace (or elsewhere depending on
     # the path given in the YAML file).
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     lm_brain = LM(
         modules=hparams["modules"],
diff --git a/recipes/LibriSpeech/LM/train_ngram.py b/recipes/LibriSpeech/LM/train_ngram.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ff17262b555b91080be4b2b32a3bb251ec84a6d
--- /dev/null
+++ b/recipes/LibriSpeech/LM/train_ngram.py
@@ -0,0 +1,176 @@
+"""
+Recipe to train kenlm ngram model.
+
+To run this recipe, do the following:
+> python train.py hparams/train.yaml --data_folder=/path/to/LibriSpeech
+
+Authors
+ * Adel Moumen 2024
+ * Pierre Champion 2023
+"""
+
+import os
+import sys
+import logging
+import speechbrain as sb
+from speechbrain.utils.distributed import run_on_main
+from hyperpyyaml import load_hyperpyyaml
+import speechbrain.k2_integration as sbk2
+from speechbrain.utils.data_utils import (
+    download_file,
+    get_list_from_csv,
+)
+
+logger = logging.getLogger(__name__)
+OPEN_SLR_11_LINK = "http://www.openslr.org/resources/11/"
+
+
+def download_librispeech_lm_training_text(destination):
+    """Download librispeech lm training and unpack it.
+
+    Arguments
+    ---------
+    destination : str
+        Place to put dataset.
+    """
+    f = "librispeech-lm-norm.txt.gz"
+    download_file(
+        OPEN_SLR_11_LINK + f, os.path.join(destination, f), unpack=True
+    )
+
+
+def dataprep_lm_training(
+    lm_dir,
+    output_arpa,
+    csv_files,
+    external_lm_corpus,
+    vocab_file,
+    arpa_order=3,
+    prune_level=[0, 1, 2],
+):
+    """Prepare lm txt corpus file for lm training with kenlm (https://github.com/kpu/kenlm)
+    Does nothing if output_arpa exists.
+    Else display to the user how to use kenlm in command line, then exit
+    (return code 1), the user has to run the command manually.
+    Instruction on how to compile kenlm (lmplz binary) is available in the
+    above link.
+
+    Arguments
+    ---------
+    lm_dir : str
+        Path to where to store txt corpus
+    output_arpa : str
+        File to write arpa lm
+    csv_files : List[str]
+        CSV files to use to increase lm txt corpus
+    external_lm_corpus : List[str]
+        (Big) text dataset corpus
+    vocab_file : str
+       N-grams that contain vocabulary items not in this file be pruned.
+    arpa_order : int
+        Order of the arpa lm
+    prune_level : List[int]
+        The numbers must be non-decreasing and the last number will be extended to any higher order.
+        For example, --prune 0 disables pruning (the default) while --prune 0 0 1 prunes singletons for orders three and higher.
+        Please refer to https://kheafield.com/code/kenlm/estimation/ for more details.
+    """
+    download_librispeech_lm_training_text(lm_dir)
+    column_text_key = "wrd"  # defined in librispeech_prepare.py
+    lm_corpus = os.path.join(lm_dir, "libri_lm_corpus.txt")
+    line_seen = set()
+    with open(lm_corpus, "w") as corpus:
+        for file in csv_files:
+            for line in get_list_from_csv(file, column_text_key):
+                corpus.write(line + "\n")
+                line_seen.add(line + "\n")
+        for file in external_lm_corpus:
+            with open(file) as f:
+                for line in f:
+                    if line not in line_seen:
+                        corpus.write(line)
+    prune_level = " ".join(map(str, prune_level))
+    cmd = f"lmplz -o {arpa_order} --prune {prune_level} --limit_vocab_file {vocab_file} < {lm_corpus} | sed  '1,20s/<unk>/<UNK>/1' > {output_arpa}"
+    logger.critical(
+        "RUN the following kenlm command to build a 3-gram arpa LM (https://github.com/kpu/kenlm):"
+    )
+    logger.critical(f"$ {cmd}")
+    sys.exit(0)
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing Librispeech)
+    import librispeech_prepare
+
+    # multi-gpu (ddp) save data preparation
+    run_on_main(
+        librispeech_prepare.prepare_librispeech,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "tr_splits": hparams["train_splits"],
+            "dev_splits": hparams["dev_splits"],
+            "te_splits": hparams["test_splits"],
+            "save_folder": hparams["output_folder"],
+            "merge_lst": hparams["train_splits"],
+            "merge_name": "train.csv",
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Download the vocabulary file for librispeech
+    librispeech_prepare.download_librispeech_vocab_text(
+        destination=hparams["vocab_file"]
+    )
+
+    # Create the lexicon.txt for k2
+    run_on_main(
+        sbk2.lexicon.prepare_char_lexicon,
+        kwargs={
+            "lang_dir": hparams["lang_dir"],
+            "vocab_files": [hparams["vocab_file"]],
+            "extra_csv_files": [hparams["output_folder"] + "/train.csv"]
+            if not hparams["skip_prep"]
+            else [],
+            "add_word_boundary": hparams["add_word_boundary"],
+        },
+    )
+
+    caching = (
+        {"cache": False}
+        if "caching" in hparams and hparams["caching"] is False
+        else {}
+    )
+
+    # Create the lang directory for k2
+    run_on_main(
+        sbk2.prepare_lang.prepare_lang,
+        kwargs={
+            "lang_dir": hparams["lang_dir"],
+            "sil_prob": hparams["sil_prob"],
+            **caching,
+        },
+    )
+
+    dataprep_lm_training(
+        lm_dir=hparams["output_folder"],
+        output_arpa=hparams["output_arpa"],
+        csv_files=[hparams["train_csv"]],
+        external_lm_corpus=[
+            os.path.join(hparams["output_folder"], "librispeech-lm-norm.txt")
+        ],
+        vocab_file=os.path.join(hparams["lang_dir"], "words.txt"),
+        arpa_order=hparams["arpa_order"],
+        prune_level=hparams["prune_level"],
+    )
diff --git a/recipes/LibriSpeech/Tokenizer/hparams/1K_unigram_subword_bpe.yaml b/recipes/LibriSpeech/Tokenizer/hparams/1K_unigram_subword_bpe.yaml
index f610d707b6637d8baf083bfc4b468008043616e7..9dda21f82781ccbcd5bda6fa686a8c5f93eb2cfc 100644
--- a/recipes/LibriSpeech/Tokenizer/hparams/1K_unigram_subword_bpe.yaml
+++ b/recipes/LibriSpeech/Tokenizer/hparams/1K_unigram_subword_bpe.yaml
@@ -16,14 +16,17 @@ skip_prep: False
 train_csv: !ref <output_folder>/train.csv
 valid_csv: !ref <output_folder>/dev-clean.csv
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 1000  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
 csv_read: wrd
-
+bos_id: 1
+eos_id: 2
 
 tokenizer: !name:speechbrain.tokenizers.SentencePiece.SentencePiece
+   bos_id: !ref <bos_id>
+   eos_id: !ref <eos_id>
    model_dir: !ref <output_folder>
    vocab_size: !ref <token_output>
    annotation_train: !ref <train_csv>
diff --git a/recipes/LibriSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml b/recipes/LibriSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml
index c312ce5bbc3c8bc2843966c77d7fc5e97e0a5831..1f328c6f1682dcf2f25b34632b96ecdb96e5b45a 100644
--- a/recipes/LibriSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml
+++ b/recipes/LibriSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml
@@ -16,7 +16,7 @@ skip_prep: False
 train_csv: !ref <output_folder>/train.csv
 valid_csv: !ref <output_folder>/dev-clean.csv
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 5000  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
diff --git a/recipes/LibriSpeech/Tokenizer/train.py b/recipes/LibriSpeech/Tokenizer/train.py
index 3a523876ccab2ec018960228c84c82c19717460c..b580f18e16275791bbfa9b6d9c5e0d266e1c5717 100644
--- a/recipes/LibriSpeech/Tokenizer/train.py
+++ b/recipes/LibriSpeech/Tokenizer/train.py
@@ -27,7 +27,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/LibriSpeech/librispeech_prepare.py b/recipes/LibriSpeech/librispeech_prepare.py
index 70b7726756be0525da4cfc572b617c59a424569e..bc8b0ffa7dbd058bf25b3444814929e217e5f8d9 100644
--- a/recipes/LibriSpeech/librispeech_prepare.py
+++ b/recipes/LibriSpeech/librispeech_prepare.py
@@ -5,7 +5,11 @@ Download: http://www.openslr.org/12
 
 Author
 ------
-Mirco Ravanelli, Ju-Chieh Chou, Loren Lugosch 2020
+ * Mirco Ravanelli, 2020
+ * Ju-Chieh Chou, 2020
+ * Loren Lugosch, 2020
+ * Pierre Champion, 2023
+ * Adel Moumen, 2024
 """
 
 import os
@@ -15,7 +19,10 @@ from collections import Counter
 from dataclasses import dataclass
 import functools
 import logging
-from speechbrain.utils.data_utils import download_file, get_all_files
+from speechbrain.utils.data_utils import (
+    download_file,
+    get_all_files,
+)
 from speechbrain.dataio.dataio import (
     load_pkl,
     save_pkl,
@@ -27,6 +34,13 @@ from speechbrain.utils.parallel import parallel_map
 logger = logging.getLogger(__name__)
 OPT_FILE = "opt_librispeech_prepare.pkl"
 SAMPLERATE = 16000
+OPEN_SLR_11_LINK = "http://www.openslr.org/resources/11/"
+OPEN_SLR_11_NGRAM_MODELs = [
+    "3-gram.arpa.gz",
+    "3-gram.pruned.1e-7.arpa.gz",
+    "3-gram.pruned.3e-7.arpa.gz",
+    "4-gram.arpa.gz",
+]
 
 
 def prepare_librispeech(
@@ -145,13 +159,13 @@ def prepare_librispeech(
 
     # Create lexicon.csv and oov.csv
     if create_lexicon:
-        create_lexicon_and_oov_csv(all_texts, data_folder, save_folder)
+        create_lexicon_and_oov_csv(all_texts, save_folder)
 
     # saving options
     save_pkl(conf, save_opt)
 
 
-def create_lexicon_and_oov_csv(all_texts, data_folder, save_folder):
+def create_lexicon_and_oov_csv(all_texts, save_folder):
     """
     Creates lexicon csv files useful for training and testing a
     grapheme-to-phoneme (G2P) model.
@@ -160,8 +174,6 @@ def create_lexicon_and_oov_csv(all_texts, data_folder, save_folder):
     ---------
     all_text : dict
         Dictionary containing text from the librispeech transcriptions
-    data_folder : str
-        Path to the folder where the original LibriSpeech dataset is stored.
     save_folder : str
         The directory where to store the csv files.
     Returns
@@ -370,7 +382,7 @@ def skip(splits, save_folder, conf):
     splits : list
         A list of the splits expected in the preparation.
     save_folder : str
-        The location of the seave directory
+        The location of the save directory
     conf : dict
         The configuration options to ensure they haven't changed.
 
@@ -454,3 +466,55 @@ def check_librispeech_folders(data_folder, splits):
                 "Librispeech dataset)" % split_folder
             )
             raise OSError(err_msg)
+
+
+def download_librispeech_vocab_text(destination):
+    """Download librispeech vocab file and unpack it.
+
+    Arguments
+    ---------
+    destination : str
+        Place to put vocab file.
+    """
+    f = "librispeech-vocab.txt"
+    download_file(OPEN_SLR_11_LINK + f, destination)
+
+
+def download_openslr_librispeech_lm(destination, rescoring_lm=True):
+    """Download openslr librispeech lm and unpack it.
+
+    Arguments
+    ---------
+    destination : str
+        Place to put lm.
+    rescoring_lms : bool
+        Also download bigger 4grams model
+    """
+    os.makedirs(destination, exist_ok=True)
+    for f in OPEN_SLR_11_NGRAM_MODELs:
+        if f.startswith("4") and not rescoring_lm:
+            continue
+        d = os.path.join(destination, f)
+        download_file(OPEN_SLR_11_LINK + f, d, unpack=True)
+
+
+def download_sb_librispeech_lm(destination, rescoring_lm=True):
+    """Download sb librispeech lm and unpack it.
+
+    Arguments
+    ---------
+    destination : str
+        Place to put lm.
+    rescoring_lms : bool
+        Also download bigger 4grams model
+    """
+    os.makedirs(destination, exist_ok=True)
+    download_file(
+        "https://www.dropbox.com/scl/fi/3fkkdlliavhveb5n3nsow/3gram_lm.arpa?rlkey=jgdrluppfut1pjminf3l3y106&dl=1",
+        os.path.join(destination, "3-gram_sb.arpa"),
+    )
+    if rescoring_lm:
+        download_file(
+            "https://www.dropbox.com/scl/fi/roz46ee0ah2lvy5csno4z/4gram_lm.arpa?rlkey=2wt8ozb1mqgde9h9n9rp2yppz&dl=1",
+            os.path.join(destination, "4-gram_sb.arpa"),
+        )
diff --git a/recipes/LibriSpeech/quantization/README.md b/recipes/LibriSpeech/quantization/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b1114c85fd9b62169665ed580bc497d8a0a2e0e5
--- /dev/null
+++ b/recipes/LibriSpeech/quantization/README.md
@@ -0,0 +1,49 @@
+
+# K-means (Quantization)
+This folder contains recipes for training K-means clustering model for the LibriSpeech Dataset.
+The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation.
+It supports  kmeans model using the features from  HuBERT, WAVLM or Wav2Vec.
+
+You can download LibriSpeech at http://www.openslr.org/12
+
+## Installing Extra Dependencies
+
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+
+```
+pip install -r extra_requirements.txt
+```
+
+# How to run:
+```shell
+python train.py hparams/train_with_{SSL_model}.yaml
+```
+
+# Results
+
+The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0).
+
+The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository.
+
+
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/LibriSpeech/quantization/extra-requirements.txt b/recipes/LibriSpeech/quantization/extra-requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d5e06028d853376200623c16ebcf4992f4ae60c2
--- /dev/null
+++ b/recipes/LibriSpeech/quantization/extra-requirements.txt
@@ -0,0 +1 @@
+scikit-learn
diff --git a/recipes/LibriSpeech/quantization/hparams/train_with_hubert.yaml b/recipes/LibriSpeech/quantization/hparams/train_with_hubert.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6dd20b0e396894dd10ccdbb99c604cc04fe47cc6
--- /dev/null
+++ b/recipes/LibriSpeech/quantization/hparams/train_with_hubert.yaml
@@ -0,0 +1,57 @@
+################################
+# Recipe for Training K-Means Clustering on LibriSpeech Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from LibriSpeech data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/LibriSpeech/clustering/hubert/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: []
+test_splits: []
+skip_prep: False
+ckpt_interval_minutes: 25 # save checkpoint every N min
+train_csv: !ref <output_folder>/train.csv
+sample_rate: 16000
+
+ssl_hub: facebook/hubert-base-ls960
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/hubert_checkpoint
+ssl_layer_num: 7
+batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+sorting: ascending
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/LibriSpeech/quantization/hparams/train_with_wav2vec.yaml b/recipes/LibriSpeech/quantization/hparams/train_with_wav2vec.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9eac469a866ea414a38ca0e9bdb32bd00bca8319
--- /dev/null
+++ b/recipes/LibriSpeech/quantization/hparams/train_with_wav2vec.yaml
@@ -0,0 +1,57 @@
+################################
+# Recipe for Training K-Means Clustering on LibriSpeech Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from LibriSpeech data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/LibriSpeech/clustering/wav2vec/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+data_folder: !PLACEHOLDERS # e,g./path/to/LibriSpeech
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: []
+test_splits: []
+skip_prep: False
+ckpt_interval_minutes: 25 # save checkpoint every N min
+train_csv: !ref <output_folder>/train.csv
+sample_rate: 16000
+
+ssl_hub: facebook/wav2vec2-large-960h-lv60-self
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/wav2vec_checkpoint
+ssl_layer_num: 7
+batch_size: 64 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+sorting: ascending
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/LibriSpeech/quantization/hparams/train_with_wavlm.yaml b/recipes/LibriSpeech/quantization/hparams/train_with_wavlm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..971ab94b1ac5e116577a5043f2cf6c093038e894
--- /dev/null
+++ b/recipes/LibriSpeech/quantization/hparams/train_with_wavlm.yaml
@@ -0,0 +1,57 @@
+################################
+# Recipe for Training K-Means Clustering on LibriSpeech Data
+# Using Self-Supervised Model-Based Representations
+#
+# It is used for creating discrete audio representations from LibriSpeech data.
+#
+# Author: Pooneh Mousavi (2023)
+################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/LibriSpeech/clustering/wavlm/<seed>
+save_folder: !ref <output_folder>/save
+
+# Data files
+data_folder: !PLACEHOLDERS # e,g./path/to/LibriSpeech
+train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
+dev_splits: []
+test_splits: []
+skip_prep: False
+ckpt_interval_minutes: 25 # save checkpoint every N min
+train_csv: !ref  <output_folder>/train.csv
+sample_rate: 16000
+
+ssl_hub: microsoft/wavlm-large
+freeze_feature_extractor: True
+freeze_ssl: True
+ssl_folder: !ref <save_folder>/wavlm_checkpoint
+ssl_layer_num: 7
+batch_size: 32 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
+
+sorting: ascending
+
+# Dataloader options
+train_dataloader_opts:
+   batch_size: !ref <batch_size>
+
+ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
+   source: !ref <ssl_hub>
+   output_norm: False
+   freeze: !ref <freeze_ssl>
+   freeze_feature_extractor: !ref <freeze_feature_extractor>
+   output_all_hiddens: True
+   save_path: !ref <ssl_folder>
+
+
+####################
+# Model Parameters #
+####################
+num_clusters: 128
+init: k-means++
+max_iter: 100
+kmeans_batch_size: 1000 # should be >= num_clusters
+tol: 0.0
+max_no_improvement: 100
+n_init: 20
+reassignment_ratio: 0.0
diff --git a/recipes/LibriSpeech/quantization/librispeech_prepare.py b/recipes/LibriSpeech/quantization/librispeech_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..a3126ec94ac2f948af13432942ec52e7c9c1e4d6
--- /dev/null
+++ b/recipes/LibriSpeech/quantization/librispeech_prepare.py
@@ -0,0 +1 @@
+../librispeech_prepare.py
\ No newline at end of file
diff --git a/recipes/LibriSpeech/quantization/train.py b/recipes/LibriSpeech/quantization/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..fca6a7841919575d8d43848a41453d183a7c4ac4
--- /dev/null
+++ b/recipes/LibriSpeech/quantization/train.py
@@ -0,0 +1,148 @@
+"""
+Recipe  to train K-means clustering model on self-supervised representations.
+
+To run this recipe, do the following:
+> python train.py hparams/train_with_[SSL-model].yaml --data_folder=/path/to/LibriSPeech
+Author
+ * Pooneh Mousavi 2023
+"""
+
+import os
+import sys
+import logging
+import speechbrain as sb
+from speechbrain.utils.distributed import run_on_main
+from hyperpyyaml import load_hyperpyyaml
+from torch.utils.data import DataLoader
+from speechbrain.dataio.dataloader import LoopedLoader
+from speechbrain.utils.kmeans import fetch_kmeans_model, train, save_model
+import torchaudio
+
+logger = logging.getLogger(__name__)
+
+
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions."""
+    data_folder = hparams["data_folder"]
+
+    train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
+        csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
+    )
+
+    if hparams["sorting"] == "ascending":
+        # we sort training data to speed up training and get better results.
+        train_data = train_data.filtered_sorted(sort_key="duration")
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["train_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "descending":
+        train_data = train_data.filtered_sorted(
+            sort_key="duration", reverse=True
+        )
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["train_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "random":
+        pass
+
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+
+    datasets = [train_data]
+
+    # 2. Define audio pipeline:
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        sig = sb.dataio.dataio.read_audio(wav)
+        info = torchaudio.info(wav)
+        resampled = torchaudio.transforms.Resample(
+            info.sample_rate, hparams["sample_rate"],
+        )(sig)
+        return resampled
+
+    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
+
+    # 4. Set output:
+    sb.dataio.dataset.set_output_keys(
+        datasets, ["id", "sig"],
+    )
+    return train_data
+
+
+if __name__ == "__main__":
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing Librispeech)
+    from librispeech_prepare import prepare_librispeech  # noqa
+
+    # multi-gpu (ddp) save data preparation
+    run_on_main(
+        prepare_librispeech,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "tr_splits": hparams["train_splits"],
+            "dev_splits": hparams["dev_splits"],
+            "te_splits": hparams["test_splits"],
+            "save_folder": hparams["output_folder"],
+            "merge_lst": hparams["train_splits"],
+            "merge_name": "train.csv",
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Load SSL model
+    hparams["ssl_model"] = hparams["ssl_model"].to(run_opts["device"])
+
+    # Make training Dataloader
+    train_set = dataio_prepare(hparams)
+    if not (
+        isinstance(train_set, DataLoader) or isinstance(train_set, LoopedLoader)
+    ):
+        train_set = sb.dataio.dataloader.make_dataloader(
+            train_set, **hparams["train_dataloader_opts"]
+        )
+
+    # Load pretrained KMeans model if it exists. Otherwise,  create new one.
+    checkpoint_path = os.path.join(
+        hparams["save_folder"], f"kmeans_{hparams['num_clusters']}.pt"
+    )
+    kmeans_model = fetch_kmeans_model(
+        n_clusters=hparams["num_clusters"],
+        init=hparams["init"],
+        max_iter=hparams["max_iter"],
+        batch_size=hparams["batch_size"],
+        tol=hparams["tol"],
+        max_no_improvement=hparams["max_no_improvement"],
+        n_init=hparams["n_init"],
+        reassignment_ratio=hparams["reassignment_ratio"],
+        random_state=hparams["seed"],
+        checkpoint_path=checkpoint_path,
+    )
+
+    # Train and save Kmeans model
+    train(
+        kmeans_model,
+        train_set,
+        hparams["ssl_model"],
+        hparams["ssl_layer_num"],
+        kmeans_batch_size=hparams["kmeans_batch_size"],
+        device=run_opts["device"],
+    )
+
+    logger.info(f"Saving kmeans model at {checkpoint_path}.")
+    save_model(kmeans_model, checkpoint_path)
diff --git a/recipes/LibriSpeech/self-supervised-learning/wav2vec2/README.md b/recipes/LibriSpeech/self-supervised-learning/wav2vec2/README.md
index a097bdb5385b3dd4c7691e7636568626198e12d2..f20ed3b19c75e17e4efc76c9e11b58270deac6a1 100644
--- a/recipes/LibriSpeech/self-supervised-learning/wav2vec2/README.md
+++ b/recipes/LibriSpeech/self-supervised-learning/wav2vec2/README.md
@@ -30,6 +30,6 @@ The checkpoint generated by this pretraining is a standard PyTorch checkpoint. I
 Training wav2vec 2.0 models is crazy w.r.t compute resources. For instance, this recipe only trains a BASE wav2vec 2.0 architecture, and it already requires 16 Tesla V100 for 7 days. Of course, you can scale this to your needs (e.g., you can work with 2 GPUs only), but it will take ages! Welcome to the wav2vec 2.0 world!
 
 Here is a list of the most important advices:
-- To train w2v2 models, it is **extremely** important to have an effective batch size as high as possible. For instance, the original BASE model is trained with batches containing 1.6H of speech. This means that (duration_per_minibatch * nb_gpu * gradient_accumulation) must be at least equal to 1.6H.
+- To train w2v2 models, it is **extremely** important to have an effective batch size as high as possible. For instance, the original BASE model is trained with batches containing 1.6H of speech. This means that (duration_per_minibatch * nb_gpu * grad_accumulation_factor) must be at least equal to 1.6H.
 - Do not train on sequences longer than 20s, this will blow your VRAM up and is useless for now. Indeed training with shorter sentences (10s) may work just as well.
 - Set the `n_warmup_steps` steps in such a way that it corresponds to 10% of the total training steps. The number of steps correspond to the actual number of call to .backward w.r.t the batch size.
diff --git a/recipes/LibriSpeech/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml b/recipes/LibriSpeech/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml
index 0fa622217fb627683a704d8c5d7819609960f7cb..13ce0d2203fd62dfc84afa29a9ddd7c04cfd66d3 100644
--- a/recipes/LibriSpeech/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml
+++ b/recipes/LibriSpeech/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml
@@ -21,7 +21,7 @@ skip_prep: False
 avoid_if_longer_than: 30.0
 avoid_if_shorter_than: 1.5
 log_interval: 1000 # Logging every N optimizer steps
-auto_mix_prec: True
+precision: fp16 # bf16, fp16 or fp32
 max_grad_norm: 100.
 
 # The training will either stops at number_of_epochs or optimizer_step_limit
@@ -30,8 +30,16 @@ number_of_epochs: 3000
 optimizer_step_limit: 400000
 
 # Dynamic Batching parameters
-train_num_buckets: 70
-seconds_per_batch: 200 # Fits in a 32GB GPUs (V100)
+max_batch_length: 200 # Fits in a 32GB GPUs (V100)
+num_buckets: 70
+shuffle: True # if true re-creates batches at each epoch shuffling examples.
+batch_ordering: random
+
+dynamic_batch_sampler_train:
+   max_batch_length: !ref <max_batch_length>
+   num_buckets: !ref <num_buckets>
+   shuffle: !ref <shuffle>
+   batch_ordering: !ref <batch_ordering>
 
 train_dataloader_options:
    num_workers: 4
@@ -40,7 +48,7 @@ test_dataloader_options:
    batch_size: 8 # DynamicBatching not used at testing time
    num_workers: 4
 
-# Training parameters
+####################### Training Parameters ####################################
 lr: 0.0005
 warmup: 30000
 # This is equivalent to optimizer_step_limit - warmup
@@ -55,7 +63,7 @@ mask_prob: 0.65
 mask_length: 10
 num_negatives: 100
 
-# Model parameters
+####################### Model Parameters #######################################
 embedding_dim: 768
 extractor_dim: 512
 final_dim: 256
diff --git a/recipes/LibriSpeech/self-supervised-learning/wav2vec2/train_sb_wav2vec2.py b/recipes/LibriSpeech/self-supervised-learning/wav2vec2/train_sb_wav2vec2.py
index 291c4c0b6354cee3478c715a29af031bc87d14eb..72494f387504f5c68652cd2636d49a76bb5cfea4 100644
--- a/recipes/LibriSpeech/self-supervised-learning/wav2vec2/train_sb_wav2vec2.py
+++ b/recipes/LibriSpeech/self-supervised-learning/wav2vec2/train_sb_wav2vec2.py
@@ -29,6 +29,7 @@ from speechbrain.dataio.dataloader import SaveableDataLoader
 from speechbrain.dataio.sampler import DynamicBatchSampler
 from speechbrain.lobes.models.wav2vec import w2v_mask_collate_fn
 from speechbrain.lobes.models.wav2vec import sample_negatives
+from speechbrain.core import AMPConfig
 
 logger = logging.getLogger(__name__)
 
@@ -84,7 +85,7 @@ class W2V2Brain(sb.core.Brain):
         loss, accuracy = self.hparams.loss(embeddings, targets, negs)
 
         # This is only used for logging purpose
-        if stage != sb.Stage.TRAIN and sb.utils.distributed.if_main_process():
+        if stage != sb.Stage.TRAIN:
             self.acc_metric.append(accuracy)
 
         objectives = {
@@ -120,47 +121,33 @@ class W2V2Brain(sb.core.Brain):
         return objectives
 
     def fit_batch(self, batch):
-        should_step = self.step % self.grad_accumulation_factor == 0
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
+
         # Managing automatic mixed precision
-        if self.auto_mix_prec:
-            with self.no_sync(not should_step):
-                with torch.cuda.amp.autocast():
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
                     outputs = self.compute_forward(batch, Stage.TRAIN)
                     objectives = self.compute_objectives(
                         outputs, batch, Stage.TRAIN
                     )
-
-                self.scaler.scale(
-                    objectives["backprop_loss"] / self.grad_accumulation_factor
-                ).backward()
-
-                objectives["total_loss"] = objectives["backprop_loss"].detach()
-                if should_step:
-                    self.scaler.unscale_(self.optimizer)
-                    if self.check_gradients(objectives["backprop_loss"]):
-                        self.scaler.step(self.optimizer)
-                    self.optimizer.zero_grad()
-                    self.optimizer_step += 1
-                    self.scaler.update()
-        else:
-            with self.no_sync(not should_step):
+            else:
                 outputs = self.compute_forward(batch, Stage.TRAIN)
                 objectives = self.compute_objectives(
                     outputs, batch, Stage.TRAIN
                 )
 
-                (
-                    objectives["backprop_loss"] / self.grad_accumulation_factor
-                ).backward()
-                objectives["total_loss"] = objectives["backprop_loss"].detach()
+            self.scaler.scale(
+                objectives["backprop_loss"] / self.grad_accumulation_factor
+            ).backward()
 
-                if should_step:
-                    if self.check_gradients(objectives["backprop_loss"]):
-                        self.optimizer.step()
-                    self.optimizer.zero_grad()
-                    self.optimizer_step += 1
+            objectives["total_loss"] = objectives["backprop_loss"].detach()
 
         if should_step:
+            self.optimizers_step()
             self.on_fit_batch_end(objectives)
 
         return objectives["backprop_loss"].detach()
@@ -292,13 +279,10 @@ def dataio_prepare(hparams):
     sb.dataio.dataset.set_output_keys(datasets, ["id", "sig"])
 
     # We create the DynamicBatch Sampler
+    dynamic_hparams = hparams["dynamic_batch_sampler_train"]
+
     train_sampler = DynamicBatchSampler(
-        train_data,
-        hparams["seconds_per_batch"],
-        num_buckets=hparams["train_num_buckets"],
-        length_func=lambda x: x["duration"],
-        batch_ordering="random",
-        shuffle=True,
+        train_data, **dynamic_hparams, length_func=lambda x: x["duration"],
     )
 
     # We define the custom collation function that is necessary for w2v2 to
@@ -344,6 +328,10 @@ def main():
         overrides=overrides,
     )
 
+    # Update precision to bf16 if the device is CPU and precision is fp16
+    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
+        hparams["precision"] = "bf16"
+
     from librispeech_prepare import prepare_librispeech
 
     run_on_main(
diff --git a/recipes/LibriTTS/README.md b/recipes/LibriTTS/README.md
index 8061d2fc8441dd7e8f8923bcdfb6af14f0de1ad4..7c9cf1bbbf10da0581f71d194c5320664b53539c 100644
--- a/recipes/LibriTTS/README.md
+++ b/recipes/LibriTTS/README.md
@@ -6,6 +6,27 @@ The LibriTTS dataset is available here: https://www.openslr.org/60/, https://www
 
 The `libritts_prepare.py` file automatically downloads the dataset if not present and has facilities to provide the names of the subsets to be downloaded.
 
+# Zero-Shot Multi-Speaker Tacotron2
+The subfolder "TTS/mstacotron2" contains the recipe for training a zero-shot multi-speaker version of the [Tacotron2](https://arxiv.org/abs/1712.05884) model.
+To run this recipe, go into the `"TTS/mstacotron2"` folder and run:
+
+```bash
+python train.py hparams/train.yaml --data_folder=/path/to/libritts_data --device=cuda:0 --max_grad_norm=1.0
+```
+
+Please ensure that you use absolute paths when specifying the data folder.
+
+Training time required on NVIDIA A100 GPU using LibriTTS train-clean-100 and train-clean-360 subsets: ~ 2 hours 54 minutes per epoch
+
+The training logs are available [here](https://www.dropbox.com/sh/ti2vk7sce8f9fgd/AABcDGWCrBvLX_ZQs76mlJRYa?dl=0).
+
+For now, enhancements are needed for traning the model from scratch when train-clean-360 is included. Inference can be effectuated with `clone_voice_char_input` function in the MSTacotron2 interface.
+
+The pre-trained model (a model fine-tuned from LJSpeech tacotron2) with an easy-inference interface is available on [HuggingFace](https://huggingface.co/speechbrain/tts-mstacotron2-libritts).
+
+**Please Note**: The current model effectively captures speaker identities. Nevertheless, the synthesized speech quality exhibits some metallic characteristics and may include artifacts like overly long pauses.
+We are actively working to enhancing the model and will release updates as soon as improvements are achieved. We warmly welcome contributions from the community to collaboratively make the model even better!
+
 # HiFi GAN (Vocoder)
 The subfolder "vocoder/hifi_gan/" contains the [HiFi GAN vocoder](https://arxiv.org/pdf/2010.05646.pdf).
 The vocoder is a neural network that converts a spectrogram into a waveform (it can be used on top of Tacotron2).
@@ -14,11 +35,13 @@ We suggest using `tensorboard_logger` by setting `use_tensorboard: True` in the
 
 To run this recipe, go into the `"vocoder/hifigan/"` folder and run:
 
-```
+```bash
 python train.py hparams/train.yaml --data_folder=/path/to/LibriTTS
 ```
 
-The recipe will automatically download the librispeech dataset and resamples it as specified.
+The recipe will automatically download the LibriTTS dataset and resamples it as specified.
+
+Training time required on NVIDIA A100 GPU using LibriTTS train-clean-100 and train-clean-360 subsets: ~ 1 hour 50 minutes per epoch
 
 The training logs and checkpoints are available [here](https://www.dropbox.com/sh/gjs1kslxkxz819q/AABPriN4dOoD1qL7NoIyVk0Oa?dl=0).
 
diff --git a/recipes/LibriTTS/TTS/mstacotron2/compute_speaker_embeddings.py b/recipes/LibriTTS/TTS/mstacotron2/compute_speaker_embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..142090cde06791cfb22c11e31afc0e33ee2b5566
--- /dev/null
+++ b/recipes/LibriTTS/TTS/mstacotron2/compute_speaker_embeddings.py
@@ -0,0 +1,117 @@
+import json
+from speechbrain.inference.encoders import MelSpectrogramEncoder
+from speechbrain.inference.classifiers import EncoderClassifier
+import torchaudio
+import pickle
+import logging
+import os
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+def compute_speaker_embeddings(
+    input_filepaths,
+    output_file_paths,
+    data_folder,
+    spk_emb_encoder_path,
+    spk_emb_sr,
+    mel_spec_params,
+    device,
+):
+    """This function processes a JSON file to compute the speaker embeddings
+
+    Arguments
+    ---------
+    input_filepaths : list
+        A list of paths to the JSON files to be processed
+    output_file_paths : list
+        A list of paths to the output pickle files corresponding to the input JSON files
+    data_folder : str
+        Path to the folder where LibriTTS data is stored
+    spk_emb_encoder_path : str
+        Path for the speaker encoder
+    spk_emb_sr : int
+        Sample rate used by the speaker embedding encoder
+    mel_spec_params: dict
+        Information about mel-spectrogram computation
+    device : str
+        Device for to be used for computation
+    """
+
+    # Checks if this phase is already done (if so, skips it)
+    if skip(output_file_paths):
+        logger.info("Preparation completed in previous run, skipping.")
+        return
+
+    # Initializes the speaker encoder
+    spk_emb_encoder = None
+    if mel_spec_params["custom_mel_spec_encoder"]:
+        # To use the custom mel-spectrogram based encoder - for compatibility with future speaker consistency loss work
+        spk_emb_encoder = MelSpectrogramEncoder.from_hparams(
+            source=spk_emb_encoder_path, run_opts={"device": device}
+        )
+    else:
+        # To use the speaker encoders available with SpeechBrain
+        spk_emb_encoder = EncoderClassifier.from_hparams(
+            source=spk_emb_encoder_path, run_opts={"device": device}
+        )
+
+    # Processes data manifests files to create corresponding speaker embedding files
+    for i in range(len(input_filepaths)):
+        logger.info(f"Creating {output_file_paths[i]}.")
+
+        speaker_embeddings = dict()  # Holds speaker embeddings
+
+        json_file = open(input_filepaths[i])
+        json_data = json.load(json_file)
+
+        # Processes all utterances in the data manifest file
+        for utt_id, utt_data in tqdm(json_data.items()):
+            utt_wav_path = utt_data["wav"]
+            utt_wav_path = utt_wav_path.replace("{data_root}", data_folder)
+
+            # Loads and resamples waveforms if required
+            signal, sig_sr = torchaudio.load(utt_wav_path)
+            if sig_sr != spk_emb_sr:
+                signal = torchaudio.functional.resample(
+                    signal, sig_sr, spk_emb_sr
+                )
+            signal = signal.to(device)
+
+            # Computes the speaker embedding
+            if mel_spec_params["custom_mel_spec_encoder"]:
+                spk_emb = spk_emb_encoder.encode_waveform(signal)
+            else:
+                spk_emb = spk_emb_encoder.encode_batch(signal)
+
+            spk_emb = spk_emb.squeeze()
+            spk_emb = spk_emb.detach()
+
+            speaker_embeddings[utt_id] = spk_emb.cpu()
+
+        # Stores the speaker embeddings at the destination
+        with open(output_file_paths[i], "wb") as output_file:
+            pickle.dump(
+                speaker_embeddings,
+                output_file,
+                protocol=pickle.HIGHEST_PROTOCOL,
+            )
+
+        logger.info(f"Created {output_file_paths[i]}.")
+
+
+def skip(filepaths):
+    """
+    Detects if the data preparation has been already done.
+    If the preparation has been done, we can skip it.
+    Returns
+    -------
+    bool
+        if True, the preparation phase can be skipped.
+        if False, it must be done.
+    """
+    for filepath in filepaths:
+        if not os.path.isfile(filepath):
+            return False
+    return True
diff --git a/recipes/LibriTTS/TTS/mstacotron2/hparams/train.yaml b/recipes/LibriTTS/TTS/mstacotron2/hparams/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..36d22039a19d58847bdfb11c4b679609ebcaa25d
--- /dev/null
+++ b/recipes/LibriTTS/TTS/mstacotron2/hparams/train.yaml
@@ -0,0 +1,283 @@
+############################################################################
+# Model: Zero-Shot Multi-Speaker Tacotron2
+# Tokens: ARPAbet Phonemes
+# Training: LibriTTS
+# Authors: Georges Abous-Rjeili, Artem Ploujnikov, Yingzhi Wang, Pradnya Kandarkar
+# ############################################################################
+
+
+###################################
+# Experiment Parameters and setup #
+###################################
+seed: 1234
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref ./results/tacotron2/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+epochs: 700
+keep_checkpoint_interval: 50
+use_tensorboard: False
+
+# Vocoder is used to covert the intermediate mel-spectrogram into the final waveform
+log_audio_samples: True
+vocoder: speechbrain/tts-hifigan-libritts-16kHz
+vocoder_savedir: tmpdir_vocoder_16k
+
+###################################
+# Progress Samples                #
+###################################
+# Progress samples are used to monitor the progress
+# of an ongoing training session by outputting samples
+# of spectrogram, alignment, etc. at regular intervals
+
+# Whether to enable progress samples
+progress_samples: True
+
+# The path where the samples will be stored
+progress_sample_path: !ref <output_folder>/samples
+
+# The interval, in epochs. For instance, if it is set to 5,
+# progress samples will be output every 5 epochs
+progress_samples_interval: 10
+
+# The sample size for raw batch samples saved in batch.pth
+# (useful mostly for model debugging)
+progress_batch_sample_size: 3
+
+#################################
+# Data files and pre-processing #
+#################################
+data_folder: !PLACEHOLDER # e.g, /localscratch/LibriTTS/
+
+# Files to hold the manifest data
+train_json: !ref <save_folder>/train.json
+valid_json: !ref <save_folder>/valid.json
+test_json: !ref <save_folder>/test.json
+
+# Files to hold the speaker embeddings - corresponding to the data manifest files
+train_speaker_embeddings_pickle: !ref <save_folder>/train_speaker_embeddings.pickle
+valid_speaker_embeddings_pickle: !ref <save_folder>/valid_speaker_embeddings.pickle
+test_speaker_embeddings_pickle: !ref <save_folder>/test_speaker_embeddings.pickle
+
+# Data splits
+skip_prep: False
+splits: ["train", "valid", "test"]
+
+# train_split: ["train-clean-100", "train-clean-360"]
+train_split: ["train-clean-100"]
+valid_split: ["dev-clean"]
+test_split: ["test-clean"]
+
+# Use the original preprocessing from nvidia
+# The cleaners to be used (applicable to nvidia only)
+text_cleaners: ['english_cleaners']
+
+# Avoid audios longer than x seconds
+avoid_if_longer_than: 10.0
+################################
+# Audio Parameters             #
+################################
+sample_rate: 16000
+hop_length: 256
+win_length: 1024
+n_mel_channels: 80
+n_fft: 1024
+mel_fmin: 0.0
+mel_fmax: 8000.0
+mel_normalized: False
+power: 1
+norm: "slaney"
+mel_scale: "slaney"
+dynamic_range_compression: True
+
+################################
+# Speaker Embedding Parameters #
+################################
+spk_emb_size: 192
+spk_emb_sample_rate: 16000
+custom_mel_spec_encoder: False
+spk_emb_encoder: speechbrain/spkrec-ecapa-voxceleb
+
+# To use the custom mel-spectrogram based encoder - for compatibility with future speaker consistency loss work
+# 1. Change "custom_mel_spec_encoder" to True
+# 2. Change the path for "spk_emb_encoder".
+# The ECAPA-TDNN model used for the Zero-Shot Multi-Speaker Tacotron2 experiments is available here: speechbrain/spkrec-ecapa-voxceleb-mel-spec
+
+################################
+# Optimization Hyperparameters #
+################################
+learning_rate: 0.001
+weight_decay: 0.000006
+batch_size: 32 #minimum 2
+mask_padding: True
+guided_attention_sigma: 0.2
+guided_attention_weight: 25.0
+guided_attention_weight_half_life: 25.
+guided_attention_hard_stop: 50
+gate_loss_weight: 1.0
+spk_emb_loss_weight: 1.0
+
+train_dataloader_opts:
+  batch_size: !ref <batch_size>
+  drop_last: True  #True #False
+  num_workers: 8
+  collate_fn: !new:speechbrain.lobes.models.MSTacotron2.TextMelCollate
+    speaker_embeddings_pickle: !ref <train_speaker_embeddings_pickle>
+
+valid_dataloader_opts:
+  batch_size: !ref <batch_size>
+  drop_last: True
+  num_workers: 8
+  collate_fn: !new:speechbrain.lobes.models.MSTacotron2.TextMelCollate
+    speaker_embeddings_pickle: !ref <valid_speaker_embeddings_pickle>
+
+test_dataloader_opts:
+  batch_size: !ref <batch_size>
+  drop_last: True
+  num_workers: 8
+  collate_fn: !new:speechbrain.lobes.models.MSTacotron2.TextMelCollate
+    speaker_embeddings_pickle: !ref <test_speaker_embeddings_pickle>
+
+###############################
+# Model Parameters and model  #
+###############################
+n_symbols: 148 #fixed depending on symbols in textToSequence
+symbols_embedding_dim: 1024
+
+# Encoder parameters
+encoder_kernel_size: 5
+encoder_n_convolutions: 6
+encoder_embedding_dim: 1024
+
+# Decoder parameters
+# The number of frames in the target per encoder step
+n_frames_per_step: 1
+decoder_rnn_dim: 2048
+prenet_dim: 512
+max_decoder_steps: 1500
+gate_threshold: 0.5
+p_attention_dropout: 0.1
+p_decoder_dropout: 0.1
+decoder_no_early_stopping: False
+
+# Attention parameters
+attention_rnn_dim: 2048
+attention_dim: 256
+
+# Location Layer parameters
+attention_location_n_filters: 32
+attention_location_kernel_size: 31
+
+# Mel-post processing network parameters
+postnet_embedding_dim: 1024
+postnet_kernel_size: 5
+postnet_n_convolutions: 10
+
+# To compute the mel-spectrogram for an audio
+mel_spectogram: !name:speechbrain.lobes.models.Tacotron2.mel_spectogram
+  sample_rate: !ref <sample_rate>
+  hop_length: !ref <hop_length>
+  win_length: !ref <win_length>
+  n_fft: !ref <n_fft>
+  n_mels: !ref <n_mel_channels>
+  f_min: !ref <mel_fmin>
+  f_max: !ref <mel_fmax>
+  power: !ref <power>
+  normalized: !ref <mel_normalized>
+  norm: !ref <norm>
+  mel_scale: !ref <mel_scale>
+  compression: !ref <dynamic_range_compression>
+
+# Zero-Shot Multi-Speaker Tacotron2 model
+model: !new:speechbrain.lobes.models.MSTacotron2.Tacotron2
+  mask_padding: !ref <mask_padding>
+  n_mel_channels: !ref <n_mel_channels>
+  # Symbols
+  n_symbols: !ref <n_symbols>
+  symbols_embedding_dim: !ref <symbols_embedding_dim>
+  # Encoder
+  encoder_kernel_size: !ref <encoder_kernel_size>
+  encoder_n_convolutions: !ref <encoder_n_convolutions>
+  encoder_embedding_dim: !ref <encoder_embedding_dim>
+  # Attention
+  attention_rnn_dim: !ref <attention_rnn_dim>
+  attention_dim: !ref <attention_dim>
+  # Attention location
+  attention_location_n_filters: !ref <attention_location_n_filters>
+  attention_location_kernel_size: !ref <attention_location_kernel_size>
+  # Decoder
+  n_frames_per_step: !ref <n_frames_per_step>
+  decoder_rnn_dim: !ref <decoder_rnn_dim>
+  prenet_dim: !ref <prenet_dim>
+  max_decoder_steps: !ref <max_decoder_steps>
+  gate_threshold: !ref <gate_threshold>
+  p_attention_dropout: !ref <p_attention_dropout>
+  p_decoder_dropout: !ref <p_decoder_dropout>
+  # Postnet
+  postnet_embedding_dim: !ref <postnet_embedding_dim>
+  postnet_kernel_size: !ref <postnet_kernel_size>
+  postnet_n_convolutions: !ref <postnet_n_convolutions>
+  decoder_no_early_stopping: !ref <decoder_no_early_stopping>
+  # Speaker embeddings
+  spk_emb_size: !ref <spk_emb_size>
+
+# Scheduler for guided attention
+guided_attention_scheduler: !new:speechbrain.nnet.schedulers.StepScheduler
+  initial_value: !ref <guided_attention_weight>
+  half_life: !ref <guided_attention_weight_half_life>
+
+# Loss function
+criterion: !new:speechbrain.lobes.models.MSTacotron2.Loss
+  gate_loss_weight: !ref <gate_loss_weight>
+  guided_attention_weight: !ref <guided_attention_weight>
+  guided_attention_sigma: !ref <guided_attention_sigma>
+  guided_attention_scheduler: !ref <guided_attention_scheduler>
+  guided_attention_hard_stop: !ref <guided_attention_hard_stop>
+  spk_emb_loss_weight: !ref <spk_emb_loss_weight>
+
+# Overall modules used
+modules:
+  model: !ref <model>
+
+# Optimizer
+opt_class: !name:torch.optim.Adam
+  lr: !ref <learning_rate>
+  weight_decay: !ref <weight_decay>
+
+# To keep track of the epochs
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+  limit: !ref <epochs>
+
+# To log training information
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+  save_file: !ref <train_log>
+
+# # Learning rate annealing function
+lr_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+  lr_initial: !ref <learning_rate>
+  n_warmup_steps: 4000
+
+# Checkpointer
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+  checkpoints_dir: !ref <save_folder>
+  recoverables:
+    model: !ref <model>
+    counter: !ref <epoch_counter>
+    scheduler: !ref <lr_annealing>
+
+# Progress sample logger
+progress_sample_logger: !new:speechbrain.utils.train_logger.ProgressSampleLogger
+  output_path: !ref <progress_sample_path>
+  batch_sample_size: !ref <progress_batch_sample_size>
+  formats:
+    raw_batch: raw
+
+
+# Pretrained separator - Use when fine-tuning - REMOVE IF NOT REQUIRED
+# tacotron2_model_path: !PLACEHOLDER
+# pretrained_separator: !new:speechbrain.utils.parameter_transfer.Pretrainer
+#   collect_in: !ref <save_folder>
+#   loadables:
+#     model: !ref <model>
+#   paths:
+#     model: !ref <tacotron2_model_path>/model.ckpt
diff --git a/recipes/LibriTTS/TTS/mstacotron2/libritts_prepare.py b/recipes/LibriTTS/TTS/mstacotron2/libritts_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..489ab40118933de470d4372a3757a8fc80025071
--- /dev/null
+++ b/recipes/LibriTTS/TTS/mstacotron2/libritts_prepare.py
@@ -0,0 +1 @@
+../../libritts_prepare.py
\ No newline at end of file
diff --git a/recipes/LibriTTS/TTS/mstacotron2/train.py b/recipes/LibriTTS/TTS/mstacotron2/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..12ca750b9dabc84fe382f435988a7258cce7e5d6
--- /dev/null
+++ b/recipes/LibriTTS/TTS/mstacotron2/train.py
@@ -0,0 +1,659 @@
+# -*- coding: utf-8 -*-
+"""
+ Recipe for training the Zero-Shot Multi-Speaker Tacotron Text-To-Speech model, an end-to-end
+ neural text-to-speech (TTS) system
+
+ To run this recipe, do the following:
+ # python train.py --device=cuda:0 --max_grad_norm=1.0 --data_folder=/path_to_data_folder hparams/train.yaml
+
+ Authors
+ * Georges Abous-Rjeili 2021
+ * Artem Ploujnikov 2021
+ * Yingzhi Wang 2022
+ * Pradnya Kandarkar 2023
+"""
+import torch
+import speechbrain as sb
+import sys
+import logging
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.text_to_sequence import text_to_sequence
+from speechbrain.utils.data_utils import scalarize
+import os
+from speechbrain.inference.vocoders import HIFIGAN
+import torchaudio
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+logger = logging.getLogger(__name__)
+
+
+class Tacotron2Brain(sb.Brain):
+    """The Brain implementation for Tacotron2"""
+
+    def on_fit_start(self):
+        """Gets called at the beginning of ``fit()``, on multiple processes
+        if ``distributed_count > 0`` and backend is ddp and initializes statistics"""
+        self.hparams.progress_sample_logger.reset()
+        self.last_epoch = 0
+        self.last_batch = None
+        self.last_preds = None
+
+        # Instantiate a vocoder if audio samples should be logged
+        if self.hparams.log_audio_samples:
+            self.vocoder = HIFIGAN.from_hparams(
+                source=self.hparams.vocoder,
+                savedir=self.hparams.vocoder_savedir,
+                run_opts={"device": self.device},
+                freeze_params=True,
+            )
+
+        self.last_loss_stats = {}
+        return super().on_fit_start()
+
+    def compute_forward(self, batch, stage):
+        """Computes the forward pass
+
+        Arguments
+        ---------
+        batch: str
+            a single batch
+        stage: speechbrain.Stage
+            the training stage
+
+        Returns
+        -------
+        the model output
+        """
+        effective_batch = self.batch_to_device(batch)
+        inputs, y, num_items, _, _, spk_embs, spk_ids = effective_batch
+
+        _, input_lengths, _, _, _ = inputs
+
+        max_input_length = input_lengths.max().item()
+
+        return self.modules.model(
+            inputs, spk_embs, alignments_dim=max_input_length
+        )
+
+    def fit_batch(self, batch):
+        """Fits a single batch and applies annealing
+
+        Arguments
+        ---------
+        batch: tuple
+            a training batch
+
+        Returns
+        -------
+        loss: torch.Tensor
+            detached loss
+        """
+        result = super().fit_batch(batch)
+        self.hparams.lr_annealing(self.optimizer)
+        return result
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss given the predicted and targeted outputs
+
+        Arguments
+        ---------
+        predictions : torch.Tensor
+            The model generated mel-spectrograms and other metrics from `compute_forward`
+        batch : PaddedBatch
+            This batch object contains all the relevant tensors for computation
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST
+
+        Returns
+        -------
+        loss : torch.Tensor
+            A one-element tensor used for backpropagating the gradient
+        """
+        effective_batch = self.batch_to_device(batch)
+        # Hold on to the batch for the inference sample.
+        # This is needed because the infernece sample is run from on_stage_end only,
+        # where batch information is not available
+        self.last_batch = effective_batch
+        self.last_preds = predictions
+        # Hold on to a sample (for logging)
+        self._remember_sample(effective_batch, predictions)
+        # Compute the loss
+        loss = self._compute_loss(predictions, effective_batch, stage)
+        return loss
+
+    def _compute_loss(self, predictions, batch, stage):
+        """Computes the value of the loss function and updates stats
+
+        Arguments
+        ---------
+        predictions: tuple
+            model predictions
+        batch : PaddedBatch
+            This batch object contains all the relevant tensors for computation
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST
+
+        Returns
+        -------
+        loss: torch.Tensor
+            the loss value
+        """
+        inputs, targets, num_items, labels, wavs, spk_embs, spk_ids = batch
+        text_padded, input_lengths, _, max_len, output_lengths = inputs
+
+        # Speaker embedding input to compute speaker consistency loss - WIP
+        spk_emb_input = None
+
+        loss_stats = self.hparams.criterion(
+            predictions,
+            targets,
+            input_lengths,
+            output_lengths,
+            spk_emb_input,
+            self.last_epoch,
+        )
+        self.last_loss_stats[stage] = scalarize(loss_stats)
+        return loss_stats.loss
+
+    def _remember_sample(self, batch, predictions):
+        """Remembers samples of spectrograms and the batch for logging purposes
+
+        Arguments
+        ---------
+        batch: tuple
+            a training batch
+        predictions: tuple
+            predictions (raw output of the Tacotron model)
+        """
+        inputs, targets, num_items, labels, wavs, spk_embs, spk_ids = batch
+        text_padded, input_lengths, _, max_len, output_lengths = inputs
+        mel_target, _ = targets
+        (
+            mel_out,
+            mel_out_postnet,
+            gate_out,
+            alignments,
+            pred_mel_lengths,
+        ) = predictions
+        alignments_max = (
+            alignments[0]
+            .max(dim=-1)
+            .values.max(dim=-1)
+            .values.unsqueeze(-1)
+            .unsqueeze(-1)
+        )
+        alignments_output = alignments[0].T.flip(dims=(1,)) / alignments_max
+        self.hparams.progress_sample_logger.remember(
+            target=self._get_spectrogram_sample(mel_target),
+            output=self._get_spectrogram_sample(mel_out),
+            output_postnet=self._get_spectrogram_sample(mel_out_postnet),
+            alignments=alignments_output,
+            raw_batch=self.hparams.progress_sample_logger.get_batch_sample(
+                {
+                    "text_padded": text_padded,
+                    "input_lengths": input_lengths,
+                    "mel_target": mel_target,
+                    "mel_out": mel_out,
+                    "mel_out_postnet": mel_out_postnet,
+                    "max_len": max_len,
+                    "output_lengths": output_lengths,
+                    "gate_out": gate_out,
+                    "alignments": alignments,
+                    "labels": labels,
+                    "wavs": wavs,
+                    "spk_embs": spk_embs,
+                    "spk_ids": spk_ids,
+                }
+            ),
+        )
+
+    def batch_to_device(self, batch):
+        """Transfers the batch to the target device
+
+        Arguments
+        ---------
+        batch: tuple
+            the batch to use
+
+        Returns
+        -------
+        batch: tuple
+            the batch on the correct device
+        """
+        (
+            text_padded,
+            input_lengths,
+            mel_padded,
+            gate_padded,
+            output_lengths,
+            len_x,
+            labels,
+            wavs,
+            spk_embs,
+            spk_ids,
+        ) = batch
+        text_padded = text_padded.to(self.device, non_blocking=True).long()
+        input_lengths = input_lengths.to(self.device, non_blocking=True).long()
+        max_len = torch.max(input_lengths.data).item()
+        mel_padded = mel_padded.to(self.device, non_blocking=True).float()
+        gate_padded = gate_padded.to(self.device, non_blocking=True).float()
+
+        output_lengths = output_lengths.to(
+            self.device, non_blocking=True
+        ).long()
+        x = (text_padded, input_lengths, mel_padded, max_len, output_lengths)
+        y = (mel_padded, gate_padded)
+        len_x = torch.sum(output_lengths)
+        spk_embs = spk_embs.to(self.device, non_blocking=True).float()
+        return (x, y, len_x, labels, wavs, spk_embs, spk_ids)
+
+    def _get_spectrogram_sample(self, raw):
+        """Converts a raw spectrogram to one that can be saved as an image
+        sample  = sqrt(exp(raw))
+
+        Arguments
+        ---------
+        raw: torch.Tensor
+            the raw spectrogram (as used in the model)
+
+        Returns
+        -------
+        sample: torch.Tensor
+            the spectrogram, for image saving purposes
+        """
+        sample = raw[0]
+        return torch.sqrt(torch.exp(sample))
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of an epoch
+
+        Arguments
+        ---------
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
+        stage_loss : float
+            The average loss for all of the data processed in this stage.
+        epoch : int
+            The currently-starting epoch. This is passed
+            `None` during the test stage.
+        """
+
+        # Logs training samples every 10 epochs
+        if stage == sb.Stage.TRAIN and (
+            self.hparams.epoch_counter.current % 10 == 0
+        ):
+            if self.last_batch is None:
+                return
+
+            train_sample_path = os.path.join(
+                self.hparams.progress_sample_path,
+                str(self.hparams.epoch_counter.current),
+            )
+            if not os.path.exists(train_sample_path):
+                os.makedirs(train_sample_path)
+
+            _, targets, _, labels, wavs, spk_embs, spk_ids = self.last_batch
+
+            train_sample_text = os.path.join(
+                self.hparams.progress_sample_path,
+                str(self.hparams.epoch_counter.current),
+                "train_input_text.txt",
+            )
+            with open(train_sample_text, "w") as f:
+                f.write(labels[0])
+
+            train_input_audio = os.path.join(
+                self.hparams.progress_sample_path,
+                str(self.hparams.epoch_counter.current),
+                "train_input_audio.wav",
+            )
+            torchaudio.save(
+                train_input_audio,
+                sb.dataio.dataio.read_audio(wavs[0]).unsqueeze(0),
+                self.hparams.sample_rate,
+            )
+
+            _, mel_out_postnet, _, _, pred_mel_lengths = self.last_preds
+
+            if self.hparams.log_audio_samples:
+                waveform_ss = self.vocoder.decode_batch(mel_out_postnet[0])
+                train_sample_audio = os.path.join(
+                    self.hparams.progress_sample_path,
+                    str(self.hparams.epoch_counter.current),
+                    "train_output_audio.wav",
+                )
+                torchaudio.save(
+                    train_sample_audio,
+                    waveform_ss.squeeze(1).cpu(),
+                    self.hparams.sample_rate,
+                )
+
+            if self.hparams.use_tensorboard:
+                self.tensorboard_logger.log_audio(
+                    f"{stage}/train_audio_target",
+                    sb.dataio.dataio.read_audio(wavs[0]).unsqueeze(0),
+                    self.hparams.sample_rate,
+                )
+                if self.hparams.log_audio_samples:
+                    self.tensorboard_logger.log_audio(
+                        f"{stage}/train_audio_pred",
+                        waveform_ss.squeeze(1),
+                        self.hparams.sample_rate,
+                    )
+                try:
+                    self.tensorboard_logger.log_figure(
+                        f"{stage}/train_mel_target", targets[0][0]
+                    )
+                    self.tensorboard_logger.log_figure(
+                        f"{stage}/train_mel_pred", mel_out_postnet[0]
+                    )
+                except Exception:
+                    # This is to avoid the code from crashing in case of a mel-spectrogram with one frame
+                    pass
+
+        # At the end of validation, we can write
+        if stage == sb.Stage.VALID:
+            # Update learning rate
+            lr = self.optimizer.param_groups[-1]["lr"]
+            self.last_epoch = epoch
+
+            # The train_logger writes a summary to stdout and to the logfile.
+            self.hparams.train_logger.log_stats(  # 1#2#
+                stats_meta={"Epoch": epoch, "lr": lr},
+                train_stats=self.last_loss_stats[sb.Stage.TRAIN],
+                valid_stats=self.last_loss_stats[sb.Stage.VALID],
+            )
+
+            # The tensorboard_logger writes a summary to stdout and to the logfile.
+            if self.hparams.use_tensorboard:
+                self.tensorboard_logger.log_stats(
+                    stats_meta={"Epoch": epoch, "lr": lr},
+                    train_stats=self.last_loss_stats[sb.Stage.TRAIN],
+                    valid_stats=self.last_loss_stats[sb.Stage.VALID],
+                )
+
+            # Save the current checkpoint and delete previous checkpoints.
+            epoch_metadata = {
+                **{"epoch": epoch},
+                **self.last_loss_stats[sb.Stage.VALID],
+            }
+            self.checkpointer.save_and_keep_only(
+                meta=epoch_metadata,
+                min_keys=["loss"],
+                ckpt_predicate=(
+                    lambda ckpt: (
+                        ckpt.meta["epoch"]
+                        % self.hparams.keep_checkpoint_interval
+                        != 0
+                    )
+                )
+                if self.hparams.keep_checkpoint_interval is not None
+                else None,
+            )
+            output_progress_sample = (
+                self.hparams.progress_samples
+                and epoch % self.hparams.progress_samples_interval == 0
+            )
+            if output_progress_sample:
+                self.run_inference_sample(sb.Stage.VALID)
+                self.hparams.progress_sample_logger.save(epoch)
+
+        # We also write statistics about test data to stdout and to the logfile.
+        if stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                {"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=self.last_loss_stats[sb.Stage.TEST],
+            )
+            if self.hparams.use_tensorboard:
+                self.tensorboard_logger.log_stats(
+                    {"Epoch loaded": self.hparams.epoch_counter.current},
+                    test_stats=self.last_loss_stats[sb.Stage.TEST],
+                )
+            if self.hparams.progress_samples:
+                self.run_inference_sample(sb.Stage.TEST)
+                self.hparams.progress_sample_logger.save("test")
+
+    def run_inference_sample(self, stage):
+        """Produces a sample in inference mode. This is called when producing
+        samples and can be useful because"""
+
+        if self.last_batch is None:
+            return
+        inputs, targets, _, labels, wavs, spk_embs, spk_ids = self.last_batch
+        text_padded, input_lengths, _, _, _ = inputs
+
+        mel_out, _, _ = self.hparams.model.infer(
+            text_padded[:1], spk_embs[:1], input_lengths[:1]
+        )
+        self.hparams.progress_sample_logger.remember(
+            inference_mel_out=self._get_spectrogram_sample(mel_out)
+        )
+
+        if stage == sb.Stage.VALID:
+            inf_sample_path = os.path.join(
+                self.hparams.progress_sample_path,
+                str(self.hparams.epoch_counter.current),
+            )
+
+            if not os.path.exists(inf_sample_path):
+                os.makedirs(inf_sample_path)
+
+            inf_sample_text = os.path.join(
+                self.hparams.progress_sample_path,
+                str(self.hparams.epoch_counter.current),
+                "inf_input_text.txt",
+            )
+            with open(inf_sample_text, "w") as f:
+                f.write(labels[0])
+
+            inf_input_audio = os.path.join(
+                self.hparams.progress_sample_path,
+                str(self.hparams.epoch_counter.current),
+                "inf_input_audio.wav",
+            )
+            torchaudio.save(
+                inf_input_audio,
+                sb.dataio.dataio.read_audio(wavs[0]).unsqueeze(0),
+                self.hparams.sample_rate,
+            )
+
+            if self.hparams.log_audio_samples:
+                waveform_ss = self.vocoder.decode_batch(mel_out)
+                inf_sample_audio = os.path.join(
+                    self.hparams.progress_sample_path,
+                    str(self.hparams.epoch_counter.current),
+                    "inf_output_audio.wav",
+                )
+                torchaudio.save(
+                    inf_sample_audio,
+                    waveform_ss.squeeze(1).cpu(),
+                    self.hparams.sample_rate,
+                )
+
+            if self.hparams.use_tensorboard:
+                self.tensorboard_logger.log_audio(
+                    f"{stage}/inf_audio_target",
+                    sb.dataio.dataio.read_audio(wavs[0]).unsqueeze(0),
+                    self.hparams.sample_rate,
+                )
+                if self.hparams.log_audio_samples:
+                    self.tensorboard_logger.log_audio(
+                        f"{stage}/inf_audio_pred",
+                        waveform_ss.squeeze(1),
+                        self.hparams.sample_rate,
+                    )
+                try:
+                    self.tensorboard_logger.log_figure(
+                        f"{stage}/inf_mel_target", targets[0][0]
+                    )
+                    self.tensorboard_logger.log_figure(
+                        f"{stage}/inf_mel_pred", mel_out
+                    )
+                except Exception:
+                    # This is to avoid the code from crashing in case of a mel-spectrogram with one frame
+                    pass
+
+
+def dataio_prepare(hparams):
+    # Define audio pipeline:
+
+    @sb.utils.data_pipeline.takes("wav", "label")
+    @sb.utils.data_pipeline.provides("mel_text_pair")
+    def audio_pipeline(wav, label):
+        text_seq = torch.IntTensor(
+            text_to_sequence(label, hparams["text_cleaners"])
+        )
+
+        audio, sig_sr = torchaudio.load(wav)
+        if sig_sr != hparams["sample_rate"]:
+            audio = torchaudio.functional.resample(
+                audio, sig_sr, hparams["sample_rate"]
+            )
+
+        mel = hparams["mel_spectogram"](audio=audio.squeeze())
+
+        len_text = len(text_seq)
+
+        return text_seq, mel, len_text
+
+    datasets = {}
+    data_info = {
+        "train": hparams["train_json"],
+        "valid": hparams["valid_json"],
+        "test": hparams["test_json"],
+    }
+    for dataset in hparams["splits"]:
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=data_info[dataset],
+            replacements={"data_root": hparams["data_folder"]},
+            dynamic_items=[audio_pipeline],
+            output_keys=["mel_text_pair", "wav", "label", "uttid"],
+        )
+
+        datasets[dataset] = datasets[dataset].filtered_sorted(
+            sort_key="duration",
+            key_max_value={"duration": hparams["avoid_if_longer_than"]},
+        )
+
+    return datasets
+
+
+if __name__ == "__main__":
+
+    # Load hyperparameters file with command-line overrides
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # If --distributed_launch then
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Prepare data
+    if not hparams["skip_prep"]:
+        sys.path.append("../../")
+        from libritts_prepare import prepare_libritts
+
+        sb.utils.distributed.run_on_main(
+            prepare_libritts,
+            kwargs={
+                "data_folder": hparams["data_folder"],
+                "save_json_train": hparams["train_json"],
+                "save_json_valid": hparams["valid_json"],
+                "save_json_test": hparams["test_json"],
+                "sample_rate": hparams["sample_rate"],
+                "train_split": hparams["train_split"],
+                "valid_split": hparams["valid_split"],
+                "test_split": hparams["test_split"],
+                "seed": hparams["seed"],
+                "model_name": hparams["model"].__class__.__name__,
+            },
+        )
+
+    from compute_speaker_embeddings import compute_speaker_embeddings
+
+    sb.utils.distributed.run_on_main(
+        compute_speaker_embeddings,
+        kwargs={
+            "input_filepaths": [
+                hparams["train_json"],
+                hparams["valid_json"],
+                hparams["test_json"],
+            ],
+            "output_file_paths": [
+                hparams["train_speaker_embeddings_pickle"],
+                hparams["valid_speaker_embeddings_pickle"],
+                hparams["test_speaker_embeddings_pickle"],
+            ],
+            "data_folder": hparams["data_folder"],
+            "spk_emb_encoder_path": hparams["spk_emb_encoder"],
+            "spk_emb_sr": hparams["spk_emb_sample_rate"],
+            "mel_spec_params": {
+                "custom_mel_spec_encoder": hparams["custom_mel_spec_encoder"],
+                "sample_rate": hparams["spk_emb_sample_rate"],
+                "hop_length": hparams["hop_length"],
+                "win_length": hparams["win_length"],
+                "n_mel_channels": hparams["n_mel_channels"],
+                "n_fft": hparams["n_fft"],
+                "mel_fmin": hparams["mel_fmin"],
+                "mel_fmax": hparams["mel_fmax"],
+                "mel_normalized": hparams["mel_normalized"],
+                "power": hparams["power"],
+                "norm": hparams["norm"],
+                "mel_scale": hparams["mel_scale"],
+                "dynamic_range_compression": hparams[
+                    "dynamic_range_compression"
+                ],
+            },
+            "device": run_opts["device"],
+        },
+    )
+
+    datasets = dataio_prepare(hparams)
+
+    # Brain class initialization
+    tacotron2_brain = Tacotron2Brain(
+        modules=hparams["modules"],
+        opt_class=hparams["opt_class"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    # Load pretrained model if pretrained_separator is present in the yaml
+    if "pretrained_separator" in hparams:
+        sb.utils.distributed.run_on_main(
+            hparams["pretrained_separator"].collect_files
+        )
+        hparams["pretrained_separator"].load_collected(
+            device=run_opts["device"]
+        )
+
+    if hparams["use_tensorboard"]:
+        tacotron2_brain.tensorboard_logger = sb.utils.train_logger.TensorboardLogger(
+            save_dir=hparams["output_folder"] + "/tensorboard"
+        )
+
+    # Training
+    tacotron2_brain.fit(
+        tacotron2_brain.hparams.epoch_counter,
+        train_set=datasets["train"],
+        valid_set=datasets["valid"],
+        train_loader_kwargs=hparams["train_dataloader_opts"],
+        valid_loader_kwargs=hparams["valid_dataloader_opts"],
+    )
+
+    # Test
+    if "test" in datasets:
+        tacotron2_brain.evaluate(
+            datasets["test"],
+            test_loader_kwargs=hparams["test_dataloader_opts"],
+        )
diff --git a/recipes/LibriTTS/libritts_prepare.py b/recipes/LibriTTS/libritts_prepare.py
index ea78dfdd82bc1ca17c679726df7f32bec91508f3..9419995b4d520ab8ab2359fa857d1e527b85f621 100644
--- a/recipes/LibriTTS/libritts_prepare.py
+++ b/recipes/LibriTTS/libritts_prepare.py
@@ -1,15 +1,27 @@
+"""
+LibriTTS data preparation
+
+Authors
+ * Pradnya Kandarkar 2022
+"""
+
 from speechbrain.utils.data_utils import get_all_files, download_file
-from speechbrain.processing.speech_augmentation import Resample
 import json
 import os
 import shutil
 import random
 import logging
 import torchaudio
+import torch
+from tqdm import tqdm
+from speechbrain.inference.txt import GraphemeToPhoneme
+from speechbrain.utils.text_to_sequence import _g2p_keep_punctuations
 
 logger = logging.getLogger(__name__)
 LIBRITTS_URL_PREFIX = "https://www.openslr.org/resources/60/"
 
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
 
 def prepare_libritts(
     data_folder,
@@ -18,11 +30,17 @@ def prepare_libritts(
     save_json_test,
     sample_rate,
     split_ratio=[80, 10, 10],
-    libritts_subsets=["train-clean-100"],
+    libritts_subsets=None,
+    train_split=None,
+    valid_split=None,
+    test_split=None,
+    seed=1234,
+    model_name=None,
 ):
     """
     Prepares the json files for the LibriTTS dataset.
     Downloads the dataset if it is not found in the `data_folder` as expected.
+
     Arguments
     ---------
     data_folder : str
@@ -41,23 +59,82 @@ def prepare_libritts(
     sample_rate : int
         The sample rate to be used for the dataset
     libritts_subsets: list
-        List of librispeech subsets to use (e.g., dev-clean, train-clean-100, ...).
-    Example
-    -------
-    >>> data_folder = '/path/to/LibriTTS'
-    >>> prepare_libritts(data_folder, 'train.json', 'valid.json', 'test.json', 2050)
+        List of librispeech subsets to use (e.g., dev-clean, train-clean-100, ...) for the experiment.
+        This parameter will be ignored if explicit data splits are provided.
+        Explicit data splits parameters: "train_split", "valid_split", "test_split"
+    train_split : list
+        List of librispeech subsets to use (e.g.,train-clean-100, train-clean-360) for the experiment training stage.
+    valid_split : list
+        List of librispeech subsets to use (e.g., dev-clean) for the experiment validation stage.
+    test_split : list
+        List of librispeech subsets to use (e.g., test-clean) for the experiment testing stage.
+    seed : int
+        Seed value
+    model_name : str
+        Model name (used to prepare additional model specific data)
     """
 
+    # Setting the seed value
+    random.seed(seed)
+
     # Checks if this phase is already done (if so, skips it)
     if skip(save_json_train, save_json_valid, save_json_test):
         logger.info("Preparation completed in previous run, skipping.")
         return
 
+    logger.info(
+        f"Creating {save_json_train}, {save_json_valid}, and {save_json_test}"
+    )
+
+    # If specific splits are provided, creates data manifest files accordingly
+    if train_split:
+        wav_list = prepare_split(data_folder, train_split)
+        create_json(wav_list, save_json_train, sample_rate, model_name)
+    if valid_split:
+        wav_list = prepare_split(data_folder, valid_split)
+        create_json(wav_list, save_json_valid, sample_rate, model_name)
+    if test_split:
+        wav_list = prepare_split(data_folder, test_split)
+        create_json(wav_list, save_json_test, sample_rate, model_name)
+
+    if skip(save_json_train, save_json_valid, save_json_test):
+        logger.info("Preparation completed.")
+        return
+
+    # If specific splits are not provided, and a list of subsets if provided, creates train, valid, test splits
+    # Creates data manifest files according to the data splits
+    if libritts_subsets:
+        wav_list = prepare_split(data_folder, libritts_subsets)
+        # Random split the signal list into train, valid, and test sets.
+        data_split = split_sets(wav_list, split_ratio)
+        # Creating json files
+        create_json(data_split["train"], save_json_train, sample_rate)
+        create_json(data_split["valid"], save_json_valid, sample_rate)
+        create_json(data_split["test"], save_json_test, sample_rate)
+
+
+def prepare_split(data_folder, split_list):
+    """
+    Processes the provided list of LibriTTS subsets and creates a list of all the .wav files present in the subsets.
+    Downloads the LibriTTS subsets as required.
+
+    Arguments
+    ---------
+    data_folder : str
+        Path to the folder where the LibriTTS dataset is stored
+    split_list : list
+        List of librispeech subsets to process (e.g., dev-clean, train-clean-100, ...)
+
+    Returns
+    -------
+    wav_list : list
+        List of all .wav files to be processed
+    """
     extension = [".wav"]  # The expected extension for audio files
     wav_list = list()  # Stores all audio file paths for the dataset
 
-    # For every subset of the dataset, if it doesn't exist, downloads it and sets flag to resample the subset
-    for subset_name in libritts_subsets:
+    # For every subset of the dataset, if it doesn't exist, downloads it
+    for subset_name in split_list:
 
         subset_folder = os.path.join(data_folder, subset_name)
         subset_archive = os.path.join(subset_folder, subset_name + ".tar.gz")
@@ -84,19 +161,10 @@ def prepare_libritts(
         # Collects all files matching the provided extension
         wav_list.extend(get_all_files(subset_folder, match_and=extension))
 
-    logger.info(
-        f"Creating {save_json_train}, {save_json_valid}, and {save_json_test}"
-    )
-
-    # Random split the signal list into train, valid, and test sets.
-    data_split = split_sets(wav_list, split_ratio)
-    # Creating json files
-    create_json(data_split["train"], save_json_train, sample_rate)
-    create_json(data_split["valid"], save_json_valid, sample_rate)
-    create_json(data_split["test"], save_json_test, sample_rate)
+    return wav_list
 
 
-def create_json(wav_list, json_file, sample_rate):
+def create_json(wav_list, json_file, sample_rate, model_name=None):
     """
     Creates the json file given a list of wav files.
     Arguments
@@ -107,39 +175,48 @@ def create_json(wav_list, json_file, sample_rate):
         The path of the output json file
     sample_rate : int
         The sample rate to be used for the dataset
+    model_name : str
+        Model name (used to prepare additional model specific data)
     """
 
+    # Downloads and initializes the G2P model to compute the phonemes if data is being prepared for Tacotron2 experiments
+    if model_name == "Tacotron2":
+        logger.info(
+            "Computing phonemes for labels using SpeechBrain G2P. This may take a while."
+        )
+        g2p = GraphemeToPhoneme.from_hparams(
+            "speechbrain/soundchoice-g2p", run_opts={"device": DEVICE}
+        )
+
     json_dict = {}
-    # Creates a resampler object with orig_freq set to LibriTTS sample rate (24KHz) and  new_freq set to SAMPLERATE
-    resampler = Resample(orig_freq=24000, new_freq=sample_rate)
 
     # Processes all the wav files in the list
-    for wav_file in wav_list:
+    for wav_file in tqdm(wav_list):
 
         # Reads the signal
         signal, sig_sr = torchaudio.load(wav_file)
-        signal = signal.squeeze(0)
-
+        duration = signal.shape[1] / sig_sr
         # Manipulates path to get relative path and uttid
         path_parts = wav_file.split(os.path.sep)
         uttid, _ = os.path.splitext(path_parts[-1])
         relative_path = os.path.join("{data_root}", *path_parts[-6:])
 
-        # Gets the path for the  text files and extracts the input text
-        original_text_path = os.path.join(
-            "/", *path_parts[:-1], uttid + ".original.txt"
+        # Gets the path for the text files and extracts the input text
+        normalized_text_path = os.path.join(
+            "/", *path_parts[:-1], uttid + ".normalized.txt"
         )
-        with open(original_text_path) as f:
-            original_text = f.read()
-            if original_text.__contains__("{"):
-                original_text = original_text.replace("{", "")
-            if original_text.__contains__("}"):
-                original_text = original_text.replace("}", "")
+        with open(normalized_text_path) as f:
+            normalized_text = f.read()
+            if normalized_text.__contains__("{"):
+                normalized_text = normalized_text.replace("{", "")
+            if normalized_text.__contains__("}"):
+                normalized_text = normalized_text.replace("}", "")
 
         # Resamples the audio file if required
         if sig_sr != sample_rate:
-            signal = signal.unsqueeze(0)
-            resampled_signal = resampler(signal)
+            resampled_signal = torchaudio.functional.resample(
+                signal, sig_sr, sample_rate
+            )
             os.unlink(wav_file)
             torchaudio.save(wav_file, resampled_signal, sample_rate=sample_rate)
 
@@ -148,12 +225,20 @@ def create_json(wav_list, json_file, sample_rate):
 
         # Creates an entry for the utterance
         json_dict[uttid] = {
+            "uttid": uttid,
             "wav": relative_path,
+            "duration": duration,
             "spk_id": spk_id,
-            "label": original_text,
+            "label": normalized_text,
             "segment": True if "train" in json_file else False,
         }
 
+        # Characters are used for Tacotron2, phonemes may be needed for other models
+        if model_name != "Tacotron2":
+            # Computes phoneme labels using SpeechBrain G2P and keeps the punctuations
+            phonemes = _g2p_keep_punctuations(g2p, normalized_text)
+            json_dict[uttid].update({"label_phoneme": phonemes})
+
     # Writes the dictionary to the json file
     with open(json_file, mode="w") as json_f:
         json.dump(json_dict, json_f, indent=2)
@@ -215,9 +300,3 @@ def check_folders(*folders):
         if not os.path.exists(folder):
             return False
     return True
-
-
-if __name__ == "__main__":
-    prepare_libritts(
-        "libritts_data", "train.json", "valid.json", "test.json", 16000
-    )
diff --git a/recipes/LibriTTS/vocoder/hifigan/train.py b/recipes/LibriTTS/vocoder/hifigan/train.py
index 5bc6309f6ba4f4c21688396c069d5fd416fad083..021610c66664e0fcbaaaa2dd7cd8882a997b7330 100644
--- a/recipes/LibriTTS/vocoder/hifigan/train.py
+++ b/recipes/LibriTTS/vocoder/hifigan/train.py
@@ -48,8 +48,7 @@ class HifiGanBrain(sb.Brain):
         return (y_g_hat, scores_fake, feats_fake, scores_real, feats_real)
 
     def compute_objectives(self, predictions, batch, stage):
-        """Computes and combines generator and discriminator losses
-        """
+        """Computes and combines generator and discriminator losses"""
         batch = batch.to(self.device)
         x, _ = batch.mel
         y, _ = batch.sig
@@ -64,7 +63,7 @@ class HifiGanBrain(sb.Brain):
 
         y_hat, scores_fake, feats_fake, scores_real, feats_real = predictions
         loss_g = self.hparams.generator_loss(
-            y_hat, y, scores_fake, feats_fake, feats_real
+            stage, y_hat, y, scores_fake, feats_fake, feats_real
         )
         loss_d = self.hparams.discriminator_loss(scores_fake, scores_real)
         loss = {**loss_g, **loss_d}
@@ -72,8 +71,7 @@ class HifiGanBrain(sb.Brain):
         return loss
 
     def fit_batch(self, batch):
-        """Train discriminator and generator adversarially
-        """
+        """Train discriminator and generator adversarially"""
 
         batch = batch.to(self.device)
         y, _ = batch.sig
@@ -104,8 +102,7 @@ class HifiGanBrain(sb.Brain):
         return loss_g.detach().cpu()
 
     def evaluate_batch(self, batch, stage):
-        """Evaluate one batch
-        """
+        """Evaluate one batch"""
         out = self.compute_forward(batch, stage=stage)
         loss = self.compute_objectives(out, batch, stage=stage)
         loss_g = loss["G_loss"]
@@ -153,6 +150,11 @@ class HifiGanBrain(sb.Brain):
                     "scheduler_d", self.scheduler_d
                 )
 
+            self.optimizers_dict = {
+                "optimizer_g": self.optimizer_g,
+                "optimizer_d": self.optimizer_d,
+            }
+
     def _remember_sample(self, batch, predictions):
         """Remembers samples of spectrograms and the batch for logging purposes
 
@@ -167,8 +169,7 @@ class HifiGanBrain(sb.Brain):
         y_hat, scores_fake, feats_fake, scores_real, feats_real = predictions
 
     def on_stage_end(self, stage, stage_loss, epoch):
-        """Gets called at the end of a stage (TRAIN, VALID, Or TEST)
-        """
+        """Gets called at the end of a stage (TRAIN, VALID, Or TEST)"""
         if stage == sb.Stage.VALID:
             # Update learning rate
             self.scheduler_g.step()
@@ -349,7 +350,6 @@ def dataio_prepare(hparams):
 
 
 if __name__ == "__main__":
-
     # Load hyperparameters file with command-line overrides
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
diff --git a/recipes/MEDIA/ASR/CTC/hparams/train_hf_wav2vec.yaml b/recipes/MEDIA/ASR/CTC/hparams/train_hf_wav2vec.yaml
index 4c69b677683d4be43fa94afefc1aaded3c6b497c..70ef38de7090847faa4b5ead4d5bf721ccd399ee 100644
--- a/recipes/MEDIA/ASR/CTC/hparams/train_hf_wav2vec.yaml
+++ b/recipes/MEDIA/ASR/CTC/hparams/train_hf_wav2vec.yaml
@@ -55,7 +55,7 @@ test_dataloader_options:
 sample_rate: 16000
 feats_dim: 1024
 
-# Training parameters:
+####################### Training Parameters ####################################:
 number_of_epochs: 30
 lr: 1
 lr_wav2vec: 0.0001
@@ -67,7 +67,7 @@ patient: 0
 patient_wav2vec: 0
 sorting: ascending
 
-# Model parameters:
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dnn_blocks: 3
 dnn_neurons: 512
@@ -86,7 +86,7 @@ output_neurons: 67
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec_url>
     output_norm: True
     freeze: !ref <freeze>
diff --git a/recipes/MEDIA/ASR/CTC/train_hf_wav2vec.py b/recipes/MEDIA/ASR/CTC/train_hf_wav2vec.py
index 3d512e4c602986ad68e3db5496ace8cbf5490a39..7c5d2164c242e46b95c670dabb2245f01ecbcacc 100644
--- a/recipes/MEDIA/ASR/CTC/train_hf_wav2vec.py
+++ b/recipes/MEDIA/ASR/CTC/train_hf_wav2vec.py
@@ -82,35 +82,6 @@ class ASR(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-
-        stage = sb.Stage.TRAIN
-
-        # Train.
-        predictions = self.compute_forward(batch, stage)
-        loss = self.compute_objectives(predictions, batch, stage)
-
-        # Propagate loss.
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer_wav2vec.step()
-            self.optimizer.step()
-        self.optimizer_wav2vec.zero_grad()
-        self.optimizer.zero_grad()
-
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-
-        # Evaluate.
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage)
-
-        return loss.detach()
-
     def init_optimizers(self):
         """Initializes the wav2vec2 optimizer and model optimizer"""
 
@@ -127,6 +98,11 @@ class ASR(sb.core.Brain):
             )
             self.checkpointer.add_recoverable("optimizer", self.optimizer)
 
+        self.optimizers_dict = {
+            "wav2vec_optimizer": self.optimizer_wav2vec,
+            "model_optimizer": self.optimizer,
+        }
+
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
 
@@ -306,7 +282,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol.
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_full.yaml b/recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_full.yaml
index f0ec5b5a72919a11c3f6f1163cd8aed07cb57870..4f9bad2e7ebead66974cd6f164c3a92f2b6504b9 100644
--- a/recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_full.yaml
+++ b/recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_full.yaml
@@ -57,7 +57,7 @@ test_dataloader_options:
 sample_rate: 16000
 feats_dim: 1024
 
-# Training parameters:
+####################### Training Parameters ####################################:
 number_of_epochs: 30
 lr: 1
 lr_wav2vec: 0.0001
@@ -69,7 +69,7 @@ patient: 0
 patient_wav2vec: 0
 sorting: ascending
 
-# Model parameters:
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dnn_blocks: 3
 dnn_neurons: 512
@@ -88,7 +88,7 @@ output_neurons: 212
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec_url>
     output_norm: True
     freeze: !ref <freeze>
diff --git a/recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_relax.yaml b/recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_relax.yaml
index 17fc76e85b480cacc7b6a930c61fa74cb990dcf5..8631e6e885f0e31dfb9448b51a86acf7b40d918a 100644
--- a/recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_relax.yaml
+++ b/recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_relax.yaml
@@ -57,7 +57,7 @@ test_dataloader_options:
 sample_rate: 16000
 feats_dim: 1024
 
-# Training parameters:
+####################### Training Parameters ####################################:
 number_of_epochs: 30
 lr: 1
 lr_wav2vec: 0.0001
@@ -69,7 +69,7 @@ patient: 0
 patient_wav2vec: 0
 sorting: ascending
 
-# Model parameters:
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dnn_blocks: 3
 dnn_neurons: 512
@@ -88,7 +88,7 @@ output_neurons: 141
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec_url>
     output_norm: True
     freeze: !ref <freeze>
diff --git a/recipes/MEDIA/SLU/CTC/train_hf_wav2vec.py b/recipes/MEDIA/SLU/CTC/train_hf_wav2vec.py
index 05a67032d6e8565c7c76e479c4d060051ea7c47a..31b703bc7df4eab91679d606892102ea9eb80b7e 100644
--- a/recipes/MEDIA/SLU/CTC/train_hf_wav2vec.py
+++ b/recipes/MEDIA/SLU/CTC/train_hf_wav2vec.py
@@ -96,35 +96,6 @@ class SLU(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-
-        stage = sb.Stage.TRAIN
-
-        # Train.
-        predictions = self.compute_forward(batch, stage)
-        loss = self.compute_objectives(predictions, batch, stage)
-
-        # Propagate loss.
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer_wav2vec.step()
-            self.optimizer.step()
-        self.optimizer_wav2vec.zero_grad()
-        self.optimizer.zero_grad()
-
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-
-        # Evaluate.
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage)
-
-        return loss.detach()
-
     def init_optimizers(self):
         """Initializes the wav2vec2 optimizer and model optimizer"""
 
@@ -141,6 +112,11 @@ class SLU(sb.core.Brain):
             )
             self.checkpointer.add_recoverable("optimizer", self.optimizer)
 
+        self.optimizers_dict = {
+            "wav2vec_optimizer": self.optimizer_wav2vec,
+            "model_optimizer": self.optimizer,
+        }
+
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
 
@@ -328,7 +304,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol.
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/MultiWOZ/response_generation/README.md b/recipes/MultiWOZ/response_generation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..305dd59eac0f2c58a5e67131d1702a5aae07759e
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/README.md
@@ -0,0 +1,67 @@
+# MultiWOZ Response Generation with LLM Model.
+This folder contains the scripts to finetune a LLM using MultiWOZ for the response generation task.
+You can download MultiWOZ at https://github.com/budzianowski/multiwoz.
+The data will be automatically downloaded in the specified data_folder.
+Supported LLM models are:
+ - GPT
+ - LLAMA2
+
+
+## Installing Extra Dependencies
+
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+> **Note**
+> For the Llama2 recipe, transformers and peft libraries should follow the versions mentioned in the extra_requirements.
+
+```
+cd recipes/MultiWOZ/response_generation/[LLM_model]
+pip install -r extra_requirements.txt
+```
+> **Note**
+> “Llama 2 is licensed under the LLAMA 2 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.”
+>
+> Use of the llama2 model is governed by the Meta license. In order to download the model weights and tokenizer, please visit the [website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and accept the License before starting training the llama2 model. Getting access to the original weights is usually very fast. Sometimes, It took longer to get access to the HF repo. Before proceeding, make sure that you have access to the HF repo.
+
+After getting access to the HF repo, you should log in to your HF generate a new token, and use this token to :
+```
+pip install huggingface_hub
+python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('Your_TOKEN)"
+```
+
+# How to run
+```
+cd recipes/MultiWOZ/response_generation/[LLM_model]
+python train_with_[LLM_model].py hparams/train_[LLM_model].yaml --data_folder=/your/data/folder
+```
+The data will be automatically downloaded in the specified data_folder.
+
+
+# Results
+
+| Model | Release | Hyperparams file | Test Cross-entropy Loss | Test PPL | Test BLEU 4| HuggingFace link | Full model link | GPUs |
+|:-------------:|:-------------:|:---------------------------:| :-----:| :-----:| :-----:| :-----:| :--------:|:--------:|
+| GPT2 | 2023-08-15 | train_gpt.yaml |  1.39 |  4.01 | 2.54e-04 |[model](https://huggingface.co/speechbrain/MultiWOZ-GPT-Response_Generation) | [model](https://www.dropbox.com/sh/vm8f5iavohr4zz9/AACrkOxXuxsrvJy4Cjpih9bQa?dl=0) | 1xV100 16GB |
+| LLAMA2 | 2023-10-15 | train_llama2.yaml |  1.13 |  2.90 | 7.45e-04 |[model](https://huggingface.co/speechbrain/MultiWOZ-Llama2-Response_Generation) | [model](https://www.dropbox.com/sh/d093vsje1d7ijj9/AAA-nHEd_MwNEFJfBGLmXxJra?dl=0) | 1xV100 16GB |
+
+
+
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+# **Citing**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/MultiWOZ/response_generation/gpt/extra_requirements.txt b/recipes/MultiWOZ/response_generation/gpt/extra_requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..454a21c03558b3210da0efe42581688e9cfc4b1d
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/gpt/extra_requirements.txt
@@ -0,0 +1 @@
+sacrebleu
diff --git a/recipes/MultiWOZ/response_generation/gpt/hparams/train_gpt.yaml b/recipes/MultiWOZ/response_generation/gpt/hparams/train_gpt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5bb3b8ed8fa30c797cd1ad1a3f62bdaac535a0a6
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/gpt/hparams/train_gpt.yaml
@@ -0,0 +1,136 @@
+# ########################################
+# Model: GPT2LMHeadModel +  NLL
+# Authors:
+    # Pooneh Mousavi 2023
+    # Simone Alghisi 2023
+# ########################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1995
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+
+# Dataset will be downloaded to the `data_original`
+data_folder: !PLACEHOLDER
+output_folder: !ref results/train_with_gpt2/<seed>
+replacements_path: mapping.pair
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+bleu_4_test_file: !ref <output_folder>/bleu_4_test.txt
+bleu_4_valid_file: !ref <output_folder>/bleu_4_valid.txt
+
+# URL for the gpt2 model
+gpt_hub: gpt2
+gpt_folder: !ref <save_folder>/gpt_checkpoint
+
+# Path where data manifest files will be stored
+train_annotation: !ref <output_folder>/train.json
+valid_annotation: !ref <output_folder>/dev.json
+test_annotation: !ref <output_folder>/test.json
+
+skip_prep: False
+
+# The train logger writes training statistics to a file, as well as stdout.
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+# Special tokens
+bos_token: "BOS"
+eos_token: "EOS"
+
+system_token: "SPK_1"
+user_token: "SPK_2"
+
+special_tokens: [
+    !ref <bos_token>,
+    !ref <eos_token>,
+    !ref <system_token>,
+    !ref <user_token>
+]
+
+attr_to_special_tokens:
+    "bos_token": !ref <bos_token>
+    "eos_token": !ref <eos_token>
+    "additional_special_tokens": [!ref <system_token>, !ref <user_token>]
+
+# history_window, i.e. how many user-system exchanges consider as context.
+max_history: 5
+
+ignore_index: -100
+label_smoothing: 0
+
+####################### Training Parameters ####################################
+number_of_epochs: 4
+batch_size: 8
+test_batch_size: 4
+lr: 1.97125e-4
+
+#freeze GPT model
+freeze_gptmodel: False
+num_beams: 3
+max_new_tokens: 50
+top_k: 45
+top_p: 0.9
+
+
+train_dataloader_options:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: 2
+    drop_last: False
+
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    shuffle: True
+    num_workers: 2
+    drop_last: True
+
+# Masks
+padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
+
+# gpt model
+gpt_model: !new:speechbrain.lobes.models.huggingface_transformers.gpt.GPT
+    source: !ref <gpt_hub>
+    freeze: !ref <freeze_gptmodel>
+    save_path: !ref <gpt_folder>
+    max_new_tokens: !ref <max_new_tokens>
+    num_beams: !ref <num_beams>
+    top_k: !ref  <top_k>
+    top_p: !ref <top_p>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+modules:
+    gpt_model: !ref <gpt_model>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <gpt_model>]
+
+
+ce_loss: !new:torch.nn.CrossEntropyLoss
+    ignore_index: !ref <ignore_index>
+    label_smoothing: !ref <label_smoothing>
+
+opt_class: !name:torch.optim.AdamW
+    lr: !ref <lr>
+
+
+lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.9
+    patient: 0
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        gpt_model: !ref <gpt_model>
+        lr_annealing_output: !ref <lr_annealing>
+        counter: !ref <epoch_counter>
+
+
+bleu_4_computer: !name:speechbrain.utils.bleu.BLEUStats
+    max_ngram_order: 4
+
+bleu_2_computer: !name:speechbrain.utils.bleu.BLEUStats
+    max_ngram_order: 2
diff --git a/recipes/MultiWOZ/response_generation/gpt/multiwoz_prepare.py b/recipes/MultiWOZ/response_generation/gpt/multiwoz_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..9f91b289b7957c87bf630a71984cfc9deb8310c5
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/gpt/multiwoz_prepare.py
@@ -0,0 +1 @@
+../multiwoz_prepare.py
\ No newline at end of file
diff --git a/recipes/MultiWOZ/response_generation/gpt/train_with_gpt.py b/recipes/MultiWOZ/response_generation/gpt/train_with_gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf7c3d314bc87e4c5acaeb8766ea8a27000cbaab
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/gpt/train_with_gpt.py
@@ -0,0 +1,508 @@
+#!/usr/bin/env python3
+"""
+Recipe for training a gpt_based response generation model with MultiWOZ.
+The system employs GPT2 (https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf).
+This recipe takes the GPT2LMHeadModel to fine-tune for the response generation task on the NLL.
+
+To run this recipe, do the following:
+> python train_with_gpt.py hparams/train_gpt.yaml
+
+Authors
+ * Pooneh Mousavi 2023
+ * Simone Alghisi 2023
+"""
+
+
+import sys
+import speechbrain as sb
+import torch
+from itertools import chain
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.distributed import run_on_main
+import math
+from speechbrain.dataio.batch import PaddedBatch
+
+
+class ResGenBrain(sb.Brain):
+    def compute_forward(self, batch, stage):
+        """Computation pipeline based on a gpt decoder.
+        """
+        # Get required data from batch
+        batch = batch.to(self.device)
+        input_ids, _ = batch.input_ids
+        token_type_ids, _ = batch.token_type_ids
+
+        # Forward Pass
+        padding_mask = ~self.hparams.padding_mask(
+            input_ids, pad_idx=tokenizer.unk_token_id
+        )
+        outputs = self.modules.gpt_model(
+            input_ids, token_type_ids, padding_mask
+        ).logits
+
+        return outputs
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the NLL-loss using reply as label.
+        """
+        # Get required data from batch
+        batch = batch.to(self.device)
+        ids = batch.id
+        lm_labels, labels_lens = batch.lm_labels
+        history_bos, history_lens = batch.history_bos
+        reply_eos, reply_lens = batch.reply_eos
+        history_token_type, _ = batch.history_token_type
+
+        loss = self.hparams.ce_loss(
+            predictions.flatten(end_dim=-2), lm_labels.flatten()
+        )
+
+        if stage == sb.Stage.VALID:
+            # hyps = None
+            # current_epoch = self.hparams.epoch_counter.current
+            # if current_epoch % self.hparams.valid_search_interval == 0:
+            # history_bos = torch.LongTensor([hparams["bos_index"]] + (history_bos))
+            padding_mask = ~self.hparams.padding_mask(
+                history_bos, pad_idx=tokenizer.unk_token_id
+            )
+            hyps = self.modules.gpt_model.generate(
+                history_bos.detach(),
+                history_token_type.detach(),
+                padding_mask.detach(),
+            )
+        elif stage == sb.Stage.TEST:
+            padding_mask = ~self.hparams.padding_mask(
+                history_bos, pad_idx=tokenizer.unk_token_id
+            )
+            hyps = self.modules.gpt_model.generate(
+                history_bos.detach(),
+                history_token_type.detach(),
+                padding_mask.detach(),
+                "beam",
+            )
+
+        if stage != sb.Stage.TRAIN:
+            reply_truncated = [
+                reply_eos[i][
+                    : int(reply_lens[i].item() * reply_eos.shape[1] - 1)
+                ].detach()
+                for i in range(reply_eos.shape[0])
+            ]
+            predicted_words = tokenizer.batch_decode(
+                hyps[:, history_bos.shape[1] :],
+                skip_special_tokens=True,
+                clean_up_tokenization_spaces=True,
+            )
+            target_words = tokenizer.batch_decode(
+                reply_truncated,
+                skip_special_tokens=True,
+                clean_up_tokenization_spaces=True,
+            )
+            self.bleu_4_metric.append(ids, predicted_words, target_words)
+            self.bleu_2_metric.append(ids, predicted_words, target_words)
+            if stage != sb.Stage.TRAIN:
+                self.hyps.extend(predicted_words)
+                self.references.extend(target_words)
+
+        return loss
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called at the beginning of each epoch"""
+        if stage != sb.Stage.TRAIN:
+            self.bleu_4_metric = self.hparams.bleu_4_computer()
+            self.bleu_2_metric = self.hparams.bleu_2_computer()
+            self.hyps = []
+            self.references = []
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of an epoch.
+
+        Arguments
+        ---------
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
+        stage_loss : float
+            The average loss for all of the data processed in this stage.
+        epoch : int
+            The currently-starting epoch. This is passed
+            `None` during the test stage.
+        """
+
+        # Store the train loss until the validation stage.
+        stage_stats = {"loss": stage_loss}
+        stage_stats["PPL"] = math.exp(stage_loss)
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_stats
+        else:
+            stage_stats["BLEU_4"] = self.bleu_4_metric.summarize("BLEU")
+            stage_stats["BLEU_2"] = self.bleu_2_metric.summarize("BLEU")
+        # Perform end-of-iteration things, like annealing, logging, etc.
+        if stage == sb.Stage.VALID:
+            # Update learning rate
+            old_lr, new_lr = self.hparams.lr_annealing(epoch)
+            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
+
+            # The train_logger writes a summary to stdout and to the logfile.
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={"epoch": epoch, "lr": old_lr},
+                train_stats=self.train_stats,
+                valid_stats=stage_stats,
+            )
+            # Save the current checkpoint and delete previous checkpoints.
+            self.checkpointer.save_and_keep_only(
+                meta={"PPL": stage_stats["PPL"]}, min_keys=["PPL"],
+            )
+            if epoch == hparams["number_of_epochs"] - 1:
+                with open(self.hparams.bleu_4_valid_file, "w") as w:
+                    self.bleu_4_metric.write_stats(w)
+                    for i in range(len(self.hyps)):
+                        w.write("target: " + str(self.references[i]) + "\n")
+                        w.write("predicted:" + str(self.hyps[i]) + "\n")
+                        w.write(
+                            "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+                        )
+
+        # We also write statistics about test data to stdout and to the logfile.
+        elif stage == sb.Stage.TEST:
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+            with open(self.hparams.bleu_4_test_file, "w") as w:
+                self.bleu_4_metric.write_stats(w)
+                for i in range(len(self.hyps)):
+                    w.write("target: " + str(self.references[i]) + "\n")
+                    w.write("predicted:" + str(self.hyps[i]) + "\n")
+                    w.write(
+                        "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+                    )
+
+    def init_optimizers(self):
+        "Initializes the model optimizer"
+        self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())
+
+        if self.checkpointer is not None:
+            self.checkpointer.add_recoverable("optimizer", self.optimizer)
+
+        self.optimizers_dict = {
+            "optimizer": self.optimizer,
+        }
+
+
+def add_special_tokens_(model, tokenizer, attr_to_special_token,) -> None:
+    orig_num_tokens = len(tokenizer.encoder)
+    num_added_tokens = tokenizer.add_special_tokens(
+        attr_to_special_token  # type: ignore
+    )  # doesn't add if they are already there
+    if num_added_tokens > 0:
+        model.resize_token_embeddings(
+            new_num_tokens=orig_num_tokens + num_added_tokens
+        )
+
+
+def dataio_prep(hparams, tokenizer):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined
+    functions. We expect `prepare_multiwoz` to have been called before
+    this, so that the `train.json`, `dev.json`,  and `test.json` manifest
+    files are available.
+    Arguments
+    ---------
+    hparams : dict
+        This dictionary is loaded from the `train.yaml` file, and it includes
+        all the hyperparameters needed for dataset construction and loading.
+    Returns
+    -------
+    datasets : dict
+        Contains two keys, "train" and "valid" that correspond
+        to the appropriate DynamicItemDataset object.
+    """
+
+    # convert special tokens to their ids
+    bos, eos, system, user = tokenizer.convert_tokens_to_ids(
+        hparams["special_tokens"]
+    )
+    # history_window, i.e. how many user-system exchanges consider as context (+1 to consider at least the last user turn)
+    history_window = 2 * hparams["max_history"] + 1
+
+    #  Define histoy pipeline:
+    @sb.utils.data_pipeline.takes("history")
+    @sb.utils.data_pipeline.provides(
+        "history",
+        "history_tokens_lists",
+        "history_ids",
+        "history_bos",
+        "history_token_type",
+    )
+    def history_pipeline(history):
+        yield history
+
+        # encode each turn of the history
+        history_tokens_lists = [tokenizer.encode(turn) for turn in history]
+        yield history_tokens_lists
+
+        # add speaker tokens to the history turns (user is even, system is odd)
+        # BEFORE:  [Hi how are you?], [I'm fine, thanks]
+        # AFTER:   [SPK_1 Hi how are you?], [SPK_2 I'm fine, thanks]
+        history_input_lists = [
+            [user if i % 2 == 0 else system] + encoded_turn
+            for i, encoded_turn in enumerate(history_tokens_lists)
+        ]
+
+        history_ids = history_input_lists[-history_window:]
+        # concatenate every token into a single list
+        # list(chain(*[[1, 2], [3, 4], [5]]))
+        # >>> [1, 2, 3, 4, 5]
+        history_ids = torch.LongTensor(list(chain(*history_ids)))
+        # without bos for lm_labels
+        yield history_ids
+
+        # create bos version for the input
+        history_bos = torch.cat((torch.tensor([bos]), history_ids))
+        yield history_bos
+
+        # create a mapping that associates each token in the input to a speaker
+        # INPUT: [SPK_1 Hi    how   are   you? ], [SPK_2 I'm   fine, thanks]
+        # TYPE:  [SPK_1 SPK_1 SPK_1 SPK_1 SPK_1], [SPK_2 SPK_2 SPK_2 SPK_2 ]
+        history_token_type_lists = [
+            [user if i % 2 == 0 else system] * len(encoded_turn)
+            for i, encoded_turn in enumerate(history_input_lists)
+        ]
+        history_token_type = torch.LongTensor(
+            list(
+                chain(
+                    *([[system]] + history_token_type_lists[-history_window:])
+                )
+            )
+        )
+
+        yield history_token_type
+
+    #  Define reply pipeline:
+    @sb.utils.data_pipeline.takes("reply")
+    @sb.utils.data_pipeline.provides(
+        "reply",
+        "reply_tokens_list",
+        "reply_ids",
+        "reply_eos",
+        "reply_token_type",
+    )
+    def reply_pipeline(reply):
+        yield reply
+
+        reply_tokens_list = tokenizer.encode(reply)
+        yield reply_tokens_list
+
+        # specify that the system will say the reply
+        reply_input_list = [system] + reply_tokens_list
+        reply_ids = torch.LongTensor(reply_input_list)
+        yield reply_ids
+
+        # create eos version of the reply for lm_labels
+        reply_eos = torch.cat((reply_ids, torch.tensor([eos])))
+        yield reply_eos
+
+        # specify the speaker for each token in the reply
+        reply_token_type = torch.LongTensor([system] * len(reply_input_list))
+        yield reply_token_type
+
+    # Define input_and_token_type_pipeline
+    @sb.utils.data_pipeline.takes(
+        "history_ids",
+        "history_bos",
+        "history_token_type",
+        "reply_ids",
+        "reply_eos",
+        "reply_token_type",
+    )
+    @sb.utils.data_pipeline.provides("input_ids", "token_type_ids", "lm_labels")
+    def input_and_token_type_pipeline(
+        history_ids,
+        history_bos,
+        history_token_type,
+        reply_ids,
+        reply_eos,
+        reply_token_type,
+    ):
+
+        # put history and reply together
+        # N.B. input_sequence = history_bos + reply_ids, we don't have eos in the input
+        input_ids = torch.cat((history_bos, reply_ids), -1)
+        yield input_ids
+
+        token_type_ids = torch.cat((history_token_type, reply_token_type), -1)
+        yield token_type_ids
+
+        # create the language model label (ground truth) for the current input
+        # -100 is a special tokens that is ignored during the loss computation
+        # the idea is to mask everything except the reply (withouth the speaker token)
+        # N.B. we don't have bos in the input
+        lm_labels = (
+            [hparams["ignore_index"]] * history_ids.shape[0]
+            + [hparams["ignore_index"]]
+            + reply_eos[1:].tolist()
+        )
+        lm_labels = torch.LongTensor(lm_labels)
+
+        yield lm_labels
+
+    # Define datasets. We also connect the dataset with the data processing
+    # functions defined above.
+    datasets = {}
+    data_info = {
+        "train": hparams["train_annotation"],
+        "valid": hparams["valid_annotation"],
+        "test": hparams["test_annotation"],
+    }
+    for dataset in data_info:
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=data_info[dataset],
+            replacements={"data_root": hparams["data_folder"]},
+            dynamic_items=[
+                reply_pipeline,
+                history_pipeline,
+                input_and_token_type_pipeline,
+            ],
+            output_keys=[
+                "id",
+                "input_ids",
+                "token_type_ids",
+                "history_bos",
+                "reply_eos",
+                "history_token_type",
+                "reply_token_type",
+                "lm_labels",
+            ],
+        )
+
+    return datasets
+
+
+# RECIPE BEGINS!
+if __name__ == "__main__":
+
+    # Reading command line arguments.
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    # Initialize ddp (useful only for multi-GPU DDP training).
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Load hyperparameters file with command-line overrides.
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing MultiWOZ)
+    from multiwoz_prepare import prepare_mwoz_21
+
+    run_on_main(
+        prepare_mwoz_21,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "save_folder": hparams["output_folder"],
+            "replacements_path": hparams["replacements_path"],
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Load tokenizer and add special tokens
+    tokenizer = hparams["gpt_model"].tokenizer
+
+    #  Load pretrained GPT
+    hparams["gpt_model"] = hparams["gpt_model"].to(device=run_opts["device"])
+
+    # Add special tokens to the tokenizer and resize model embedding
+    add_special_tokens_(
+        hparams["gpt_model"].model, tokenizer, hparams["attr_to_special_tokens"]
+    )
+
+    class CustomPaddedBatch(PaddedBatch):
+        """PaddedBatch with custom padding values.
+
+        See the documentation of `speechbrain.dataio.batch.PaddedBatch`.
+
+        """
+
+        def __init__(self, examples, *args, **kwargs):
+            _, _, system, _ = tokenizer.convert_tokens_to_ids(
+                hparams["special_tokens"]
+            )
+            for k in [
+                "input_ids",
+                "history_bos",
+                "lm_labels",
+                "token_type_ids",
+                "history_token_type",
+            ]:
+                max_len = max([len(x[k]) for x in examples])
+                pad_value = 0
+                if k in [
+                    "input_ids",
+                    "history_bos",
+                    "token_type_ids",
+                    "history_token_type",
+                ]:
+                    pad_value = tokenizer.unk_token_id
+                elif k == "lm_labels":
+                    pad_value = hparams["ignore_index"]
+                for example in examples:
+                    x = example[k]
+                    if k in ["history_bos", "history_token_type"]:
+                        x = torch.cat(
+                            (example[k], torch.LongTensor([system])), -1
+                        )
+                        example[k] = torch.nn.functional.pad(
+                            x, [max_len - len(x), 0], value=pad_value
+                        )
+                    else:
+                        example[k] = torch.nn.functional.pad(
+                            x, [0, max_len - len(x)], value=pad_value
+                        )
+            super().__init__(examples, *args, **kwargs)
+
+    hparams["train_dataloader_options"]["collate_fn"] = CustomPaddedBatch
+    hparams["test_dataloader_options"]["collate_fn"] = CustomPaddedBatch
+
+    # Create dataset objects "train", "valid", and "test".
+    datasets = dataio_prep(hparams, tokenizer)
+
+    # Initialize the Brain object to prepare for mask training.
+    res_gen_brain = ResGenBrain(
+        modules=hparams["modules"],
+        opt_class=hparams["opt_class"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    # We load the pretrained whisper model
+    if "pretrainer" in hparams.keys():
+        run_on_main(hparams["pretrainer"].collect_files)
+        hparams["pretrainer"].load_collected(res_gen_brain.device)
+
+    # The `fit()` method iterates the training loop, calling the methods
+    # necessary to update the parameters of the model. Since all objects
+    # with changing state are managed by the Checkpointer, training can be
+    # stopped at any point, and will be resumed on next call.
+    res_gen_brain.fit(
+        epoch_counter=res_gen_brain.hparams.epoch_counter,
+        train_set=datasets["train"],
+        valid_set=datasets["valid"],
+        train_loader_kwargs=hparams["train_dataloader_options"],
+        valid_loader_kwargs=hparams["test_dataloader_options"],
+    )
+
+    # Load the best checkpoint for evaluation
+    test_stats = res_gen_brain.evaluate(
+        test_set=datasets["test"],
+        min_key="PPL",
+        test_loader_kwargs=hparams["test_dataloader_options"],
+    )
diff --git a/recipes/MultiWOZ/response_generation/llama2/extra_requirements.txt b/recipes/MultiWOZ/response_generation/llama2/extra_requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bf2240c403ff94d35ffcf931364bb6dbb1d44621
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/llama2/extra_requirements.txt
@@ -0,0 +1,6 @@
+accelerate
+bitsandbytes
+peft<=0.5.0
+protobuf
+sacrebleu
+transformers<=4.34.0
diff --git a/recipes/MultiWOZ/response_generation/llama2/hparams/train_llama2.yaml b/recipes/MultiWOZ/response_generation/llama2/hparams/train_llama2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..507115e832948e06abc8a8501d66606df252525b
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/llama2/hparams/train_llama2.yaml
@@ -0,0 +1,119 @@
+# ########################################
+# Model: LLAMA2-chat +  NLL
+# Authors:
+    # Pooneh Mousavi 2023
+# ########################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1995
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+
+# Dataset will be downloaded to the `data_original`
+data_folder: !PLACEHOLDER
+output_folder: !ref results/train_with_llama2/<seed>
+replacements_path: ../mapping.pair
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+bleu_4_test_file: !ref <output_folder>/bleu_4_test.txt
+bleu_4_valid_file: !ref <output_folder>/bleu_4_valid.txt
+
+# URL for the LLAMA2-chat model
+model_hub: meta-llama/Llama-2-7b-chat-hf
+llama2_folder: !ref <save_folder>/llama2_checkpoint
+
+# Path where data manifest files will be stored
+train_annotation: !ref <output_folder>/train.json
+valid_annotation: !ref <output_folder>/dev.json
+test_annotation: !ref <output_folder>/test.json
+
+skip_prep: False
+
+# The train logger writes training statistics to a file, as well as stdout.
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+ckpt_interval_minutes: 30 # save checkpoint every N min
+
+# history_window, i.e. how many user-system exchanges consider as context.
+max_history: 2
+
+ignore_index: -100
+label_smoothing: 0
+
+####################### Training Parameters ####################################
+number_of_epochs: 4
+batch_size: 1
+test_batch_size: 1
+lr: 2e-4
+
+#freeze  model
+freeze_model: False
+num_beams: 3
+max_new_tokens: 50
+top_k: 45
+top_p: 0.9
+
+
+train_dataloader_options:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: 2
+    drop_last: False
+
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    shuffle: True
+    num_workers: 2
+    drop_last: True
+
+# Masks
+padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
+
+#LLAMA2 model
+llama2_model: !new:speechbrain.lobes.models.huggingface_transformers.llama2.LLAMA2
+    source: !ref <model_hub>
+    freeze: !ref <freeze_model>
+    save_path: !ref <llama2_folder>
+    max_new_tokens: !ref <max_new_tokens>
+    num_beams: !ref <num_beams>
+    top_k: !ref  <top_k>
+    top_p: !ref <top_p>
+    with_peft: True
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+modules:
+    llama2_model: !ref <llama2_model>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <llama2_model>]
+
+
+ce_loss: !new:torch.nn.CrossEntropyLoss
+    ignore_index: !ref <ignore_index>
+    label_smoothing: !ref <label_smoothing>
+
+opt_class: !name:bitsandbytes.optim.PagedAdam32bit
+    lr: !ref <lr>
+
+
+lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.9
+    patient: 0
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        llama2_model: !ref <llama2_model>
+        lr_annealing_output: !ref <lr_annealing>
+        counter: !ref <epoch_counter>
+
+
+bleu_4_computer: !name:speechbrain.utils.bleu.BLEUStats
+    max_ngram_order: 4
+
+bleu_2_computer: !name:speechbrain.utils.bleu.BLEUStats
+    max_ngram_order: 2
diff --git a/recipes/MultiWOZ/response_generation/llama2/multiwoz_prepare.py b/recipes/MultiWOZ/response_generation/llama2/multiwoz_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..9f91b289b7957c87bf630a71984cfc9deb8310c5
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/llama2/multiwoz_prepare.py
@@ -0,0 +1 @@
+../multiwoz_prepare.py
\ No newline at end of file
diff --git a/recipes/MultiWOZ/response_generation/llama2/train_with_llama2.py b/recipes/MultiWOZ/response_generation/llama2/train_with_llama2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b06ca93b3b3d6295dcd6c2de6eb81e26442f76d0
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/llama2/train_with_llama2.py
@@ -0,0 +1,456 @@
+#!/usr/bin/env python3
+"""
+Recipe for training a llama2-based  response generation model with MultiWOZ.
+The system employs LLAMA2 (https://arxiv.org/abs/2307.09288).
+This recipe takes the LLAMA2-chat to fine-tune for the response generation task on the NLL.
+
+To run this recipe, do the following:
+> python train_with_llama2.py hparams/train_llama2.yaml
+
+Authors
+ * Pooneh Mousavi 2023
+"""
+
+
+import sys
+import speechbrain as sb
+import torch
+from itertools import chain
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.distributed import run_on_main
+import math
+from speechbrain.dataio.batch import PaddedBatch
+
+
+class ResGenBrain(sb.Brain):
+    def compute_forward(self, batch, stage):
+        """Computation pipeline based on a LLAMA2.
+        """
+        # Get required data from batch
+        batch = batch.to(self.device)
+        input_ids, _ = batch.input_ids
+
+        # Forward Pass
+        padding_mask = ~self.hparams.padding_mask(
+            input_ids, pad_idx=tokenizer.pad_token_id
+        )
+        outputs = self.modules.llama2_model(input_ids, padding_mask).logits
+
+        return outputs
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the NLL-loss using reply as label.
+        """
+        # Get required data from batch
+        batch = batch.to(self.device)
+        ids = batch.id
+        lm_labels, _ = batch.lm_labels
+        prompt_bos, _ = batch.prompt_bos
+        reply_eos, reply_lens = batch.reply_eos
+
+        loss = self.hparams.ce_loss(
+            predictions.flatten(end_dim=-2), lm_labels.flatten()
+        )
+
+        if stage == sb.Stage.VALID:
+            padding_mask = ~self.hparams.padding_mask(
+                prompt_bos, pad_idx=tokenizer.pad_token_id
+            )
+            hyps = self.modules.llama2_model.generate(
+                prompt_bos.detach(), padding_mask.detach(),
+            )
+        elif stage == sb.Stage.TEST:
+            padding_mask = ~self.hparams.padding_mask(
+                prompt_bos, pad_idx=tokenizer.pad_token_id
+            )
+            hyps = self.modules.llama2_model.generate(
+                prompt_bos.detach(), padding_mask.detach(), "beam",
+            )
+
+        if stage != sb.Stage.TRAIN:
+            reply_truncated = [
+                reply_eos[i][
+                    : int(reply_lens[i].item() * reply_eos.shape[1] - 1)
+                ].detach()
+                for i in range(reply_eos.shape[0])
+            ]
+            predicted_words = tokenizer.batch_decode(
+                hyps[:, prompt_bos.shape[1] :],
+                skip_special_tokens=True,
+                clean_up_tokenization_spaces=True,
+            )
+            target_words = tokenizer.batch_decode(
+                reply_truncated,
+                skip_special_tokens=True,
+                clean_up_tokenization_spaces=True,
+            )
+            self.bleu_4_metric.append(ids, predicted_words, target_words)
+            self.bleu_2_metric.append(ids, predicted_words, target_words)
+            if stage != sb.Stage.TRAIN:
+                self.hyps.extend(predicted_words)
+                self.references.extend(target_words)
+
+        return loss
+
+    def fit_batch(self, batch):
+        """Trains the parameters given a single batch in input"""
+
+        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
+        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
+        loss.backward()
+        if self.check_gradients():
+            self.optimizer.step()
+        self.optimizer.zero_grad()
+
+        return loss.detach()
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called at the beginning of each epoch"""
+        if stage != sb.Stage.TRAIN:
+            self.bleu_4_metric = self.hparams.bleu_4_computer()
+            self.bleu_2_metric = self.hparams.bleu_2_computer()
+            self.hyps = []
+            self.references = []
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of an epoch.
+
+        Arguments
+        ---------
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
+        stage_loss : float
+            The average loss for all of the data processed in this stage.
+        epoch : int
+            The currently-starting epoch. This is passed
+            `None` during the test stage.
+        """
+
+        # Store the train loss until the validation stage.
+        stage_stats = {"loss": stage_loss}
+        stage_stats["PPL"] = math.exp(stage_loss)
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_stats
+        else:
+            stage_stats["BLEU_4"] = self.bleu_4_metric.summarize("BLEU")
+            stage_stats["BLEU_2"] = self.bleu_2_metric.summarize("BLEU")
+        # Perform end-of-iteration things, like annealing, logging, etc.
+        if stage == sb.Stage.VALID:
+            # Update learning rate
+            old_lr, new_lr = self.hparams.lr_annealing(epoch)
+            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
+
+            # The train_logger writes a summary to stdout and to the logfile.
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={"epoch": epoch, "lr": old_lr},
+                train_stats=self.train_stats,
+                valid_stats=stage_stats,
+            )
+            # Save the current checkpoint and delete previous checkpoints.
+            self.checkpointer.save_and_keep_only(
+                meta={"PPL": stage_stats["PPL"]}, min_keys=["PPL"],
+            )
+            if epoch == hparams["number_of_epochs"] - 1:
+                with open(
+                    self.hparams.bleu_4_valid_file, "w", encoding="utf-8"
+                ) as w:
+                    self.bleu_4_metric.write_stats(w)
+                    for i in range(len(self.hyps)):
+                        w.write("target: " + str(self.references[i]) + "\n")
+                        w.write("predicted:" + str(self.hyps[i]) + "\n")
+                        w.write(
+                            "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+                        )
+
+        # We also write statistics about test data to stdout and to the logfile.
+        elif stage == sb.Stage.TEST:
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+            with open(
+                self.hparams.bleu_4_test_file, "w", encoding="utf-8"
+            ) as w:
+                self.bleu_4_metric.write_stats(w)
+                for i in range(len(self.hyps)):
+                    w.write("target: " + str(self.references[i]) + "\n")
+                    w.write("predicted:" + str(self.hyps[i]) + "\n")
+                    w.write(
+                        "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+                    )
+
+    def init_optimizers(self):
+        "Initializes the model optimizer"
+        self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())
+
+        if self.checkpointer is not None:
+            self.checkpointer.add_recoverable("optimizer", self.optimizer)
+
+    def zero_grad(self, set_to_none=False):
+        self.optimizer.zero_grad(set_to_none)
+
+
+def add_special_tokens_(model, tokenizer, attr_to_special_token,) -> None:
+    orig_num_tokens = len(tokenizer)
+    num_added_tokens = tokenizer.add_special_tokens(
+        attr_to_special_token  # type: ignore
+    )  # doesn't add if they are already there
+    if num_added_tokens > 0:
+        model.resize_token_embeddings(
+            new_num_tokens=orig_num_tokens + num_added_tokens
+        )
+
+
+def dataio_prep(hparams, tokenizer):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined
+    functions. We expect `prepare_multiwoz` to have been called before
+    this, so that the `train.json`, `dev.json`,  and `test.json` manifest
+    files are available.
+    Arguments
+    ---------
+    hparams : dict
+        This dictionary is loaded from the `train.yaml` file, and it includes
+        all the hyperparameters needed for dataset construction and loading.
+    Returns
+    -------
+    datasets : dict
+        Contains two keys, "train" and "valid" that correspond
+        to the appropriate DynamicItemDataset object.
+    """
+
+    # history_window, i.e. how many user-system exchanges consider as context (+1 to consider at least the last user turn)
+    history_window = 2 * hparams["max_history"] + 1
+
+    #  Define histoy pipeline:
+    @sb.utils.data_pipeline.takes("history")
+    @sb.utils.data_pipeline.provides(
+        "prompts", "propmt_tokens_lists", "prompt_ids", "prompt_bos",
+    )
+    def history_pipeline(history):
+        # add INST tokens to the history turns for turns associated with user.
+        # BEFORE:  [ Hi how are you? ], [I'm fine, thanks]
+        # AFTER:   [[INST] Hi how are you? [/INST]], [I'm fine, thanks]
+
+        def generate_prompt(idx_and_item):
+            index, item = idx_and_item
+            if index % 2 == 0:
+                return "[INST] " + item + " [/INST]"
+            else:
+                return item
+
+        prompts = list(map(generate_prompt, enumerate(history)))
+        yield prompts
+
+        # encode each turn of the history
+        propmt_tokens_lists = [tokenizer.encode(turn) for turn in prompts]
+        yield propmt_tokens_lists
+
+        prompt_ids = propmt_tokens_lists[-history_window:]
+        # concatenate every token into a single list
+        # list(chain(*[[1, 2], [3, 4], [5]]))
+        # >>> [1, 2, 3, 4, 5]
+        prompt_ids = torch.LongTensor(list(chain(*prompt_ids)))
+        # without bos for lm_labels
+        yield prompt_ids
+
+        # # create bos version for the input
+        prompt_bos = torch.cat(
+            (torch.tensor([tokenizer.bos_token_id]), prompt_ids)
+        )
+        yield prompt_bos
+
+    #  Define reply pipeline:
+    @sb.utils.data_pipeline.takes("reply")
+    @sb.utils.data_pipeline.provides(
+        "reply", "reply_tokens_list", "reply_ids", "reply_eos",
+    )
+    def reply_pipeline(reply):
+        yield reply
+
+        reply_tokens_list = tokenizer.encode(reply)
+        yield reply_tokens_list
+
+        reply_ids = torch.LongTensor(reply_tokens_list)
+        yield reply_ids
+
+        # create eos version of the reply for lm_labels
+        reply_eos = torch.cat(
+            (reply_ids, torch.tensor([tokenizer.eos_token_id]))
+        )
+        yield reply_eos
+
+    # Define input_and_token_type_pipeline
+    @sb.utils.data_pipeline.takes(
+        "prompt_ids", "prompt_bos", "reply_ids", "reply_eos",
+    )
+    @sb.utils.data_pipeline.provides("input_ids", "lm_labels")
+    def input_and_token_type_pipeline(
+        prompt_ids, prompt_bos, reply_ids, reply_eos,
+    ):
+
+        # put history and reply together
+        # N.B. input_sequence = history_ids + reply_ids, we don't have eos in the input
+        input_ids = torch.cat((prompt_bos, reply_ids), -1)
+        yield input_ids
+
+        # create the language model label (ground truth) for the current input
+        # -100 is a special tokens that is ignored during the loss computation
+        # the idea is to mask everything except the reply (withouth the speaker token)
+        # N.B. we don't have bos in the input
+        lm_labels = [hparams["ignore_index"]] * prompt_ids.shape[
+            0
+        ] + reply_eos.tolist()
+        lm_labels = torch.LongTensor(lm_labels)
+
+        yield lm_labels
+
+    # Define datasets. We also connect the dataset with the data processing
+    # functions defined above.
+    datasets = {}
+    data_info = {
+        "train": hparams["train_annotation"],
+        "valid": hparams["valid_annotation"],
+        "test": hparams["test_annotation"],
+    }
+    for dataset in data_info:
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=data_info[dataset],
+            replacements={"data_root": hparams["data_folder"]},
+            dynamic_items=[
+                reply_pipeline,
+                history_pipeline,
+                input_and_token_type_pipeline,
+            ],
+            output_keys=[
+                "id",
+                "input_ids",
+                "prompt_bos",
+                "reply_eos",
+                "lm_labels",
+            ],
+        )
+
+    return datasets
+
+
+# RECIPE BEGINS!
+if __name__ == "__main__":
+
+    # Reading command line arguments.
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    # Initialize ddp (useful only for multi-GPU DDP training).
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Load hyperparameters file with command-line overrides.
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing MultiWOZ)
+    from multiwoz_prepare import prepare_mwoz_21
+
+    run_on_main(
+        prepare_mwoz_21,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "save_folder": hparams["output_folder"],
+            "replacements_path": hparams["replacements_path"],
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Load tokenizer and add special tokens
+    tokenizer = hparams["llama2_model"].tokenizer
+
+    #  Load pretrained LLAMA2
+    hparams["llama2_model"] = hparams["llama2_model"].to(
+        device=run_opts["device"]
+    )
+
+    # Add special tokens to the tokenizer and resize model embedding
+    add_special_tokens_(
+        hparams["llama2_model"].model, tokenizer, {"pad_token": "<pad>"},
+    )
+
+    class CustomPaddedBatch(PaddedBatch):
+        """PaddedBatch with custom padding values.
+
+        See the documentation of `speechbrain.dataio.batch.PaddedBatch`.
+
+        """
+
+        def __init__(self, examples, *args, **kwargs):
+            for k in [
+                "input_ids",
+                "prompt_bos",
+                "lm_labels",
+            ]:
+                max_len = max([len(x[k]) for x in examples])
+                pad_value = 0
+                if k in [
+                    "input_ids",
+                    "prompt_bos",
+                ]:
+                    pad_value = tokenizer.pad_token_id
+                elif k == "lm_labels":
+                    pad_value = hparams["ignore_index"]
+                for example in examples:
+                    x = example[k]
+                    if k in ["prompt_bos"]:
+                        example[k] = torch.nn.functional.pad(
+                            x, [max_len - len(x), 0], value=pad_value
+                        )
+                    else:
+                        example[k] = torch.nn.functional.pad(
+                            x, [0, max_len - len(x)], value=pad_value
+                        )
+            super().__init__(examples, *args, **kwargs)
+
+    hparams["train_dataloader_options"]["collate_fn"] = CustomPaddedBatch
+    hparams["test_dataloader_options"]["collate_fn"] = CustomPaddedBatch
+
+    # Create dataset objects "train", "valid", and "test".
+    datasets = dataio_prep(hparams, tokenizer)
+
+    # Initialize the Brain object to prepare for mask training.
+    res_gen_brain = ResGenBrain(
+        modules=hparams["modules"],
+        opt_class=hparams["opt_class"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    # We load the pretrained whisper model
+    if "pretrainer" in hparams.keys():
+        run_on_main(hparams["pretrainer"].collect_files)
+        hparams["pretrainer"].load_collected(res_gen_brain.device)
+
+    # The `fit()` method iterates the training loop, calling the methods
+    # necessary to update the parameters of the model. Since all objects
+    # with changing state are managed by the Checkpointer, training can be
+    # stopped at any point, and will be resumed on next call.
+    res_gen_brain.fit(
+        epoch_counter=res_gen_brain.hparams.epoch_counter,
+        train_set=datasets["train"],
+        valid_set=datasets["valid"],
+        train_loader_kwargs=hparams["train_dataloader_options"],
+        valid_loader_kwargs=hparams["test_dataloader_options"],
+    )
+
+    # Load the best checkpoint for evaluation
+    test_stats = res_gen_brain.evaluate(
+        test_set=datasets["test"],
+        min_key="PPL",
+        test_loader_kwargs=hparams["test_dataloader_options"],
+    )
diff --git a/recipes/MultiWOZ/response_generation/mapping.pair b/recipes/MultiWOZ/response_generation/mapping.pair
new file mode 100644
index 0000000000000000000000000000000000000000..34df41d01e93ce27039e721e1ffb55bf9267e5a2
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/mapping.pair
@@ -0,0 +1,83 @@
+it's	it is
+don't	do not
+doesn't	does not
+didn't	did not
+you'd	you would
+you're	you are
+you'll	you will
+i'm	i am
+they're	they are
+that's	that is
+what's	what is
+couldn't	could not
+i've	i have
+we've	we have
+can't	cannot
+i'd	i would
+i'd	i would
+aren't	are not
+isn't	is not
+wasn't	was not
+weren't	were not
+won't	will not
+there's	there is
+there're	there are
+. .	.
+restaurants	restaurant -s
+hotels	hotel -s
+laptops	laptop -s
+cheaper	cheap -er
+dinners	dinner -s
+lunches	lunch -s
+breakfasts	breakfast -s
+expensively	expensive -ly
+moderately	moderate -ly
+cheaply	cheap -ly
+prices	price -s
+places	place -s
+venues	venue -s
+ranges	range -s
+meals	meal -s
+locations	location -s
+areas	area -s
+policies	policy -s
+children	child -s
+kids	kid -s
+kidfriendly	kid friendly
+cards	card -s
+upmarket	expensive
+inpricey	cheap
+inches	inch -s
+uses	use -s
+dimensions	dimension -s
+driverange	drive range
+includes	include -s
+computers	computer -s
+machines	machine -s
+families	family -s
+ratings	rating -s
+constraints	constraint -s
+pricerange	price range
+batteryrating	battery rating
+requirements	requirement -s
+drives	drive -s
+specifications	specification -s
+weightrange	weight range
+harddrive	hard drive
+batterylife	battery life
+businesses	business -s
+hours	hour -s
+one	1
+two	2
+three	3
+four	4
+five	5
+six	6
+seven	7
+eight	8
+nine	9
+ten	10
+eleven	11
+twelve	12
+anywhere	any where
+good bye	goodbye
diff --git a/recipes/MultiWOZ/response_generation/multiwoz_prepare.py b/recipes/MultiWOZ/response_generation/multiwoz_prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..201cab3288b1ab9ed71d903f89c92500ffe3cf5a
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/multiwoz_prepare.py
@@ -0,0 +1,675 @@
+from itertools import product
+from statistics import mean
+from typing import Any, Dict, List, Optional, Set, Tuple
+import json
+import logging
+import os
+import re
+import shutil
+from tqdm import tqdm
+from speechbrain.utils.data_utils import download_file
+
+"""
+Data preparation.
+Download: https://github.com/budzianowski/multiwoz/tree/master/data
+
+The original one can be found at:
+https://github.com/jasonwu0731/trade-dst/blob/master/create_data.py
+Author
+------
+ * Pooneh Mousavi 2023
+ * Simone Alghisi 2023
+"""
+
+logger = logging.getLogger(__name__)
+MULTIWOZ_21_DATASET_URL = (
+    "https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip"
+)
+
+
+def prepare_mwoz_21(
+    data_folder: str, save_folder: str, replacements_path: str, skip_prep=False,
+) -> None:
+
+    """
+    This class prepares the JSON files for the MultiWOZ dataset.
+    Data will be automatically downloaded in the data_folder.
+    Download link: https://github.com/budzianowski/multiwoz/tree/master/data
+
+    Arguments
+    ---------
+    data_folder : str
+        Path to the folder where the original MultiWOZ dataset is stored.
+    save_folder : str
+        The directory where to store the JSON files.
+    replacements_path: str
+        File containing (from, to) pairs, one per line for preprocessing the text.
+    skip_prep: bool
+        If True, data preparation is skipped.
+
+
+    Example
+    -------
+    >>> data_folder = 'data/MultiWOZ_2.1'
+    >>> save_folder = 'MultiWOZ_prepared'
+    >>> replacements_path = 'mapping.pair'
+    >>> prepare_mwoz_21(data_folder, save_folder, replacements_path)
+    """
+
+    if skip_prep:
+        return
+
+    # Saving folder
+    if not os.path.exists(save_folder):
+        os.makedirs(save_folder)
+
+    # Setting ouput files
+    save_train = save_folder + "/train.json"
+    save_dev = save_folder + "/dev.json"
+    save_test = save_folder + "/test.json"
+
+    # If csv already exists, we skip the data preparation
+    if skip(save_train, save_dev, save_test):
+
+        msg = "%s already exists, skipping data preparation!" % (save_train)
+        logger.info(msg)
+
+        msg = "%s already exists, skipping data preparation!" % (save_dev)
+        logger.info(msg)
+
+        msg = "%s already exists, skipping data preparation!" % (save_test)
+        logger.info(msg)
+
+        return
+
+    # Download dataset
+    download_mwoz_21(data_folder)
+    data_folder = os.path.join(data_folder, "MultiWOZ_21")
+
+    # Additional checks to make sure the data folder contains MultiWOZ
+    check_multiwoz_folders(data_folder)
+
+    data_path = os.path.join(data_folder, "data.json")
+    train_split, dev_split, test_split = get_splits(data_folder)
+    # Creating json files for {train, dev, test} data
+    file_pairs = zip(
+        [train_split, dev_split, test_split], [save_train, save_dev, save_test],
+    )
+
+    for split, save_file in file_pairs:
+        build_dialogue_dataset(
+            data_path, split, save_file, replacements_path,
+        )
+
+
+def check_multiwoz_folders(data_folder):
+    """
+    Check if the data folder actually contains the MultiWOZ dataset.
+    If not, raises an error.
+    Returns
+    -------
+    None
+    Raises
+    ------
+    FileNotFoundError
+        If the data folder doesn't contain the MultiWOZ dataset.
+    """
+    files_str = "/data.json"
+    # Checking clips
+    if not os.path.exists(data_folder + files_str):
+        err_msg = (
+            "the folder %s does not exist (it is expected in "
+            "the MultiWOZ dataset)" % (data_folder + files_str)
+        )
+        raise FileNotFoundError(err_msg)
+
+
+def download_mwoz_21(destination):
+    """ Download the dataset repo, unpack it, and remove unnecessary elements.
+    Arguments
+    ---------
+    destination: str
+        Place to put dataset.
+    """
+    mwoz_21_archive = os.path.join(destination, "MultiWOZ_21.zip")
+    download_file(MULTIWOZ_21_DATASET_URL, mwoz_21_archive)
+    shutil.unpack_archive(mwoz_21_archive, destination)
+    shutil.rmtree(os.path.join(destination, "__MACOSX"))
+
+    mwoz_21 = os.path.join(destination, "MultiWOZ_21")
+    os.makedirs(mwoz_21, exist_ok=True)
+
+    mwoz_21_repo = os.path.join(destination, "MultiWOZ_2.1")
+    for relevant_file in ["data.json", "valListFile.txt", "testListFile.txt"]:
+        shutil.move(
+            os.path.join(mwoz_21_repo, relevant_file),
+            os.path.join(mwoz_21, relevant_file),
+        )
+
+    shutil.rmtree(mwoz_21_repo)
+
+
+def skip(save_train, save_dev, save_test):
+    """
+    Detects if the MultiWOZ data preparation has been already done.
+    If the preparation has been done, we can skip it.
+    Returns
+    -------
+    bool
+        if True, the preparation phase can be skipped.
+        if False, it must be done.
+    """
+
+    # Checking folders and save options
+    skip = False
+
+    if (
+        os.path.isfile(save_train)
+        and os.path.isfile(save_dev)
+        and os.path.isfile(save_test)
+    ):
+        skip = True
+
+    return skip
+
+
+def get_splits(dataset_folder) -> Tuple[List[str], List[str], List[str]]:
+    mwoz_21_dialouges = get_json_object(
+        os.path.join(dataset_folder, "data.json")
+    )
+    dialougues_keys: Set[str] = set(mwoz_21_dialouges.keys())
+    tr_split: List[str] = []
+    with open(os.path.join(dataset_folder, "valListFile.txt")) as f:
+        dev_split: List[str] = [key.strip() for key in f]
+    with open(os.path.join(dataset_folder, "testListFile.txt")) as f:
+        te_split: List[str] = [key.strip() for key in f]
+
+    for key in dialougues_keys:
+        if key not in dev_split and key not in te_split:
+            tr_split.append(key)
+
+    return tr_split, dev_split, te_split
+
+
+def build_dialogue_dataset(
+    data_path: str,
+    data_split: List[str],
+    save_file: str,
+    replacements_path: str,
+) -> None:
+    """
+    Returns the dialogue dataset for the corresponding data_path.
+
+    Arguments
+    ---------
+    data_path: str
+     Path to the folder where the original MultiWOZ dataset is stored.
+    data_split: list of str
+        List of strings containing MultiWOZ 2.1 keys of the dialogues
+        associated with a certain split (train, dev, test).
+    save_file: str
+        Path of the file where the dataset will be saved.
+    replacements_path: str
+        Path to file containing (from, to) pairs, one per line.
+
+    Returns
+    -------
+    dataset:
+        dataset, keys are str, values are dictionaries containing the
+        dialogue history, the system reply, and the mean length.
+    """
+    logger.info(f"Prepare {save_file}")
+    encode_dialogue_dataset(
+        save_file, data_path, data_split, replacements_path,
+    )
+
+
+def encode_dialogue_dataset(
+    save_file: str,
+    data_path: str,
+    data_split: List[str],
+    replacements_path: str,
+) -> None:
+    """
+    Wrapper function that loads processed data stored at
+    dst_folder/file_name. If they are not available, it processes the
+    original data and then saves them at dst_folder/file_name.
+
+    Arguments
+    ---------
+    save_file: str
+        Path of the file where the dataset will be saved.
+    data_path: str
+        Path to the folder where the original MultiWOZ dataset is stored.
+    data_split: list of str
+        List of strings containing MultiWOZ 2.1 keys of the dialogues
+        associated with a certain split (train, dev, test).
+    replacements_path: str
+        Path to file containing (from, to) pairs, one per line.
+    """
+
+    replacements = get_replacements(replacements_path)
+    logger.info(f"Extract dialogues from {data_path}")
+    # custom loading function to return the important elements of a dialogue
+    dialogues = load_dialogues(data_path, data_split, replacements)
+
+    logger.info("Create dataset")
+    dataset = create_dialogue_dataset(dialogues)
+    logger.info(f"Save dataset in {save_file}")
+    save_dialogue_dataset(dataset, save_file)
+
+
+def get_replacements(
+    replacements_path: str = "trade/utils/mapping.pair",
+) -> List[Tuple[str, str]]:
+    """
+    Get the replacements from a given file. Used by trade preprocessing.
+
+    Arguments
+    ---------
+    replacements_path: str
+        File containing from, to pairs, one per line.
+
+    Returns
+    -------
+    replacements: List of replacements, i.e. pairs of str
+        Pairs of elements used to substitute the first element with the second.
+    """
+    replacements = []
+    with open(replacements_path, "r") as fin:
+        for line in fin.readlines():
+            tok_from, tok_to = line.replace("\n", "").split("\t")
+            replacements.append((" " + tok_from + " ", " " + tok_to + " "))
+    return replacements
+
+
+def load_dialogues(
+    data_path: str, data_split: List[str], replacements: List[Tuple[str, str]],
+) -> List[List[Dict[str, Any]]]:
+    """
+    Load dialogues from data_path, apply trade pre-processing, revert the
+    subtokenization, and create a dictionary containing the dialogue id,
+    the turn id, and the corrected sequence.
+
+    Arguments
+    ---------
+    data_path: str
+        Path to the json file containing the data.
+    data_split: list of str
+        List of string containing MultiWOZ 2.1 keys of the dialogues
+        associated to a certain split (train, dev, test).
+    replacements_path: str
+        File containing (from, to) pairs, one per line.
+
+    Returns
+    -------
+    dialogues: list of list of dict, keys are str, values could be anything
+        List of dialogues. Each dialogue is a list of turns. Each turn is a
+        dict containing dialogue_idx, turn_idx, and the corrected sequence.
+    """
+
+    def get_preprocessed_seq(
+        original_seq: str, replacements: List[Tuple[str, str]]
+    ) -> str:
+        # apply trade normalization
+        trade_seq = normalize(original_seq, replacements)
+        # merge back subtokens
+        sequence = invert_trade_subtokenization(original_seq, trade_seq)
+        return sequence
+
+    dialogues: List[List[Dict[str, Any]]] = []
+
+    data = get_json_object(data_path)
+
+    for dialogue_idx in tqdm(data_split, desc="Load Dialogues"):
+        dial: List[Dict[str, Any]] = []
+        original_dialogue: dict = data[dialogue_idx]
+        turns: dict = original_dialogue["log"]
+        for i, turn in enumerate(turns):
+            sequence = get_preprocessed_seq(turn["text"], replacements)
+            to_save = {
+                "sequence": sequence,
+                "turn_idx": i,
+                "dialogue_idx": dialogue_idx,
+            }
+            dial.append(to_save)
+        dialogues.append(dial)
+    return dialogues
+
+
+def normalize(text, replacements):
+    # lower case every word
+    text = text.lower()
+
+    # replace white spaces in front and end
+    text = re.sub(r"^\s*|\s*$", "", text)
+
+    # hotel domain pfb30
+    text = re.sub(r"b&b", "bed and breakfast", text)
+    text = re.sub(r"b and b", "bed and breakfast", text)
+
+    # weird unicode bug
+    text = re.sub("(\u2018|\u2019)", "'", text)
+
+    # replace st.
+    text = text.replace(";", ",")
+    text = re.sub(r"$\/", "", text)
+    text = text.replace("/", " and ")
+
+    # replace other special characters
+    text = text.replace("-", " ")
+    text = re.sub(r'["\<>@\(\)]', "", text)  # remove
+
+    # insert white space before and after tokens:
+    for token in ["?", ".", ",", "!"]:
+        text = insertSpace(token, text)
+
+    # insert white space for 's
+    text = insertSpace("'s", text)
+
+    # replace it's, does't, you'd ... etc
+    text = re.sub("^'", "", text)
+    text = re.sub(r"'$", "", text)
+    text = re.sub(r"'\s", " ", text)
+    text = re.sub(r"\s'", " ", text)
+    for fromx, tox in replacements:
+        text = " " + text + " "
+        text = text.replace(fromx, tox)[1:-1]
+
+    # remove multiple spaces
+    text = re.sub(" +", " ", text)
+
+    # concatenate numbers
+    tokens = text.split()
+    i = 1
+    while i < len(tokens):
+        if re.match(r"^\d+$", tokens[i]) and re.match(r"\d+$", tokens[i - 1]):
+            tokens[i - 1] += tokens[i]
+            del tokens[i]
+        else:
+            i += 1
+    text = " ".join(tokens)
+    return text
+
+
+def insertSpace(token, text):
+    sidx = 0
+    while True:
+        sidx = text.find(token, sidx)
+        if sidx == -1:
+            break
+        if (
+            sidx + 1 < len(text)
+            and re.match("[0-9]", text[sidx - 1])
+            and re.match("[0-9]", text[sidx + 1])
+        ):
+            sidx += 1
+            continue
+        if text[sidx - 1] != " ":
+            text = text[:sidx] + " " + text[sidx:]
+            sidx += 1
+        if sidx + len(token) < len(text) and text[sidx + len(token)] != " ":
+            text = text[: sidx + 1] + " " + text[sidx + 1 :]
+        sidx += 1
+    return text
+
+
+TOKEN_EXCEPTIONS = {
+    "childs": "children",
+    "businesss": "businesses",
+    "inchs": "inches",
+}
+PATTERN_EXCEPTIONS = {"breakfasts": "b&bs"}
+
+
+def invert_trade_subtokenization(
+    original_seq: str,
+    trade_seq: str,
+    token_exceptions: Dict[str, str] = TOKEN_EXCEPTIONS,
+    pattern_exceptions: Dict[str, str] = PATTERN_EXCEPTIONS,
+    subtoken_special_chrs: List[str] = [" -", " _"],
+) -> str:
+    """
+    Invert all trade subtokenizations in a string given the original sequence.
+
+    Arguments
+    ---------
+    original_seq: str
+        The original sequence.
+    trade_seq: str
+        The sequence that has been pre-processed by trade.
+    token_exceptions: dict, keys are str, values are str
+        A dictionary to map merged token to their correct counterpart. E.g.
+        child -s is merged into childs, but the correct token is children.
+    pattern_exceptions: dict, keys are str, values are str
+        A dictionary to map patterns to their correct counterpart. E.g.
+        after the pre-processing "b&bs" is mapped to "bed and breakfast -s",
+        making the search of breakfasts impossible if not handled by such
+        exceptions.
+    subtoken_special_chrs: list of str
+        List containing the special characters that are used for subtokens.
+
+    Returns
+    -------
+    corrected_seq: str
+        The sequence corrected, i.e. subtokens replaced by tokens.
+    """
+    regex = "|".join(subtoken_special_chrs)
+    subtoken_pieces = re.split(regex, trade_seq, maxsplit=1)
+    search_after: int = 0
+    while len(subtoken_pieces) > 1:
+        # example: 'the wind is moderate -ly strong'
+        # split: ['the wind is moderate ', 'ly strong']
+        # split[0]: 'the wind is moderate' --> split on whitespace ['the', 'wind', 'is', 'moderate']
+        left_side = subtoken_pieces[0].split()
+        subtoken_left = left_side[-1]
+        # split[1]: 'ly strong' --> split on whitespace ['ly', 'strong']
+        right_side = subtoken_pieces[1].split()
+        subtoken_right = right_side[0]
+        # try merging the subtoken parts to form a token, i.e. moderate + ly
+        token = "".join([subtoken_left, subtoken_right])
+
+        if token in token_exceptions:
+            # if you match an exception, replace the token with the exception
+            token = token_exceptions[token]
+
+        # assume there are no tokens on left and right side of the subtokens' pieces
+        left_token = None  # if token is at the beginnig
+        right_token = None  # if token is at the end
+        # try looking for them
+        if len(left_side) > 1:
+            left_token = left_side[-2]
+        if len(right_side) > 1:
+            right_token = right_side[1]
+
+        # start from a complete match, and progressively remove left and right
+        # tokens to counter TRADE preprocessing of some tokens
+        # The order is
+        # 1. True, True
+        # 2. True, False
+        # 3. False, True
+        # 4. False, False
+        # basically, at the end you try looking only for the merged token
+        pattern: str = ""
+        idx: int = -1
+        for use_left, use_right in product((True, False), (True, False)):
+            pattern = token
+            if (left_token is not None) and use_left:
+                pattern = " ".join([left_token, pattern])
+            if right_token is not None and use_right:
+                pattern = " ".join([pattern, right_token])
+
+            # check if the pattern is in the exceptions
+            if pattern in pattern_exceptions:
+                pattern = pattern_exceptions[pattern]
+            # Search the pattern
+            idx = original_seq[search_after:].lower().find(pattern)
+            if idx > -1:
+                break
+
+        error: str = f"""
+            Pattern search failed in the following case:
+            PATTERN =  \t{pattern}
+            LEFT SIDE = \t{left_side}
+            RIGHT SIDE = \t{right_side}
+            ORIG SEQ = \t{original_seq[search_after:]}
+
+            This may be due to further TRADE pre-processing, or not correct merging operation.
+            To solve this, add a special rule for the token that breaks the code either as a
+            token_exception or a pattern_exception.
+        """
+
+        assert idx > -1, error
+        # move the index to avoid perfect matches with the same token
+        # TODO is probably better to move it of len(left_token + token) or
+        # len(token) depending on the match
+        search_after += idx + 1
+        # reconstruct the sentence with the matched pattern
+        trade_seq = " ".join([*left_side[:-1], token, *right_side[1:]])
+
+        # try splitting the sentence again and repeat the process
+        subtoken_pieces = re.split(regex, trade_seq, maxsplit=1)
+    # Good, no subtokens found: return trade seq
+    return trade_seq
+
+
+def get_json_object(data_path: str) -> dict:
+    """
+    A function to read a json object and return the python
+    dictionary associated to it.
+
+    Arguments
+    ---------
+    data_path: str
+        Path to a json file.
+
+    Returns
+    -------
+    loaded_json: dict
+        A loaded json object.
+    """
+    with open(data_path, "r") as data_file:
+        data = json.load(data_file)
+
+    return data
+
+
+def create_dialogue_dataset(
+    dialogues: List[List[Dict[str, Any]]]
+) -> Dict[str, Dict[str, Any]]:
+    """
+    Creates a dialogue dataset starting from a set of dialogues. Each
+    entry of the dataset contains the dialogue history and the system
+    reply in response to that.
+
+    Arguments
+    ---------
+    dialogues: list of list of dict, keys are str, values could be anything
+        List of dialogues. Each dialogue is a list of turns. Each turn is a
+        dict containing dialogue_idx, turn_idx, and the corrected sequence.
+    kwargs: any
+        Additional arguments for the current function.
+
+    Returns
+    -------
+    dataset: Dict[str, Dict[str, Any]]
+        Dataset, keys are str, values are dictionaries containing the
+        dialogue history and the system reply.
+    """
+
+    def create_dialogue_dataset_entry(
+        turn: Dict[str, Any], history: List[str]
+    ) -> Optional[Dict[str, Any]]:
+        """
+        Creates an entry if the current turn id is odd. An entry is
+        composed of the history, which contains the previous turns
+        of the current dialogue, and the reply of the system.
+
+        Arguments
+        ---------
+        turn: dict, keys are str, values could be anything
+            A dict containing, the dialogue id, the turn id, the sequence,
+            and the mean length.
+        replacements_path: str
+            Path to TRADE file containing (from, to) pairs, one per line.
+        kwargs: any
+            Additional arguments for the current function.
+
+        Returns
+        -------
+        entry: optional dict, keys are str, values could be anything
+            Entry of the dialogue dataset. It is a dict containing the history
+            of the dialogue, i.e. a list of turns, the reply of the system,
+            i.e. a turn, and the mean length.
+        """
+
+        turn_idx = turn["turn_idx"]
+        entry: Optional[Dict[str, Any]] = None
+        if turn_idx % 2 == 0:
+            # user turn, simply append it to the history
+            user_seq: str = turn["sequence"]
+            history.append(user_seq)
+        elif turn_idx % 2 == 1:
+            # system turn, create the dataset entry, and the append it to the history
+            system_seq: str = turn["sequence"]
+            history_mean_length = mean([len(turn) for turn in history])
+            entry = {
+                "history": history.copy(),
+                "reply": system_seq,
+                "length": history_mean_length + len(system_seq),
+            }
+            history.append(system_seq)
+        return entry
+
+    dataset: Dict[str, Dict[str, Any]] = {}
+    for dialogue in tqdm(dialogues, desc="Creating dataset"):
+        history: List[str] = []
+        for turn in dialogue:
+            # custom function to create a dataset entry
+            dataset_entry = create_dialogue_dataset_entry(turn, history)
+            # custom function to create a dataset key
+            key = create_entry_key(turn)
+            if dataset_entry is not None:
+                dataset[key] = dataset_entry
+    return dataset
+
+
+def create_entry_key(turn: Dict[str, Any]) -> str:
+    """
+    Creates the entry key for a given entry by considering dialogue id
+    and turn id for the given turn.
+
+    Arguments
+    ---------
+    turn: dict, keys are str, values could be anything
+        A dict containing, the dialogue id, the turn id, the sequence,
+        and the mean length.
+    kwargs: any
+        Additional arguments for the current function.
+
+    Returns
+    -------
+    key: str
+        The key for the given turn.
+    """
+    dialogue_idx = turn["dialogue_idx"]
+    turn_idx = turn["turn_idx"]
+    return f"{dialogue_idx}_{turn_idx}"
+
+
+def save_dialogue_dataset(
+    dataset: Dict[str, Dict[str, Any]], save_file: str
+) -> None:
+    """
+    Saves the dialogue dataset at dst_folder/file_name as a json file.
+
+    Arguments
+    ---------
+    dataset: Dict[str, Dict[str, Any]]
+        Dataset, keys are str, values are dictionaries containing the
+        dialogue history, the system reply, and the mean length.
+    save_file: str
+        Path to the folder where the original MultiWOZ dataset is stored.
+    """
+    with open(save_file, "w") as f:
+        json.dump(dataset, f, indent=4)
diff --git a/recipes/REAL-M/sisnr-estimation/README.md b/recipes/REAL-M/sisnr-estimation/README.md
index 0c83058c0bcf7ba664890afa7a5aa796ee7cb95f..f918ff5b540602f76c45a9c7775aeb5f27712d96 100644
--- a/recipes/REAL-M/sisnr-estimation/README.md
+++ b/recipes/REAL-M/sisnr-estimation/README.md
@@ -8,7 +8,7 @@
 
 * The paper for the REAL-M dataset can be found on [this arxiv link](https://arxiv.org/pdf/2110.10812.pdf).
 
-* The model is trained with the LibriMix and WHAMR! datasets. You can download LibriMix by following the instructions [here](https://github.com/JorisCos/LibriMix). Instructions on WHAMR! can be found [here](https://wham.whisper.ai/)
+* The model is trained with the LibriMix and WHAMR! datasets. You can download LibriMix by following the instructions [here](https://github.com/JorisCos/LibriMix). Instructions on WHAMR! can be found [here](http://wham.whisper.ai/)
 
 # How to Run
 
diff --git a/recipes/REAL-M/sisnr-estimation/extra_requirements.txt b/recipes/REAL-M/sisnr-estimation/extra_requirements.txt
index 15c5e2d2271c19ac519496f20923c3ea3920c279..73fe73d2cc75c0b06a0f5464f54bbe21814755f2 100644
--- a/recipes/REAL-M/sisnr-estimation/extra_requirements.txt
+++ b/recipes/REAL-M/sisnr-estimation/extra_requirements.txt
@@ -1 +1 @@
-pyroomacoustics
+pyroomacoustics==0.1.4
diff --git a/recipes/REAL-M/sisnr-estimation/hparams/pool_sisnrestimator.yaml b/recipes/REAL-M/sisnr-estimation/hparams/pool_sisnrestimator.yaml
index 2cb35d116a6f61ead24abdcc4ba897e2ce06b112..c23c11c53524db10e8e8110202af01e00cfa611e 100644
--- a/recipes/REAL-M/sisnr-estimation/hparams/pool_sisnrestimator.yaml
+++ b/recipes/REAL-M/sisnr-estimation/hparams/pool_sisnrestimator.yaml
@@ -60,14 +60,14 @@ skip_prep: False
 ckpt_interval_minutes: 60
 
 # Experiment params
-auto_mix_prec: False # Set this to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32 # Set this to True for mixed precision
 
 # for the currently supported datasets (Libri2Mix, WHAMR!), this should be set 2
 num_spks: 2
 noprogressbar: False
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.0001
@@ -91,18 +91,39 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
 
 # Dataloader options
 dataloader_opts:
diff --git a/recipes/REAL-M/sisnr-estimation/train.py b/recipes/REAL-M/sisnr-estimation/train.py
index f650d3df41c6149899e3c52f3dd1aab3feddd830..0398daab1b2d00e1e24e06a27c0c0e2d022b71b0 100644
--- a/recipes/REAL-M/sisnr-estimation/train.py
+++ b/recipes/REAL-M/sisnr-estimation/train.py
@@ -15,12 +15,12 @@ import speechbrain as sb
 import speechbrain.nnet.schedulers as schedulers
 from speechbrain.utils.distributed import run_on_main
 from hyperpyyaml import load_hyperpyyaml
-from torch.cuda.amp import autocast
 import itertools as it
 from tqdm import tqdm
 import numpy as np
 import logging
 import csv
+from speechbrain.core import AMPConfig
 
 
 # Define training procedure
@@ -62,7 +62,7 @@ class Separation(sb.Brain):
 
                     if self.hparams.use_reverb_augment:
                         targets_rev = [
-                            self.hparams.reverb(targets[:, :, i], None)
+                            self.hparams.reverb(targets[:, :, i])
                             for i in range(self.hparams.num_spks)
                         ]
                         targets_rev = torch.stack(targets_rev, dim=-1)
@@ -83,7 +83,8 @@ class Separation(sb.Brain):
                         targets = targets[:, :min_len, :]
 
                 if self.hparams.use_wavedrop:
-                    mix = self.hparams.wavedrop(mix, mix_lens)
+                    mix = self.hparams.drop_chunk(mix, mix_lens)
+                    mix = self.hparams.drop_freq(mix)
 
                 if self.hparams.limit_training_signal_len:
                     mix, targets = self.cut_signals(mix, targets)
@@ -157,6 +158,8 @@ class Separation(sb.Brain):
 
     def fit_batch(self, batch):
         """Trains one batch"""
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
 
         if self.hparams.use_whamr_train:
             whamr_prob = torch.rand(1).item()
@@ -174,8 +177,48 @@ class Separation(sb.Brain):
         if self.hparams.num_spks == 3:
             targets.append(batch.s3_sig)
 
-        if self.auto_mix_prec:
-            with autocast():
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
+                    (
+                        predictions,
+                        snrhat,
+                        snr,
+                        snr_compressed,
+                    ) = self.compute_forward(
+                        mixture, targets, sb.Stage.TRAIN, noise
+                    )
+
+                    snr = snr.reshape(-1)
+                    loss = ((snr_compressed - snrhat).abs()).mean()
+
+                    if (
+                        loss.nelement() > 0
+                        and loss < self.hparams.loss_upper_lim
+                    ):  # the fix for computational problems
+
+                        self.scaler.scale(loss).backward()
+                        if self.hparams.clip_grad_norm >= 0:
+                            self.scaler.unscale_(self.optimizer)
+                            torch.nn.utils.clip_grad_norm_(
+                                self.modules.parameters(),
+                                self.hparams.clip_grad_norm,
+                            )
+                        self.scaler.step(self.optimizer)
+                        self.scaler.update()
+                    else:
+                        self.nonfinite_count += 1
+                        logger.info(
+                            "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                                self.nonfinite_count
+                            )
+                        )
+                        loss.data = torch.tensor(0).to(self.device)
+
+            else:
+                # get the oracle snrs, estimated snrs, and the source estimates
                 predictions, snrhat, snr, snr_compressed = self.compute_forward(
                     mixture, targets, sb.Stage.TRAIN, noise
                 )
@@ -186,16 +229,13 @@ class Separation(sb.Brain):
                 if (
                     loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
                 ):  # the fix for computational problems
-
-                    self.scaler.scale(loss).backward()
+                    loss.backward()
                     if self.hparams.clip_grad_norm >= 0:
-                        self.scaler.unscale_(self.optimizer)
                         torch.nn.utils.clip_grad_norm_(
                             self.modules.parameters(),
                             self.hparams.clip_grad_norm,
                         )
-                    self.scaler.step(self.optimizer)
-                    self.scaler.update()
+                    self.optimizer.step()
                 else:
                     self.nonfinite_count += 1
                     logger.info(
@@ -205,33 +245,6 @@ class Separation(sb.Brain):
                     )
                     loss.data = torch.tensor(0).to(self.device)
 
-        else:
-            # get the oracle snrs, estimated snrs, and the source estimates
-            predictions, snrhat, snr, snr_compressed = self.compute_forward(
-                mixture, targets, sb.Stage.TRAIN, noise
-            )
-
-            snr = snr.reshape(-1)
-            loss = ((snr_compressed - snrhat).abs()).mean()
-
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                loss.backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm
-                    )
-                self.optimizer.step()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
-                    )
-                )
-                loss.data = torch.tensor(0).to(self.device)
-
         self.optimizer.zero_grad()
 
         return loss.detach().cpu()
@@ -309,9 +322,7 @@ class Separation(sb.Brain):
             recombine = True
 
             for i in range(targets.shape[-1]):
-                new_target = self.hparams.speedperturb(
-                    targets[:, :, i], targ_lens
-                )
+                new_target = self.hparams.speed_perturb(targets[:, :, i])
                 new_targets.append(new_target)
                 if i == 0:
                     min_len = new_target.shape[-1]
@@ -643,7 +654,7 @@ if __name__ == "__main__":
             },
         )
 
-        hparams["reverb"] = sb.processing.speech_augmentation.AddReverb(
+        hparams["reverb"] = sb.augment.time_domain.AddReverb(
             os.path.join(hparams["save_folder"], "whamr_rirs.csv")
         )
 
@@ -752,43 +763,47 @@ if __name__ == "__main__":
         checkpointer=hparams["checkpointer"],
     )
 
-    from speechbrain.pretrained import SepformerSeparation as separator
-    from speechbrain.pretrained.interfaces import fetch
+    from speechbrain.inference.separation import (
+        SepformerSeparation as separator,
+    )
+    from speechbrain.utils.fetching import fetch
 
     all_separators = []
     for separator_model in hparams["separators_to_use"]:
+        savedir = hparams["output_folder"] + "/" + separator_model
+
         fetch(
-            separator_model + "_encoder.ckpt",
+            filename=separator_model + "_encoder.ckpt",
             source=hparams["separator_repo"],
-            savedir=separator_model,
+            savedir=savedir,
             save_filename="encoder.ckpt",
         )
 
         fetch(
-            separator_model + "_decoder.ckpt",
+            filename=separator_model + "_decoder.ckpt",
             source=hparams["separator_repo"],
-            savedir=separator_model,
+            savedir=savedir,
             save_filename="decoder.ckpt",
         )
 
         fetch(
-            separator_model + "_masknet.ckpt",
+            filename=separator_model + "_masknet.ckpt",
             source=hparams["separator_repo"],
-            savedir=separator_model,
+            savedir=savedir,
             save_filename="masknet.ckpt",
         )
 
         fetch(
-            separator_model + "_hyperparams.yaml",
+            filename=separator_model + "_hyperparams.yaml",
             source=hparams["separator_repo"],
-            savedir=separator_model,
+            savedir=savedir,
             save_filename="hyperparams.yaml",
         )
 
         separator_loaded = separator.from_hparams(
-            source=separator_model,
-            run_opts={"device": "cuda"},
-            savedir=separator_model,
+            source=savedir,
+            run_opts={"device": run_opts["device"]},
+            savedir=savedir,
         )
 
         all_separators.append(separator_loaded)
diff --git a/recipes/RescueSpeech/ASR/noise-robust/hparams/robust_asr_16k.yaml b/recipes/RescueSpeech/ASR/noise-robust/hparams/robust_asr_16k.yaml
index 7cda1b90d62750a658db395c0b315b86617c3ca2..10e8e58e4a05a72f94d4c2f9d63aada185b3e236 100644
--- a/recipes/RescueSpeech/ASR/noise-robust/hparams/robust_asr_16k.yaml
+++ b/recipes/RescueSpeech/ASR/noise-robust/hparams/robust_asr_16k.yaml
@@ -21,12 +21,11 @@ language: german
 pretrained_whisper_path: speechbrain/whisper_rescuespeech
 pretrained_enhance_path: speechbrain/sepformer_rescuespeech
 
-epochs_before_lr_drop: 0
+epochs_before_lr_drop: 2
 unfreeze_epoch: !ref <epochs_before_lr_drop> + 1
 frozen_models: [encoder, decoder, masknet, whisper]
 unfrozen_models: [masknet, whisper]
 
-
 # Dataset prep parameters
 data_folder: !PLACEHOLDER
 train_tsv_file: !ref <data_folder>/train.tsv
@@ -42,26 +41,24 @@ skip_prep: False
 # longer sentences certainly correspond to "open microphones".
 avoid_if_longer_than: 10.0
 
-## Model parameters- Enhance model
+## Model Parameters- Enhance model
 dereverberate: False
 save_audio: True
-resample: False
+sample_rate: 16000
 enhance_sample_rate: 16000
-lr_enhance: 0.00015
 limit_training_signal_len: False
 training_signal_len: 64000
-use_wavedrop: False
 use_speedperturb: True
 use_freq_domain: False
 use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-## Training parameters- ASR
+######################## Training Parameters ####################################- ASR
 number_of_epochs: 10
 lr_whisper: 0.00003
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 asr_sample_rate: 16000
 ckpt_interval_minutes: 30 # save checkpoint every N min
 checkpoint_avg: 5
@@ -71,8 +68,6 @@ checkpoint_avg: 5
 # Must be 6 per GPU to fit 16GB of VRAM
 batch_size: 2
 test_batch_size: 2
-dataloader_num_workers: 4
-test_num_workers: 4
 
 
 # These values are only used for the searchers.
@@ -105,31 +100,19 @@ test_loader_kwargs:
 
 # Loss weights
 sepformer_weight: 0.1
+asr_weight: 1
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
 
-# loss thresholding -- this thresholds the training loss
-threshold_byloss: True
-threshold: -30
-clip_grad_norm: 5
-loss_upper_lim: 999999  # this is the upper limit for an acceptable loss
-optimizer: !name:torch.optim.Adam
-    lr: !ref <lr_enhance>
-    weight_decay: 0
-
-# Functions and classes
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <enhance_sample_rate>
-    speeds: [95, 100, 105]
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
 
 enhance_model: !include:../models/sepformer.yaml
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <asr_sample_rate>
-    speeds: [95, 100, 105]
 
-whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
+whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
     source: !ref <whisper_hub>
     freeze: !ref <freeze_whisper>
     save_path: !ref <whisper_folder>
diff --git a/recipes/RescueSpeech/ASR/noise-robust/train.py b/recipes/RescueSpeech/ASR/noise-robust/train.py
index b35980e7f840dea5fce3ff992ded7f39590ba742..aed76d956adb9c7e2dccc605051cf6738aaeed68 100644
--- a/recipes/RescueSpeech/ASR/noise-robust/train.py
+++ b/recipes/RescueSpeech/ASR/noise-robust/train.py
@@ -51,10 +51,8 @@ class ASR(sb.core.Brain):
 
         predictions, clean = self.compute_forward_enhance(batch, stage)
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Enhanced signal is to be fed into ASR
+        wavs = predictions[0]
 
         # We compute the padding mask and replace the values with the pad_token_id
         # that the Whisper decoder expect to see.
@@ -72,9 +70,11 @@ class ASR(sb.core.Brain):
 
         hyps = None
         if stage == sb.Stage.VALID:
-            hyps, _ = self.hparams.valid_greedy_searcher(enc_out, wav_lens)
+            hyps, _, _, _ = self.hparams.valid_greedy_searcher(
+                enc_out, wav_lens
+            )
         elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_beam_searcher(enc_out, wav_lens)
+            hyps, _, _, _ = self.hparams.test_beam_searcher(enc_out, wav_lens)
 
         return predictions, clean, [log_probs, hyps, wav_lens]
 
@@ -115,9 +115,6 @@ class ASR(sb.core.Brain):
                     # fix the length of clean also
                     clean = clean[:, :min_len, :]
 
-                if self.hparams.use_wavedrop:
-                    noisy = self.hparams.wavedrop(noisy, noisy_lens)
-
                 if self.hparams.limit_training_signal_len:
                     noisy, clean = self.cut_signals(noisy, clean)
 
@@ -161,6 +158,8 @@ class ASR(sb.core.Brain):
         if stage != sb.Stage.TRAIN:
             tokens, tokens_lens = batch.tokens
 
+            hyps = [hyp[0] if len(hyp) > 0 else [] for hyp in hyps]
+
             # Decode token terms to words
             predicted_words = self.tokenizer.batch_decode(
                 hyps, skip_special_tokens=True
@@ -214,13 +213,19 @@ class ASR(sb.core.Brain):
             self.compute_objectives_enhance(predictions, clean)
             * self.hparams.sepformer_weight
         )
-        loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
+        loss = (
+            self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
+            * self.hparams.asr_weight
+        )
         loss = torch.add(enhance_loss, loss)
 
-        loss.backward()
+        if loss.requires_grad:
+            loss.backward()
 
-        if self.check_gradients(loss):
-            self.optimizer.step()
+        torch.nn.utils.clip_grad_norm_(
+            self.modules.parameters(), self.max_grad_norm
+        )
+        self.optimizer.step()
         self.optimizer.zero_grad()
 
         return loss.detach()
@@ -230,7 +235,15 @@ class ASR(sb.core.Brain):
         predictions, clean, outputs = self.compute_forward(batch, stage=stage)
 
         with torch.no_grad():
-            loss = self.compute_objectives(outputs, batch, stage=stage)
+            enhance_loss = (
+                self.compute_objectives_enhance(predictions, clean)
+                * self.hparams.sepformer_weight
+            )
+            loss = (
+                self.compute_objectives(outputs, batch, stage=stage)
+                * self.hparams.asr_weight
+            )
+            loss = torch.add(enhance_loss, loss)
 
         if stage != sb.Stage.TRAIN:
             self.pesq_metric.append(
@@ -278,7 +291,6 @@ class ASR(sb.core.Brain):
                     self.modules[model].train()
                     for p in self.modules[model].parameters():
                         p.requires_grad = True  # Model's weight will be updated
-                    print(model)
                 else:
                     self.modules[model].eval()
                     for p in self.modules[model].parameters():
@@ -354,9 +366,7 @@ class ASR(sb.core.Brain):
             recombine = True
 
             for i in range(clean.shape[-1]):
-                new_target = self.hparams.speedperturb(
-                    clean[:, :, i], targ_lens
-                )
+                new_target = self.hparams.speed_perturb(clean[:, :, i])
                 new_clean.append(new_target)
                 if i == 0:
                     min_len = new_target.shape[-1]
@@ -522,9 +532,14 @@ class ASR(sb.core.Brain):
                         )
 
                     # Write enhanced wavs for sanity check
-                    self.save_audio(
-                        snt_id[0], batch.noisy_sig, clean, predictions[0], batch
-                    )
+                    if self.hparams.save_audio:
+                        self.save_audio(
+                            snt_id[0],
+                            batch.noisy_sig,
+                            clean,
+                            predictions[0],
+                            batch,
+                        )
 
                     psq_mode = (
                         "wb"
@@ -755,7 +770,6 @@ if __name__ == "__main__":
     # CLI:
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/RescueSpeech/README.md b/recipes/RescueSpeech/README.md
index 0fa0aa0ce08841cbdc389249e35c47d8a8313180..b58a3c575e61397d1afbcf028fb2e5f1c664e188 100755
--- a/recipes/RescueSpeech/README.md
+++ b/recipes/RescueSpeech/README.md
@@ -37,7 +37,7 @@ During training, both speech enhancement and ASR is kept unfrozen- i.e. both ASR
 
 | Model | SISNRi | SDRi | PESQ   | STOI  | *WER*   |
 |------ |--------|-------|-------|-------|----   |
-| Whisper (`large-v2`)| 7.334 | 7.871 | 2.085 | 0.857 | **24.20** |
+| Whisper (`large-v2`)| 7.482 | 8.011 | 2.083 | 0.854 | **45.29** |
 
 
 ## Fine-tuned models
@@ -49,7 +49,7 @@ During training, both speech enhancement and ASR is kept unfrozen- i.e. both ASR
 |---|----------------|------------------------------------------------|------------------------------------------------|
 | 1. | Whisper ASR    | [HuggingFace](https://huggingface.co/speechbrain/whisper_rescuespeech)             | [Dropbox](https://www.dropbox.com/sh/dgmgi0b3bfxlfo4/AAAo3EYPXUEMZRTdRDzhw4lea?dl=)             |
 | 2. | Sepformer Enhancement   | [HuggingFace](https://huggingface.co/speechbrain/sepformer_rescuespeech)            | [Dropbox](https://www.dropbox.com/sh/edrna82oarivkzl/AACsiGQXnbAYa_bfTJzjY23qa?dl=0)            |
-| 3. | Sepformer +  Whisper ASR  (fine-tuned)  |  [HuggingFace](https://huggingface.co/sangeet2020/noisy-whisper-resucespeech)            | [Dropbox](https://www.dropbox.com/sh/nk55xm0saa5qfly/AAC6yHgJnQalFMePgKFZqVfPa?dl=0)            |
+| 3. | Sepformer +  Whisper ASR  (fine-tuned)  |  [HuggingFace](https://huggingface.co/sangeet2020/noisy-whisper-resucespeech)            | [Dropbox](https://www.dropbox.com/sh/kqs2ld14fm20cxl/AACiobSLdNtXhm-4Y3IIbTeia?dl=0)            |
 
 
 # **About SpeechBrain**
diff --git a/recipes/RescueSpeech/rescuespeech_prepare.py b/recipes/RescueSpeech/rescuespeech_prepare.py
index f29828ebb26d06283214e489ab47e6ee8b0e18ec..a78a382bbc67c1efa9f1e07c486a48aced51d859 100755
--- a/recipes/RescueSpeech/rescuespeech_prepare.py
+++ b/recipes/RescueSpeech/rescuespeech_prepare.py
@@ -274,12 +274,6 @@ def create_asr_csv(
                 snr_level = item.replace("snr", "")
                 break
 
-        # Setting torchaudio backend to sox-io (needed to read mp3 files)
-        if torchaudio.get_audio_backend() != "sox_io":
-            logger.warning("This recipe needs the sox-io backend of torchaudio")
-            logger.warning("The torchaudio backend is changed to sox_io")
-            torchaudio.set_audio_backend("sox_io")
-
         # Reading the signal (to retrieve duration in seconds)
         if os.path.isfile(clean_fp):
             info = torchaudio.info(clean_fp)
diff --git a/recipes/SLURP/NLU/hparams/train.yaml b/recipes/SLURP/NLU/hparams/train.yaml
index b5fdf700def491da40a2452f1db25637a557a015..7d88d62a9baa3b307c7a9c1c497c67f606146521 100644
--- a/recipes/SLURP/NLU/hparams/train.yaml
+++ b/recipes/SLURP/NLU/hparams/train.yaml
@@ -28,14 +28,14 @@ asr_tokenizer_file: https://www.dropbox.com/s/o7gnouwdoqchotj/1000_unigram.model
 slu_tokenizer_file: https://www.dropbox.com/s/tmwq12r5vgcsif9/58_unigram.model?dl=1
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 batch_size: 16
 lr: 0.0003
 # token_type: unigram # ["unigram", "bpe", "char"]
 sorting: random
 
-# Model parameters
+####################### Model Parameters #######################################
 # sample_rate: 1600
 emb_size: 128
 dec_neurons: 512
@@ -128,9 +128,8 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     beam_size: !ref <slu_beam_size>
     eos_threshold: !ref <eos_threshold>
     temperature: !ref <temperature>
-    using_max_attn_shift: False
     max_attn_shift: 30
-    coverage_penalty: 0.
+    using_max_attn_shift: False
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
@@ -148,10 +147,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         scheduler: !ref <lr_annealing>
         counter: !ref <epoch_counter>
 
-# augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-#    sample_rate: !ref <sample_rate>
-#    speeds: [95, 100, 105]
-
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
 
diff --git a/recipes/SLURP/NLU/train.py b/recipes/SLURP/NLU/train.py
index 0ec3b1f67b01c5ab7145acfe2ec66fa35020f6cc..20ec91bbd527f8e83a3d9010b54351b1a83c6e4c 100644
--- a/recipes/SLURP/NLU/train.py
+++ b/recipes/SLURP/NLU/train.py
@@ -40,24 +40,20 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
+        p_tokens = None
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
             return p_seq, transcript_tokens_lens
         else:
-            p_tokens, scores = self.hparams.beam_searcher(
+            p_tokens, _, _, _ = self.hparams.beam_searcher(
                 encoder_out, transcript_tokens_lens
             )
+
             return p_seq, transcript_tokens_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
             p_seq, transcript_tokens_lens = predictions
         else:
             p_seq, transcript_tokens_lens, predicted_tokens = predictions
@@ -76,9 +72,7 @@ class SLU(sb.Brain):
         # (No ctc loss)
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 slu_tokenizer.decode_ids(utt_seq).split(" ")
@@ -127,26 +121,8 @@ class SLU(sb.Brain):
             print(" ".join(target_semantics[i]).replace("|", ","))
             print("")
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -293,7 +269,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -330,7 +305,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Brain class initialization
     slu_brain = SLU(
diff --git a/recipes/SLURP/README.md b/recipes/SLURP/README.md
index 0c5c0b16da5a6806c5ddb93ed30c9408ec8eed04..a8fb9b994aa5612ebfd326e40fd7a4e7c6aeb591 100644
--- a/recipes/SLURP/README.md
+++ b/recipes/SLURP/README.md
@@ -48,8 +48,8 @@ The following results were obtained on a 48 GB RTX 8000 (the recipe has also bee
 
 | Model	| scenario (accuracy) | action (accuracy) | intent (accuracy) | Word-F1 | Char-F1 | SLU-F1 | Training time | Model link |
 |:------:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|
-| Direct | 81.73 | 77.11 | 75.05 | 61.24 | 65.42 | 63.26 | 1 hour per epoch | https://www.dropbox.com/sh/c2rnjads9gfd7k2/AADtVZi616qH_jb_owKK7b9Ra?dl=0 |
-| Direct (HuBert) | 91.24 | 88.47 | 87.54 | 72.93 | 77.40 | 75.10 | 4 hours per epoch | https://www.dropbox.com/sh/uygyiir8rfajcmu/AACXbjhM34ZDy2UprWfg-uyVa?dl=0 |
+| Direct | 81.73 | 77.11 | 75.05 | 61.24 | 65.42 | 63.26 | 1 hour per epoch | https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 |
+| Direct (HuBert) | 91.24 | 88.47 | 87.54 | 72.93 | 77.40 | 75.10 | 4 hours per epoch | https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 |
 
 | Model	| scenario (accuracy) | action (accuracy) | intent (accuracy) | Training time |
 |:---:|:-----:|:-----:|:-----:|:-----:|
diff --git a/recipes/SLURP/Tokenizer/hparams/tokenizer_bpe58.yaml b/recipes/SLURP/Tokenizer/hparams/tokenizer_bpe58.yaml
index 51f805b078449a7594c1a205be73af6c6778e147..bf935024a739e665dad65a5260a57f393f36aa48 100644
--- a/recipes/SLURP/Tokenizer/hparams/tokenizer_bpe58.yaml
+++ b/recipes/SLURP/Tokenizer/hparams/tokenizer_bpe58.yaml
@@ -14,7 +14,7 @@ train_csv: !ref <output_folder>/train-type=direct.csv
 valid_csv: !ref <output_folder>/devel-type=direct.csv
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 58  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
diff --git a/recipes/SLURP/Tokenizer/train.py b/recipes/SLURP/Tokenizer/train.py
index 2479f47569d1d470c5a41b1e49d2207efddbf74d..9e198ea3ea3c71ef6ee2d558993837e9a4591459 100644
--- a/recipes/SLURP/Tokenizer/train.py
+++ b/recipes/SLURP/Tokenizer/train.py
@@ -25,7 +25,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/SLURP/direct/hparams/train.yaml b/recipes/SLURP/direct/hparams/train.yaml
index 004e2fde28170b9cba2ac9f69c3b9654ac801fdc..038d2e59ea32bf9322087d0df6c4fe4cb3380878 100644
--- a/recipes/SLURP/direct/hparams/train.yaml
+++ b/recipes/SLURP/direct/hparams/train.yaml
@@ -16,18 +16,25 @@ save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 test_wer_file: !ref <output_folder>/wer_test_real.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
 # Data files
 # The SLURP dataset will be automatically downloaded in the specified data_folder
 data_folder: !PLACEHOLDER # e.g, /localscratch/SLURP
-data_folder_rirs: !ref <data_folder>
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
 train_splits: ["train_synthetic", "train_real"]
-csv_train: !ref <output_folder>/train-type=direct.csv
-csv_valid: !ref <output_folder>/devel-type=direct.csv
-csv_test: !ref <output_folder>/test-type=direct.csv
+csv_train: !ref <save_folder>/train-type=direct.csv
+csv_valid: !ref <save_folder>/devel-type=direct.csv
+csv_test: !ref <save_folder>/test-type=direct.csv
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
 tokenizer_file: https://www.dropbox.com/s/tmwq12r5vgcsif9/58_unigram.model?dl=1
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 batch_size: 16
 lr: 0.0003
@@ -35,7 +42,7 @@ lr: 0.0003
 sorting: random
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Model parameters
+####################### Model Parameters #######################################
 sample_rate: 16000
 emb_size: 128
 dec_neurons: 512
@@ -52,17 +59,86 @@ slu_beam_size: 80
 eos_threshold: 1.5
 temperature: 1.25
 
+num_workers: 4
 dataloader_opts:
+    num_workers: !ref <num_workers>
     batch_size: !ref <batch_size>
     shuffle: True
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Models
-asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
-    source: speechbrain/asr-crdnn-rnnlm-librispeech
-    run_opts: {"device":"cuda:0"}
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [90, 95, 105, 110]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 3
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    shuffle_augmentations: True
+    min_augmentations: 1
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+asr_model_source: speechbrain/asr-crdnn-rnnlm-librispeech
 
 slu_enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <ASR_encoder_dim>]
@@ -94,20 +170,11 @@ seq_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dec_neurons>
     n_neurons: !ref <output_neurons>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-
 modules:
     slu_enc: !ref <slu_enc>
     output_emb: !ref <output_emb>
     dec: !ref <dec>
     seq_lin: !ref <seq_lin>
-    env_corrupt: !ref <env_corrupt>
 
 model: !new:torch.nn.ModuleList
     - [!ref <slu_enc>, !ref <output_emb>,
@@ -133,9 +200,8 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     beam_size: !ref <slu_beam_size>
     eos_threshold: !ref <eos_threshold>
     temperature: !ref <temperature>
-    using_max_attn_shift: False
     max_attn_shift: 30
-    coverage_penalty: 0.
+    using_max_attn_shift: False
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
@@ -153,9 +219,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         scheduler: !ref <lr_annealing>
         counter: !ref <epoch_counter>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
 
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
diff --git a/recipes/SLURP/direct/hparams/train_with_wav2vec2.yaml b/recipes/SLURP/direct/hparams/train_with_wav2vec2.yaml
index 31b22993ea5d07560518123067e03521708e13a4..84222db5f430c3c4df1e5446279b48cd5ceb3eb5 100644
--- a/recipes/SLURP/direct/hparams/train_with_wav2vec2.yaml
+++ b/recipes/SLURP/direct/hparams/train_with_wav2vec2.yaml
@@ -32,7 +32,7 @@ skip_prep: False
 # URL for the wav2vec2 model, you can change to benchmark diffrenet models
 wav2vec2_hub: "facebook/hubert-base-ls960"
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 35
 batch_size: 6
 lr: 0.0003
@@ -47,7 +47,7 @@ freeze_wav2vec2: False
 #set to true to freeze the CONV part of the wav2vec2 model
 freeze_wav2vec2_conv: True
 
-# Model parameters
+####################### Model Parameters #######################################
 sample_rate: 16000
 emb_size: 128
 dec_neurons: 512
@@ -71,7 +71,7 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
 # Models
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec2>
@@ -96,10 +96,39 @@ seq_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dec_neurons>
     n_neurons: !ref <output_neurons>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
 modules:
     wav2vec2: !ref <wav2vec2>
     output_emb: !ref <output_emb>
@@ -132,7 +161,6 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     temperature: !ref <temperature>
     using_max_attn_shift: False
     max_attn_shift: 30
-    coverage_penalty: 0.
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
diff --git a/recipes/SLURP/direct/train.py b/recipes/SLURP/direct/train.py
index ad73bc1fef48ee6c5d0a8d9c300a41a160ed729d..63a4cc48b3c384b7be22adf697b4cb5db7de5d32 100644
--- a/recipes/SLURP/direct/train.py
+++ b/recipes/SLURP/direct/train.py
@@ -33,16 +33,13 @@ class SLU(sb.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, tokens_bos_lens = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "env_corrupt"):
-                wavs_noise = self.hparams.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-                tokens_bos_lens = torch.cat([tokens_bos_lens, tokens_bos_lens])
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
+            tokens_bos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_bos_lens
+            )
 
         # ASR encoder forward pass
         with torch.no_grad():
@@ -60,34 +57,28 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            return p_seq, wav_lens
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
+            return p_seq, wav_lens, None
         else:
-            p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens)
+            p_tokens, _, _, _ = self.hparams.beam_searcher(
+                encoder_out, wav_lens
+            )
+
             return p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            p_seq, wav_lens = predictions
-        else:
-            p_seq, wav_lens, predicted_tokens = predictions
+        p_seq, wav_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.hparams, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
             )
 
         loss_seq = self.hparams.seq_cost(
@@ -97,9 +88,7 @@ class SLU(sb.Brain):
         # (No ctc loss)
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 tokenizer.decode_ids(utt_seq).split(" ")
@@ -160,26 +149,8 @@ class SLU(sb.Brain):
             print(" ".join(target_semantics[i]).replace("|", ","))
             print("")
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -307,7 +278,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -332,13 +302,23 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
+    run_on_main(hparams["prepare_rir_data"])
 
     # here we create the datasets objects as well as tokenization and encoding
     (train_set, valid_set, test_set, tokenizer,) = dataio_prepare(hparams)
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
+
+    # Download pretrained ASR model
+    from speechbrain.inference.ASR import EncoderDecoderASR
+
+    hparams["asr_model"] = EncoderDecoderASR.from_hparams(
+        source=hparams["asr_model_source"],
+        run_opts={"device": run_opts["device"]},
+    )
 
     # Brain class initialization
     slu_brain = SLU(
diff --git a/recipes/SLURP/direct/train_with_wav2vec2.py b/recipes/SLURP/direct/train_with_wav2vec2.py
index 0680f65959d20da0e8c92b9af79b936b89af99d7..c0e97a2b02fc181d3e518e43ccdf84ec4c437a7e 100644
--- a/recipes/SLURP/direct/train_with_wav2vec2.py
+++ b/recipes/SLURP/direct/train_with_wav2vec2.py
@@ -32,10 +32,10 @@ class SLU(sb.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, tokens_bos_lens = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
 
         #  encoder forward pass
         wav2vec2_out = self.modules.wav2vec2(wavs, wav_lens)
@@ -49,24 +49,19 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
             return p_seq, wav_lens
         else:
-            p_tokens, scores = self.hparams.beam_searcher(
-                wav2vec2_out, wav_lens
+            hyps, _, _, _ = self.hparams.beam_searcher(
+                wav2vec2_out.detach(), wav_lens
             )
-            return p_seq, wav_lens, p_tokens
+
+            return p_seq, wav_lens, hyps
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
             p_seq, wav_lens = predictions
         else:
             p_seq, wav_lens, predicted_tokens = predictions
@@ -74,15 +69,20 @@ class SLU(sb.Brain):
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
 
+        # Label Augmentation
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
+            )
+
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
         )
 
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 tokenizer.decode_ids(utt_seq).split(" ")
@@ -137,28 +137,8 @@ class SLU(sb.Brain):
             print(" ".join(target_semantics[i]).replace("|", ","))
             print("")
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.wav2vec2_optimizer.step()
-            self.optimizer.step()
-        self.wav2vec2_optimizer.zero_grad()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -220,9 +200,10 @@ class SLU(sb.Brain):
             )
             self.checkpointer.add_recoverable("optimizer", self.optimizer)
 
-    def zero_grad(self, set_to_none=False):
-        self.wav2vec2_optimizer.zero_grad(set_to_none)
-        self.optimizer.zero_grad(set_to_none)
+        self.optimizers_dict = {
+            "wav2vec_optimizer": self.wav2vec2_optimizer,
+            "model_optimizer": self.optimizer,
+        }
 
 
 def dataio_prepare(hparams):
@@ -314,7 +295,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -345,7 +325,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Move the wav2vec2
     hparams["wav2vec2"] = hparams["wav2vec2"].to(run_opts["device"])
diff --git a/recipes/Switchboard/ASR/CTC/hparams/train_with_wav2vec.yaml b/recipes/Switchboard/ASR/CTC/hparams/train_with_wav2vec.yaml
index 8f5cb1a990dc2d89c8178e81dad6f00f86183f69..7741680bdbb25e81dd9c1f0216e8d275d9af44aa 100644
--- a/recipes/Switchboard/ASR/CTC/hparams/train_with_wav2vec.yaml
+++ b/recipes/Switchboard/ASR/CTC/hparams/train_with_wav2vec.yaml
@@ -49,12 +49,12 @@ test_csv:
   - !ref <output_folder>/test_callhome.csv
   - !ref <output_folder>/test.csv
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 lr: 1.0
 lr_wav2vec: 0.0001
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 ckpt_interval_minutes: 30  # save checkpoint every N min
 
@@ -74,7 +74,7 @@ test_dataloader_options:
 token_type: unigram  # ["unigram", "bpe", "char"]
 character_coverage: 1.0
 
-# Model parameters
+####################### Model Parameters #######################################
 wav2vec_output_dim: 1024
 dnn_neurons: 1024
 freeze_wav2vec: False
@@ -89,16 +89,61 @@ blank_index: 0
 bos_index: 1
 eos_index: 2
 
+# Decoding parameters
+test_searcher: !name:speechbrain.decoders.CTCBeamSearcher
+beam_size: 143
+beam_prune_logp: -12.0
+token_prune_min_logp: -1.2
+prune_history: True
+topk: 1
+alpha: 0.8
+beta: 1.2
+# can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
+# It can either be a .bin or .arpa ; note: .arpa is much slower at loading
+# If you don't want to use an LM, comment it out or set it to null
+kenlm_model_path: null
+
 #
 # Functions and classes
 #
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
   limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-  sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+  orig_freq: !ref <sample_rate>
   speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+  drop_freq_low: 0
+  drop_freq_high: 1
+  drop_freq_count_low: 1
+  drop_freq_count_high: 3
+  drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+  drop_length_low: 1000
+  drop_length_high: 2000
+  drop_count_low: 1
+  drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+  concat_original: True
+  min_augmentations: 3
+  max_augmentations: 3
+  augment_prob: 1.0
+  augmentations: [
+    !ref <speed_perturb>,
+    !ref <drop_freq>,
+    !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 enc: !new:speechbrain.nnet.containers.Sequential
   input_shape: [null, null, !ref <wav2vec_output_dim>]
   linear1: !name:speechbrain.nnet.linear.Linear
@@ -121,7 +166,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
   bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
   activation3: !new:torch.nn.LeakyReLU
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
   source: !ref <wav2vec2_hub>
   output_norm: True
   freeze: !ref <freeze_wav2vec>
diff --git a/recipes/Switchboard/ASR/CTC/train_with_wav2vec.py b/recipes/Switchboard/ASR/CTC/train_with_wav2vec.py
index 6cdf7e18a83839456951906d85302925414cb07e..6bd01bc34960ae48ff0c55b5dac26a6e5a1438dc 100644
--- a/recipes/Switchboard/ASR/CTC/train_with_wav2vec.py
+++ b/recipes/Switchboard/ASR/CTC/train_with_wav2vec.py
@@ -49,7 +49,6 @@ class ASR(sb.core.Brain):
         hparams=None,
         run_opts=None,
         checkpointer=None,
-        profiler=None,
         normalize_fn=None,
     ):
 
@@ -61,7 +60,6 @@ class ASR(sb.core.Brain):
             hparams=hparams,
             run_opts=run_opts,
             checkpointer=checkpointer,
-            profiler=profiler,
         )
 
     def compute_forward(self, batch, stage):
@@ -69,12 +67,11 @@ class ASR(sb.core.Brain):
 
         batch = batch.to(self.device)
         wavs, wav_lens = batch.sig
-        tokens_bos, _ = batch.tokens_bos
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
 
         # Forward pass
         feats = self.modules.wav2vec2(wavs, wav_lens)
@@ -92,9 +89,14 @@ class ASR(sb.core.Brain):
         ids = batch.id
         tokens, tokens_lens = batch.tokens
 
+        # Label Augmentation
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens = self.hparams.wav_augment.replicate_labels(tokens)
+            tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens)
+
         loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
 
-        if stage != sb.Stage.TRAIN:
+        if stage == sb.Stage.VALID:
             # Decode token terms to words
             sequence = sb.decoders.ctc_greedy_decode(
                 p_ctc, wav_lens, blank_id=self.hparams.blank_index
@@ -102,6 +104,13 @@ class ASR(sb.core.Brain):
 
             predicted_words = self.tokenizer(sequence, task="decode_from_list")
 
+        elif stage == sb.Stage.TEST:
+            # Decode token terms to words
+            sequence = test_searcher(p_ctc, wav_lens)
+            predicted_words = [hyp[0].text.split(" ") for hyp in sequence]
+
+        if stage != sb.Stage.TRAIN:
+
             # Convert indices to words
             target_words = undo_padding(tokens, tokens_lens)
             target_words = self.tokenizer(target_words, task="decode_from_list")
@@ -117,53 +126,6 @@ class ASR(sb.core.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        if self.auto_mix_prec:
-
-            if not self.hparams.wav2vec2.freeze:
-                self.wav2vec_optimizer.zero_grad()
-            self.model_optimizer.zero_grad()
-
-            with torch.cuda.amp.autocast():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-
-            self.scaler.scale(loss).backward()
-            if not self.hparams.wav2vec2.freeze:
-                self.scaler.unscale_(self.wav2vec_optimizer)
-            self.scaler.unscale_(self.model_optimizer)
-
-            if self.check_gradients(loss):
-                if not self.hparams.wav2vec2.freeze:
-                    self.scaler.step(self.wav2vec_optimizer)
-                self.scaler.step(self.model_optimizer)
-
-            self.scaler.update()
-        else:
-            outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            loss.backward()
-
-            if self.check_gradients(loss):
-                if not self.hparams.wav2vec2.freeze:
-                    self.wav2vec_optimizer.step()
-                self.model_optimizer.step()
-
-            if not self.hparams.wav2vec2.freeze:
-                self.wav2vec_optimizer.zero_grad()
-            self.model_optimizer.zero_grad()
-
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -219,6 +181,8 @@ class ASR(sb.core.Brain):
     def init_optimizers(self):
         "Initializes the wav2vec2 optimizer and model optimizer"
 
+        self.optimizers_dict = {}
+
         # If the wav2vec encoder is unfrozen, we create the optimizer
         if not self.hparams.wav2vec2.freeze:
             self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
@@ -229,14 +193,27 @@ class ASR(sb.core.Brain):
                     "wav2vec_opt", self.wav2vec_optimizer
                 )
 
+            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
+
         self.model_optimizer = self.hparams.model_opt_class(
             self.hparams.model.parameters()
         )
 
+        self.optimizers_dict["model_optimizer"] = self.model_optimizer
+
         if self.checkpointer is not None:
             self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
 
 
+def freeze_optimizers(self, optimizers):
+    """Freezes the wav2vec2 optimizer according to the warmup steps"""
+    valid_optimizers = {}
+    if not self.hparams.wav2vec2.freeze:
+        valid_optimizers["wav2vec_optimizer"] = optimizers["wav2vec_optimizer"]
+    valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
+    return valid_optimizers
+
+
 # Define custom data procedure
 def dataio_prepare(hparams, tokenizer):
     """This function prepares the datasets to be used in the brain class.
@@ -363,7 +340,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -423,8 +399,22 @@ if __name__ == "__main__":
         normalize_fn=normalize_fn,
     )
 
-    # Adding objects to trainer.
     asr_brain.tokenizer = tokenizer
+    vocab_list = [
+        tokenizer.sp.id_to_piece(i) for i in range(tokenizer.sp.vocab_size())
+    ]
+    test_searcher = hparams["test_searcher"](
+        blank_index=hparams["blank_index"],
+        vocab_list=vocab_list,
+        alpha=hparams["alpha"],
+        beta=hparams["beta"],
+        beam_size=hparams["beam_size"],
+        beam_prune_logp=hparams["beam_prune_logp"],
+        token_prune_min_logp=hparams["token_prune_min_logp"],
+        prune_history=hparams["prune_history"],
+        topk=hparams["topk"],
+        kenlm_model_path=hparams.get("kenlm_model_path"),
+    )
 
     # Training
     asr_brain.fit(
diff --git a/recipes/Switchboard/ASR/seq2seq/hparams/train_BPE_2000.yaml b/recipes/Switchboard/ASR/seq2seq/hparams/train_BPE_2000.yaml
index 5ee0a84807f49b0fc4b569712ada8b225f284359..743467bcf2ffd6b0ceacb7fc873259b8ddb8f2c1 100644
--- a/recipes/Switchboard/ASR/seq2seq/hparams/train_BPE_2000.yaml
+++ b/recipes/Switchboard/ASR/seq2seq/hparams/train_BPE_2000.yaml
@@ -27,10 +27,11 @@ tokenizer_file: !ref <pretrained_tokenizer_path>/tokenizer.ckpt
 
 # Set the local path to the Switchboard dataset (e.g. /nfs/data/swbd) here.
 data_folder: !PLACEHOLDER
-# noise/ris dataset will automatically be downloaded
-# Set the location to store noisy data for augmentation here.
-data_folder_rirs: rirs-noises
 
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
 
 # Note that the test set will be created separately using the
 # Hub5/eval2000 dataset
@@ -46,17 +47,17 @@ normalize_words: True
 # allowed to appear in the training data
 max_utt: 300
 ckpt_interval_minutes: 15 # save checkpoint every N min
-train_csv: !ref <output_folder>/train.csv
-valid_csv: !ref <output_folder>/dev.csv
+train_csv: !ref <save_folder>/train.csv
+valid_csv: !ref <save_folder>/dev.csv
 # The test data is split into the full test set (test.csv),
 # the Switchboard portion of the data (test_swbd.csv),
 # and the Callhome portion of the data (test_callhome.csv).
 test_csv:
-   - !ref <output_folder>/test_swbd.csv
-   - !ref <output_folder>/test_callhome.csv
-   - !ref <output_folder>/test.csv
+   - !ref <save_folder>/test_swbd.csv
+   - !ref <save_folder>/test_callhome.csv
+   - !ref <save_folder>/test.csv
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 number_of_ctc_epochs: 5
 batch_size: 10
@@ -66,15 +67,20 @@ sorting: ascending
 dynamic_batching: False
 
 # dynamic batching parameters, if used
+feats_hop_size: 0.01
+max_batch_length: 20000 # in terms of frames
+shuffle: True
+batch_ordering: random
+num_buckets: 20
+
 dynamic_batch_sampler:
-   feats_hop_size: 0.01
-   max_batch_len: 20000 # in terms of frames
-   shuffle_ex: True
-   batch_ordering: random
-   num_buckets: 20
+   max_batch_length: !ref <max_batch_length>
+   shuffle: !ref <shuffle>
+   batch_ordering: !ref <batch_ordering>
+   num_buckets: !ref <num_buckets>
 
 # Feature parameters
-sample_rate: 16000
+sample_rate: 8000
 n_fft: 400
 n_mels: 40
 
@@ -84,16 +90,20 @@ opt_class: !name:torch.optim.Adadelta
    eps: 1.e-8
 
 # Dataloader options
+num_workers: 4
 train_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
 valid_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
 test_dataloader_opts:
+   num_workers: !ref <num_workers>
    batch_size: !ref <batch_size>
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -125,6 +135,57 @@ max_attn_shift: 240
 ctc_weight_decode: 0.3
 coverage_penalty: 1.8
 temperature: 1.25
+scorer_beam_scale: 0.1
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+   URL: !ref <NOISE_DATASET_URL>
+   dest_folder: !ref <data_folder_noise>
+   ext: wav
+   csv_file: !ref <noise_annotation>
+
+############################## Augmentations ###################################
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+   csv_file: !ref <noise_annotation>
+   snr_low: 0
+   snr_high: 15
+   noise_sample_rate: !ref <sample_rate>
+   clean_sample_rate: !ref <sample_rate>
+   num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+   drop_freq_low: 0
+   drop_freq_high: 1
+   drop_freq_count_low: 1
+   drop_freq_count_high: 3
+   drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+   drop_length_low: 1000
+   drop_length_high: 2000
+   drop_count_low: 1
+   drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   concat_original: True
+   min_augmentations: 4
+   max_augmentations: 4
+   augment_prob: 1.0
+   augmentations: [
+      !ref <add_noise>,
+      !ref <speed_perturb>,
+      !ref <drop_freq>,
+      !ref <drop_chunk>]
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>
@@ -137,17 +198,7 @@ compute_features: !new:speechbrain.lobes.features.Fbank
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mels>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-   openrir_folder: !ref <data_folder_rirs>
-   babble_prob: 0.0
-   reverb_prob: 0.0
-   noise_prob: 1.0
-   noise_snr_low: 0
-   noise_snr_high: 15
-
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-   sample_rate: !ref <sample_rate>
-   speeds: [95, 100, 105]
+############################## Models ##########################################
 
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
    input_shape: [null, null, !ref <n_mels>]
@@ -227,45 +278,55 @@ modules:
    ctc_lin: !ref <ctc_lin>
    seq_lin: !ref <seq_lin>
    normalize: !ref <normalize>
-   env_corrupt: !ref <env_corrupt>
    lm_model: !ref <lm_model>
 
 model: !new:torch.nn.ModuleList
    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
 
-valid_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+   eos_index: !ref <eos_index>
+   blank_index: !ref <blank_index>
+   ctc_fc: !ref <ctc_lin>
+
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+   vocab_size: !ref <output_neurons>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+   full_scorers: [!ref <coverage_scorer>, !ref <ctc_scorer>]
+   weights:
+      coverage: !ref <coverage_penalty>
+      ctc: !ref <ctc_weight_decode>
+   scorer_beam_scale: !ref <scorer_beam_scale>
+
+test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
-   beam_size: !ref <valid_beam_size>
+   beam_size: !ref <test_beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
+   scorer: !ref <scorer>
    temperature: !ref <temperature>
 
-test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
+valid_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
-   ctc_linear: !ref <ctc_lin>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
-   blank_index: !ref <blank_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
-   beam_size: !ref <test_beam_size>
+   beam_size: !ref <valid_beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
-   coverage_penalty: !ref <coverage_penalty>
-   ctc_weight: !ref <ctc_weight_decode>
+   scorer: !ref <scorer>
    temperature: !ref <temperature>
 
 lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
diff --git a/recipes/Switchboard/ASR/seq2seq/train.py b/recipes/Switchboard/ASR/seq2seq/train.py
index b4bdf1c3a918fc78e7e54f3d0022dbd68d99ff91..8af7d3de69d1fc66a75dfe6e1ac88f5b673855a8 100644
--- a/recipes/Switchboard/ASR/seq2seq/train.py
+++ b/recipes/Switchboard/ASR/seq2seq/train.py
@@ -55,7 +55,6 @@ class ASR(sb.Brain):
         hparams=None,
         run_opts=None,
         checkpointer=None,
-        profiler=None,
         normalize_fn=None,
     ):
 
@@ -67,7 +66,6 @@ class ASR(sb.Brain):
             hparams=hparams,
             run_opts=run_opts,
             checkpointer=checkpointer,
-            profiler=profiler,
         )
 
     def compute_forward(self, batch, stage):
@@ -77,16 +75,10 @@ class ASR(sb.Brain):
         tokens_bos, _ = batch.tokens_bos
         wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
 
         # Forward pass
         feats = self.hparams.compute_features(wavs)
@@ -111,9 +103,10 @@ class ASR(sb.Brain):
                 return p_seq, wav_lens
         else:
             if stage == sb.Stage.VALID:
-                p_tokens, scores = self.hparams.valid_search(x, wav_lens)
+                p_tokens, _, _, _ = self.hparams.valid_search(x, wav_lens)
             else:
-                p_tokens, scores = self.hparams.test_search(x, wav_lens)
+                p_tokens, _, _, _ = self.hparams.test_search(x, wav_lens)
+
             return p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
@@ -132,13 +125,17 @@ class ASR(sb.Brain):
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
+        # Labels must be extended if parallel augmentation or concatenated
+        # augmentation was performed on the input (increasing the time dimension)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            (
+                tokens,
+                tokens_lens,
+                tokens_eos,
+                tokens_eos_lens,
+            ) = self.hparams.wav_augment.replicate_multiple_labels(
+                tokens, tokens_lens, tokens_eos, tokens_eos_lens
             )
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
 
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
@@ -176,23 +173,6 @@ class ASR(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        with torch.no_grad():
-            loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
         if stage != sb.Stage.TRAIN:
@@ -346,26 +326,18 @@ def dataio_prepare(hparams):
         from speechbrain.dataio.batch import PaddedBatch  # noqa
 
         dynamic_hparams = hparams["dynamic_batch_sampler"]
-        hop_size = dynamic_hparams["feats_hop_size"]
-
-        num_buckets = dynamic_hparams["num_buckets"]
+        hop_size = hparams["feats_hop_size"]
 
         train_batch_sampler = DynamicBatchSampler(
             train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
+            **dynamic_hparams,
             length_func=lambda x: int(float(x["duration"]) * (1 / hop_size)),
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
         )
 
         valid_batch_sampler = DynamicBatchSampler(
             valid_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=num_buckets,
+            **dynamic_hparams,
             length_func=lambda x: int(float(x["duration"]) * (1 / hop_size)),
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
         )
 
     return (
@@ -382,7 +354,6 @@ if __name__ == "__main__":
     # CLI:
     hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -413,6 +384,7 @@ if __name__ == "__main__":
             "max_utt": hparams["max_utt"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
 
     # create the dataset objects as well as tokenization and encoding
     (
@@ -426,7 +398,7 @@ if __name__ == "__main__":
     # Depending on the path given in the hparams YAML file,
     # we download the pretrained LM and Tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Helper function that removes optional/deletable parts of the transcript
     # for cleaner performance metrics
diff --git a/recipes/Switchboard/ASR/transformer/hparams/transformer.yaml b/recipes/Switchboard/ASR/transformer/hparams/transformer.yaml
index d0b15473fd3d62a0c68da05de3f4ccc6b237f0a0..674c037199f6d654512df99d8c32f08aa82cf977 100644
--- a/recipes/Switchboard/ASR/transformer/hparams/transformer.yaml
+++ b/recipes/Switchboard/ASR/transformer/hparams/transformer.yaml
@@ -51,10 +51,10 @@ test_csv:
 
 ckpt_interval_minutes: 30  # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 # To make Transformers converge, the global batch size should be large enough.
 # The global batch size is computed as:
-# batch_size * n_gpus * gradient_accumulation.
+# batch_size * n_gpus * grad_accumulation_factor.
 # Empirically, we found that this value should be >= 128.
 # Please, set your parameters accordingly.
 number_of_epochs: 100
@@ -64,6 +64,7 @@ grad_accumulation_factor: 2
 max_grad_norm: 5.0
 loss_reduction: batchmean
 sorting: random
+avg_checkpoints: 5
 
 #dynamic_batching: False
 #
@@ -95,7 +96,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
   batch_size: 1
 
-####################### Model parameters  ###########################
+####################### Model Parameters  ###########################
 # Transformer
 transformer_input_size: 1280
 d_model: 256
@@ -130,6 +131,7 @@ eos_threshold: 1.5
 length_normalization: True
 using_max_attn_shift: False
 max_attn_shift: 30
+scorer_beam_scale: 0.3
 
 CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
   input_shape: (8, 10, 80)
@@ -191,36 +193,52 @@ Adam: !name:torch.optim.Adam
   betas: (0.9, 0.98)
   eps: 0.000000001
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-  modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-  bos_index: !ref <bos_index>
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+  language_model: !ref <lm_model>
+  temperature: !ref <temperature_lm>
+
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
   eos_index: !ref <eos_index>
   blank_index: !ref <blank_index>
+  ctc_fc: !ref <ctc_lin>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+  full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+  weights:
+    transformerlm: !ref <lm_weight>
+    ctc: !ref <ctc_weight_decode>
+  scorer_beam_scale: !ref <scorer_beam_scale>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+  modules: [!ref <Transformer>, !ref <seq_lin>]
+  bos_index: !ref <bos_index>
+  eos_index: !ref <eos_index>
   min_decode_ratio: !ref <min_decode_ratio>
   max_decode_ratio: !ref <max_decode_ratio>
   beam_size: !ref <valid_beam_size>
-  ctc_weight: !ref <ctc_weight_decode>
-  using_eos_threshold: False
-  length_normalization: False
+  using_eos_threshold: !ref <using_eos_threshold>
+  length_normalization: !ref <length_normalization>
+  using_max_attn_shift: !ref <using_max_attn_shift>
+  max_attn_shift: !ref <max_attn_shift>
+  scorer: !ref <scorer>
+  temperature: !ref <temperature>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-  modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+  modules: [!ref <Transformer>, !ref <seq_lin>]
   bos_index: !ref <bos_index>
   eos_index: !ref <eos_index>
-  blank_index: !ref <blank_index>
   min_decode_ratio: !ref <min_decode_ratio>
   max_decode_ratio: !ref <max_decode_ratio>
   beam_size: !ref <test_beam_size>
-  ctc_weight: !ref <ctc_weight_decode>
-  lm_weight: !ref <lm_weight>
-  lm_modules: !ref <lm_model>
-  temperature: !ref <temperature>
-  temperature_lm: !ref <temperature_lm>
   using_eos_threshold: !ref <using_eos_threshold>
-  eos_threshold: !ref <eos_threshold>
-  length_normalization: !ref <length_normalization>
   using_max_attn_shift: !ref <using_max_attn_shift>
   max_attn_shift: !ref <max_attn_shift>
+  eos_threshold: !ref <eos_threshold>
+  temperature: !ref <temperature>
+  length_normalization: !ref <length_normalization>
+  scorer: !ref <scorer>
+
 
 log_softmax: !new:torch.nn.LogSoftmax
   dim: -1
@@ -253,19 +271,41 @@ normalize: !new:speechbrain.processing.features.InputNormalization
   norm_type: global
   update_until_epoch: 4
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-  time_warp: False
-  time_warp_window: 5
-  time_warp_mode: bicubic
-  freq_mask: True
-  n_freq_mask: 4
-  time_mask: True
-  n_time_mask: 4
-  replace_with_zero: False
-  freq_mask_width: 15
-  time_mask_width: 20
-
-speed_perturb: True
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+  orig_freq: !ref <sample_rate>
+  speeds: [95, 100, 105]
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+  drop_length_low: 15
+  drop_length_high: 25
+  drop_count_low: 5
+  drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+  drop_length_low: 25
+  drop_length_high: 35
+  drop_count_low: 2
+  drop_count_high: 2
+  dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+  min_augmentations: 3
+  max_augmentations: 3
+  augment_prob: 1.0
+  augmentations: [
+    !ref <time_drop>,
+    !ref <freq_drop>,
+    !ref <time_warp>]
+
+do_speed_perturb: True
 
 compute_features: !new:speechbrain.lobes.features.Fbank
   sample_rate: !ref <sample_rate>
diff --git a/recipes/Switchboard/ASR/transformer/hparams/transformer_finetuned_LM.yaml b/recipes/Switchboard/ASR/transformer/hparams/transformer_finetuned_LM.yaml
index 02491efa5cdd649b45309bfdb0b3db59125fa5c0..8dd221ca407d4e68e34479cbce9e87f49e6ca542 100644
--- a/recipes/Switchboard/ASR/transformer/hparams/transformer_finetuned_LM.yaml
+++ b/recipes/Switchboard/ASR/transformer/hparams/transformer_finetuned_LM.yaml
@@ -51,10 +51,10 @@ test_csv:
 
 ckpt_interval_minutes: 30  # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 # To make Transformers converge, the global batch size should be large enough.
 # The global batch size is computed as:
-# batch_size * n_gpus * gradient_accumulation.
+# batch_size * n_gpus * grad_accumulation_factor.
 # Empirically, we found that this value should be >= 128.
 # Please, set your parameters accordingly.
 number_of_epochs: 60
@@ -64,6 +64,7 @@ grad_accumulation_factor: 2
 max_grad_norm: 5.0
 loss_reduction: batchmean
 sorting: random
+avg_checkpoints: 5
 
 #dynamic_batching: False
 #
@@ -95,7 +96,7 @@ valid_dataloader_opts:
 test_dataloader_opts:
   batch_size: 1
 
-####################### Model parameters  ###########################
+####################### Model Parameters  ###########################
 # Transformer
 d_model: 512
 nhead: 4
@@ -123,8 +124,9 @@ valid_beam_size: 10
 test_beam_size: 66
 lm_weight: 0.60
 ctc_weight_decode: 0.40
-
-############################## models  ################################
+temperature: 1.15
+temperature_lm: 1.15
+############################## Models  ################################
 
 CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
   input_shape: (8, 10, 80)
@@ -185,33 +187,45 @@ Adam: !name:torch.optim.Adam
   betas: (0.9, 0.98)
   eps: 0.000000001
 
-valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-  modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
-  bos_index: !ref <bos_index>
+
+transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer
+  language_model: !ref <lm_model>
+  temperature: !ref <temperature_lm>
+
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
   eos_index: !ref <eos_index>
   blank_index: !ref <blank_index>
+  ctc_fc: !ref <ctc_lin>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+  full_scorers: [!ref <transformerlm_scorer>, !ref <ctc_scorer>]
+  weights:
+    transformerlm: !ref <lm_weight>
+    ctc: !ref <ctc_weight_decode>
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+  modules: [!ref <Transformer>, !ref <seq_lin>]
+  bos_index: !ref <bos_index>
+  eos_index: !ref <eos_index>
   min_decode_ratio: !ref <min_decode_ratio>
   max_decode_ratio: !ref <max_decode_ratio>
   beam_size: !ref <valid_beam_size>
-  ctc_weight: !ref <ctc_weight_decode>
   using_eos_threshold: False
-  length_normalization: False
+  length_normalization: True
+  scorer: !ref <scorer>
 
-test_search: !new:speechbrain.decoders.S2STransformerBeamSearch
-  modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+  modules: [!ref <Transformer>, !ref <seq_lin>]
   bos_index: !ref <bos_index>
   eos_index: !ref <eos_index>
-  blank_index: !ref <blank_index>
   min_decode_ratio: !ref <min_decode_ratio>
   max_decode_ratio: !ref <max_decode_ratio>
   beam_size: !ref <test_beam_size>
-  ctc_weight: !ref <ctc_weight_decode>
-  lm_weight: !ref <lm_weight>
-  lm_modules: !ref <lm_model>
-  temperature: 1.15
-  temperature_lm: 1.15
   using_eos_threshold: False
   length_normalization: True
+  scorer: !ref <scorer>
+  temperature: !ref <temperature>
 
 log_softmax: !new:torch.nn.LogSoftmax
   dim: -1
@@ -244,19 +258,41 @@ normalize: !new:speechbrain.processing.features.InputNormalization
   norm_type: global
   update_until_epoch: 4
 
-augmentation: !new:speechbrain.lobes.augment.SpecAugment
-  time_warp: False
-  time_warp_window: 5
-  time_warp_mode: bicubic
-  freq_mask: True
-  n_freq_mask: 4
-  time_mask: True
-  n_time_mask: 4
-  replace_with_zero: False
-  freq_mask_width: 15
-  time_mask_width: 20
-
-speed_perturb: True
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+  orig_freq: !ref <sample_rate>
+  speeds: [95, 100, 105]
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+  drop_length_low: 15
+  drop_length_high: 25
+  drop_count_low: 5
+  drop_count_high: 5
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+  drop_length_low: 25
+  drop_length_high: 35
+  drop_count_low: 2
+  drop_count_high: 2
+  dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+  min_augmentations: 3
+  max_augmentations: 3
+  augment_prob: 1.0
+  augmentations: [
+    !ref <time_drop>,
+    !ref <freq_drop>,
+    !ref <time_warp>]
+
+do_speed_perturb: True
 
 compute_features: !new:speechbrain.lobes.features.Fbank
   sample_rate: !ref <sample_rate>
diff --git a/recipes/Switchboard/ASR/transformer/train.py b/recipes/Switchboard/ASR/transformer/train.py
index 30ccf20173b93cc63e8f10b2acb3226ab171ed4e..76a90a2a4959ed34b69f59689f7f53ea0ee8d7ce 100644
--- a/recipes/Switchboard/ASR/transformer/train.py
+++ b/recipes/Switchboard/ASR/transformer/train.py
@@ -58,7 +58,6 @@ class ASR(sb.core.Brain):
         hparams=None,
         run_opts=None,
         checkpointer=None,
-        profiler=None,
         normalize_fn=None,
     ):
 
@@ -70,7 +69,6 @@ class ASR(sb.core.Brain):
             hparams=hparams,
             run_opts=run_opts,
             checkpointer=checkpointer,
-            profiler=profiler,
         )
 
     def compute_forward(self, batch, stage):
@@ -79,23 +77,15 @@ class ASR(sb.core.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, _ = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-
         # compute features
         feats = self.hparams.compute_features(wavs)
         current_epoch = self.hparams.epoch_counter.current
         feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                feats = self.hparams.augmentation(feats)
-
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)
         # forward modules
         src = self.modules.CNN(feats)
 
@@ -121,9 +111,13 @@ class ASR(sb.core.Brain):
             if current_epoch % self.hparams.valid_search_interval == 0:
                 # for the sake of efficiency, we only perform beamsearch with limited capacity
                 # and no LM to give user some idea of how the AM is doing
-                hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens)
+                hyps, _, _, _ = self.hparams.valid_search(
+                    enc_out.detach(), wav_lens
+                )
         elif stage == sb.Stage.TEST:
-            hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
+            # for the sake of efficiency, we only perform beamsearch with limited capacity
+            # and no LM to give user some idea of how the AM is doing
+            hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
         return p_ctc, p_seq, wav_lens, hyps
 
@@ -136,13 +130,18 @@ class ASR(sb.core.Brain):
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
-            )
-            tokens = torch.cat([tokens, tokens], dim=0)
-            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
+        if stage == sb.Stage.TRAIN:
+            # Labels must be extended if parallel augmentation or concatenated
+            # augmentation was performed on the input (increasing the time dimension)
+            if hasattr(self.hparams, "fea_augment"):
+                (
+                    tokens,
+                    tokens_lens,
+                    tokens_eos,
+                    tokens_eos_lens,
+                ) = self.hparams.fea_augment.replicate_multiple_labels(
+                    tokens, tokens_lens, tokens_eos, tokens_eos_lens
+                )
 
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
@@ -183,39 +182,10 @@ class ASR(sb.core.Brain):
             self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
         return loss
 
-    def fit_batch(self, batch):
-
-        should_step = self.step % self.grad_accumulation_factor == 0
-        # Managing automatic mixed precision
-        if self.auto_mix_prec:
-            self.optimizer.zero_grad()
-            with torch.cuda.amp.autocast():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            self.scaler.scale(loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                self.scaler.unscale_(self.optimizer)
-                if self.check_gradients(loss):
-                    self.scaler.step(self.optimizer)
-                self.scaler.update()
-                self.optimizer_step += 1
-
-                # anneal lr every update
-                self.hparams.noam_annealing(self.optimizer)
-        else:
-            outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            (loss / self.grad_accumulation_factor).backward()
-            if should_step:
-                if self.check_gradients(loss):
-                    self.optimizer.step()
-                self.optimizer.zero_grad()
-                self.optimizer_step += 1
-
-                # anneal lr every update
-                self.hparams.noam_annealing(self.optimizer)
-
-        return loss.detach().cpu()
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            self.hparams.noam_annealing(self.optimizer)
 
     def evaluate_batch(self, batch, stage):
         """Computations needed for validation/test batches"""
@@ -247,7 +217,7 @@ class ASR(sb.core.Brain):
                 stage_stats["WER"] = self.wer_metric.summarize("error_rate")
 
         # log stats and save checkpoint at end-of-epoch
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
 
             lr = self.hparams.noam_annealing.current_lr
             steps = self.optimizer_step
@@ -267,7 +237,7 @@ class ASR(sb.core.Brain):
             self.checkpointer.save_and_keep_only(
                 meta={"ACC": stage_stats["ACC"], "epoch": epoch},
                 max_keys=["ACC"],
-                num_to_keep=5,
+                num_to_keep=self.hparams.avg_checkpoints,
             )
 
         elif stage == sb.Stage.TEST:
@@ -296,7 +266,7 @@ class ASR(sb.core.Brain):
             max_key=max_key, min_key=min_key
         )
         ckpt = sb.utils.checkpoints.average_checkpoints(
-            ckpts, recoverable_name="model", device=self.device
+            ckpts, recoverable_name="model",
         )
 
         self.hparams.model.load_state_dict(ckpt, strict=True)
@@ -420,16 +390,13 @@ def dataio_prepare(hparams):
             else:
                 resampled = resampled[:, 1]
 
-        if hparams["speed_perturb"]:
-            # sig = sb.dataio.dataio.read_audio(wav)
-            # factor = np.random.uniform(0.95, 1.05)
-            # sig = resample(sig.numpy(), 16000, int(16000*factor))
-            speed = sb.processing.speech_augmentation.SpeedPerturb(
-                16000, [x for x in range(95, 105)]
-            )
-            resampled = speed(resampled.unsqueeze(0)).squeeze(
-                0
-            )  # torch.from_numpy(sig)
+        # Speed Perturb is done here so it is multi-threaded with the
+        # workers of the dataloader (faster).
+        if hparams["do_speed_perturb"]:
+            resampled = hparams["speed_perturb"](
+                resampled.unsqueeze(0)
+            ).squeeze(0)
+
         return resampled
 
     sb.dataio.dataset.add_dynamic_item([train_data], audio_pipeline_train)
@@ -471,7 +438,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -505,7 +471,7 @@ if __name__ == "__main__":
     # We download the pretrained LM from HuggingFace (or elsewhere depending on
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Helper function that removes optional/deletable parts of the transcript
     # for cleaner performance metrics
diff --git a/recipes/Switchboard/LM/hparams/transformer.yaml b/recipes/Switchboard/LM/hparams/transformer.yaml
index bb8133aa5e82ca3a4c430aa0d80d849fd75da6ae..b501faf55cf7d8830f06abafa2145b02a9d6a465 100644
--- a/recipes/Switchboard/LM/hparams/transformer.yaml
+++ b/recipes/Switchboard/LM/hparams/transformer.yaml
@@ -36,11 +36,11 @@ test_csv: !ref <save_folder>/test.csv
 # (e.g. /path/to/2000_unigram.model)
 tokenizer_file: !PLACEHOLDER
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 100
 batch_size: 164
 lr: 1
-accu_steps: 2 # Gradient accumulation to simulate large batch training
+grad_accumulation_factor: 2 # Gradient accumulation to simulate large batch training
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
 # Dataloader options
diff --git a/recipes/Switchboard/LM/hparams/transformer_finetune.yaml b/recipes/Switchboard/LM/hparams/transformer_finetune.yaml
index c7abd1d209ac1816bd0bc3db92d3043f87266e5c..5b0860e41a432252b513ccb23afe81508d42ad23 100644
--- a/recipes/Switchboard/LM/hparams/transformer_finetune.yaml
+++ b/recipes/Switchboard/LM/hparams/transformer_finetune.yaml
@@ -39,11 +39,11 @@ test_csv: !ref <save_folder>/test.csv
 # instead. E.g if you want to use your own LM / tokenizer.
 pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 5
 batch_size: 128
 lr: 2
-accu_steps: 2
+grad_accumulation_factor: 2
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
 # Dataloader options
diff --git a/recipes/Switchboard/LM/train.py b/recipes/Switchboard/LM/train.py
index 5b6b9a944b59478d1aa04c0e0819e084d3a60ac8..7363ee33ea4f4f6ba2c0e2fd8020c485cfe4bd19 100644
--- a/recipes/Switchboard/LM/train.py
+++ b/recipes/Switchboard/LM/train.py
@@ -39,20 +39,9 @@ class LM(sb.core.Brain):
         )
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        (loss / self.hparams.accu_steps).backward()
-
-        if self.step % self.hparams.accu_steps == 0:
-            # gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
-
-            self.optimizer.step()
-            self.optimizer.zero_grad()
-
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
             if isinstance(
                 self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
             ) or isinstance(
@@ -61,15 +50,13 @@ class LM(sb.core.Brain):
             ):
                 self.hparams.lr_annealing(self.optimizer)
 
-        return loss
-
     def on_stage_end(self, stage, stage_loss, epoch):
         """Gets called at the end of a epoch."""
         stage_stats = {"loss": stage_loss}
         if stage == sb.Stage.TRAIN:
             self.train_stats = stage_stats
 
-        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
+        if stage == sb.Stage.VALID:
             if not (
                 isinstance(
                     self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
@@ -148,7 +135,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -182,7 +168,7 @@ if __name__ == "__main__":
     # We download the tokenizer and pretrained LM from HuggingFace (or elsewhere depending on
     # the path given in tokenizer_file of the hparams YAML file).
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     lm_brain = LM(
         modules=hparams["modules"],
diff --git a/recipes/Switchboard/Tokenizer/hparams/2K_unigram_subword_bpe.yaml b/recipes/Switchboard/Tokenizer/hparams/2K_unigram_subword_bpe.yaml
index e6c546bdf9deec07bffa0c0546401f1990c8b33a..d07d83e707b0ae05a49e3c2813a79a1a0baec88e 100644
--- a/recipes/Switchboard/Tokenizer/hparams/2K_unigram_subword_bpe.yaml
+++ b/recipes/Switchboard/Tokenizer/hparams/2K_unigram_subword_bpe.yaml
@@ -20,7 +20,7 @@ train_csv: !ref <output_folder>/train_lm.csv
 valid_csv: !ref <output_folder>/dev.csv
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 2000  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
diff --git a/recipes/Switchboard/Tokenizer/train.py b/recipes/Switchboard/Tokenizer/train.py
index 1bbf6c51488b05365eb95f734e611de111b28ba8..95223e92b5f37be7823bc21e948dea01047e1ce3 100644
--- a/recipes/Switchboard/Tokenizer/train.py
+++ b/recipes/Switchboard/Tokenizer/train.py
@@ -27,7 +27,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If distributed_launch=True then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/TIMIT/ASR/CTC/README.md b/recipes/TIMIT/ASR/CTC/README.md
index f5597064087607f2b81f2335ced5c6a434901eba..c2571e21183fd5611b20c9f87c612e8b9dd33255 100644
--- a/recipes/TIMIT/ASR/CTC/README.md
+++ b/recipes/TIMIT/ASR/CTC/README.md
@@ -2,8 +2,17 @@
 This folder contains the scripts to train a CTC system using TIMIT.
 TIMIT is a speech dataset available from LDC: https://catalog.ldc.upenn.edu/LDC93S1
 
-# How to run
-python train.py hparams/train.yaml --jit
+# Running the Code
+
+To execute the code, use the following command:
+
+```
+python train.py hparams/train.yaml --data_folder=your_data_folder/TIMIT --jit
+```
+
+**Note on Compilation**:
+Enabling the just-in-time (JIT) compiler significantly improves code performance, resulting in a 50-60% speed boost. We highly recommend utilizing the JIT compiler for optimal results.
+This speed improvement is observed specifically when using the CRDNN model.
 
 # Results
 | Release | hyperparams file | Val. PER | Test PER | Model link | GPUs |
diff --git a/recipes/TIMIT/ASR/CTC/hparams/train.yaml b/recipes/TIMIT/ASR/CTC/hparams/train.yaml
index 7c16e0c75a7029a30438b4f6d1038a88ab0e2675..145fa1a3e362af75bd80e3d435bf0c4ce0ac0d7d 100644
--- a/recipes/TIMIT/ASR/CTC/hparams/train.yaml
+++ b/recipes/TIMIT/ASR/CTC/hparams/train.yaml
@@ -14,14 +14,18 @@ train_log: !ref <output_folder>/train_log.txt
 
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-open_rir_folder: !ref <data_folder> # where to store noisy data for augment (change it if needed)
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
+train_annotation: !ref <save_folder>/train.json
+valid_annotation: !ref <save_folder>/dev.json
+test_annotation: !ref <save_folder>/test.json
 skip_prep: False # Skip data preparation
 uppercase: False # Must be True when the TIMIT dataset is in the upper-case version
 
-# Training parameters
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
+####################### Training Parameters ####################################
 number_of_epochs: 50
 batch_size: 8
 lr: 1.0
@@ -32,7 +36,7 @@ sample_rate: 16000
 n_fft: 400
 n_mels: 40
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -49,30 +53,73 @@ output_neurons: 40
 blank_index: 0
 
 # Dataloader options
+num_workers: 4
 train_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 valid_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 test_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-# Can be removed to improve speed
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <open_rir_folder>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
@@ -119,7 +166,6 @@ modules:
     model: !ref <model>
     output: !ref <output>
     normalize: !ref <normalize>
-    env_corrupt: !ref <env_corrupt>
 
 jit_module_keys: [model]
 
diff --git a/recipes/TIMIT/ASR/CTC/train.py b/recipes/TIMIT/ASR/CTC/train.py
index efc674ea3ed7d54bf68528cbc98e21dd5371f534..c036b4a5c062aa9cd8ed712657015ddbce14dc4e 100644
--- a/recipes/TIMIT/ASR/CTC/train.py
+++ b/recipes/TIMIT/ASR/CTC/train.py
@@ -5,7 +5,11 @@ Greedy search is using for validation, while beamsearch
 is used at test time to improve the system performance.
 
 To run this recipe, do the following:
-> python train.py hparams/train.yaml --data_folder /path/to/TIMIT
+> python train.py hparams/train.yaml --data_folder /path/to/TIMIT --jit
+
+Note on Compilation:
+Enabling the just-in-time (JIT) compiler with --jit significantly improves code performance,
+resulting in a 50-60% speed boost. We highly recommend utilizing the JIT compiler for optimal results.
 
 Authors
  * Mirco Ravanelli 2020
@@ -14,7 +18,6 @@ Authors
 
 import os
 import sys
-import torch
 import logging
 import speechbrain as sb
 from hyperpyyaml import load_hyperpyyaml
@@ -29,14 +32,10 @@ class ASR_Brain(sb.Brain):
         "Given an input batch it computes the phoneme probabilities."
         batch = batch.to(self.device)
         wavs, wav_lens = batch.sig
-        # Adding optional augmentation when specified:
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "env_corrupt"):
-                wavs_noise = self.hparams.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
 
         feats = self.hparams.compute_features(wavs)
         feats = self.modules.normalize(feats, wav_lens)
@@ -51,9 +50,9 @@ class ASR_Brain(sb.Brain):
         pout, pout_lens = predictions
         phns, phn_lens = batch.phn_encoded
 
-        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "env_corrupt"):
-            phns = torch.cat([phns, phns], dim=0)
-            phn_lens = torch.cat([phn_lens, phn_lens], dim=0)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            phns = self.hparams.wav_augment.replicate_labels(phns)
+            phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens)
 
         loss = self.hparams.compute_cost(pout, phns, pout_lens, phn_lens)
         self.ctc_metrics.append(batch.id, pout, phns, pout_lens, phn_lens)
@@ -234,6 +233,7 @@ if __name__ == "__main__":
             "uppercase": hparams["uppercase"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
 
     # Dataset IO prep: creating Dataset objects and proper encodings for phones
     train_data, valid_data, test_data, label_encoder = dataio_prep(hparams)
diff --git a/recipes/TIMIT/ASR/seq2seq/README.md b/recipes/TIMIT/ASR/seq2seq/README.md
index 2eda1eb03ae8da2c2f21950cfdd14e0e3fba872c..afd5f9b3dcabcb60534d69c61d1d3c5534deac56 100644
--- a/recipes/TIMIT/ASR/seq2seq/README.md
+++ b/recipes/TIMIT/ASR/seq2seq/README.md
@@ -2,8 +2,17 @@
 This folder contains the scripts to train a seq2seq RNNN-based system using TIMIT.
 TIMIT is a speech dataset available from LDC: https://catalog.ldc.upenn.edu/LDC93S1
 
-# How to run
-python train.py hparams/train.yaml
+# Running the Code
+
+To execute the code, use the following command:
+
+```
+python train.py hparams/train.yaml --data_folder=your_data_folder/TIMIT --jit
+```
+
+**Important Note on Compilation**:
+Enabling the just-in-time (JIT) compiler with --jit significantly improves code performance, resulting in a 50-60% speed boost. We highly recommend utilizing the JIT compiler for optimal results.
+This speed improvement is observed specifically when using the CRDNN model.
 
 # Results
 
diff --git a/recipes/TIMIT/ASR/seq2seq/hparams/train.yaml b/recipes/TIMIT/ASR/seq2seq/hparams/train.yaml
index 72aa5e2e6420f8eb6771b5aaba587459ce467524..d61179fa9787d06cdfe128ff4942485725648033 100644
--- a/recipes/TIMIT/ASR/seq2seq/hparams/train.yaml
+++ b/recipes/TIMIT/ASR/seq2seq/hparams/train.yaml
@@ -16,13 +16,13 @@ train_log: !ref <output_folder>/train_log.txt
 
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
+train_annotation: !ref <save_folder>/train.json
+valid_annotation: !ref <save_folder>/dev.json
+test_annotation: !ref <save_folder>/test.json
 skip_prep: False # Skip data preparation
 uppercase: False # Must be True when the TIMIT dataset is in the upper-case version
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 50
 batch_size: 8 # Used if dynamic_batching is False
 lr: 0.0003
@@ -34,7 +34,7 @@ sample_rate: 16000
 n_fft: 400
 n_mels: 40
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -76,18 +76,53 @@ test_dataloader_opts:
 # For more info, see speechbrain.dataio.sampler.DynamicBatchSampler
 dynamic_batching: False
 
+feats_hop_size: 0.01
+max_batch_length: 5000 # in terms of frames
+num_buckets: 20
+shuffle: False # if true re-creates batches at each epoch shuffling examples.
+batch_ordering: random
+
 dynamic_batch_sampler:
-    feats_hop_size: 0.01
-    max_batch_len: 5000 # in terms of frames
-    num_buckets: 20
-    shuffle_ex: False # if true re-creates batches at each epoch shuffling examples.
-    batch_ordering: random
+    max_batch_length: !ref <max_batch_length>
+    num_buckets: !ref <num_buckets>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
 
+############################## Augmentations ###################################
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
 
@@ -149,7 +184,8 @@ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
 seq_cost: !name:speechbrain.nnet.losses.nll_loss
     label_smoothing: 0.1
 
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
+
+valid_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
     embedding: !ref <emb>
     decoder: !ref <dec>
     linear: !ref <seq_lin>
@@ -158,14 +194,12 @@ greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
+test_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
     embedding: !ref <emb>
     decoder: !ref <dec>
     linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <beam_size>
diff --git a/recipes/TIMIT/ASR/seq2seq/hparams/train_with_wav2vec2.yaml b/recipes/TIMIT/ASR/seq2seq/hparams/train_with_wav2vec2.yaml
index 509911930b140302fd9fc149715a9e6fe836d574..705f79e9a8cefc25b3138070b6fcd10682381c03 100644
--- a/recipes/TIMIT/ASR/seq2seq/hparams/train_with_wav2vec2.yaml
+++ b/recipes/TIMIT/ASR/seq2seq/hparams/train_with_wav2vec2.yaml
@@ -17,23 +17,23 @@ wav2vec2_hub: "facebook/wav2vec2-large-lv60"
 
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
+train_annotation: !ref <save_folder>/train.json
+valid_annotation: !ref <save_folder>/dev.json
+test_annotation: !ref <save_folder>/test.json
 skip_prep: False # Skip data preparation
 uppercase: False # Must be True when the TIMIT dataset is in the upper-case version
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 batch_size: 8
 lr: 0.0003
 lr_wav2vec: 0.0001
 ctc_weight: 0.2
 sorting: ascending
-auto_mix_prec: False
+precision: fp32 # bf16, fp16 or fp32
 sample_rate: 16000
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dnn_layers: 2
 dnn_neurons: 1024
@@ -66,10 +66,40 @@ test_dataloader_opts:
     batch_size: !ref <batch_size>
     num_workers: !ref <batch_size>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
@@ -79,7 +109,7 @@ enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
     dnn_blocks: !ref <dnn_layers>
     dnn_neurons: !ref <dnn_neurons>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
@@ -133,7 +163,7 @@ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
 seq_cost: !name:speechbrain.nnet.losses.nll_loss
     label_smoothing: 0.1
 
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
+valid_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
     embedding: !ref <emb>
     decoder: !ref <dec>
     linear: !ref <seq_lin>
@@ -142,14 +172,12 @@ greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
 
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
+test_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
     embedding: !ref <emb>
     decoder: !ref <dec>
     linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <beam_size>
diff --git a/recipes/TIMIT/ASR/seq2seq/train.py b/recipes/TIMIT/ASR/seq2seq/train.py
index ced9d5760c7b8db6bf7cdce19ef4525181727a87..7c3efd9fc9488bd74515347e58565b7b476b903a 100644
--- a/recipes/TIMIT/ASR/seq2seq/train.py
+++ b/recipes/TIMIT/ASR/seq2seq/train.py
@@ -6,7 +6,11 @@ Greedy search is using for validation, while beamsearch is used at test time to
 improve the system performance.
 
 To run this recipe, do the following:
-> python train.py hparams/train.yaml --data_folder /path/to/TIMIT
+> python train.py hparams/train.yaml --data_folder /path/to/TIMIT --jit
+
+Note on Compilation:
+Enabling the just-in-time (JIT) compiler with --jit significantly improves code performance,
+resulting in a 50-60% speed boost. We highly recommend utilizing the JIT compiler for optimal results.
 
 Authors
  * Mirco Ravanelli 2020
@@ -26,7 +30,6 @@ from speechbrain.dataio.dataloader import SaveableDataLoader
 from speechbrain.dataio.sampler import DynamicBatchSampler
 from speechbrain.dataio.batch import PaddedBatch
 
-
 logger = logging.getLogger(__name__)
 
 
@@ -38,14 +41,10 @@ class ASR(sb.Brain):
         wavs, wav_lens = batch.sig
         phns_bos, _ = batch.phn_encoded_bos
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "env_corrupt"):
-                wavs_noise = self.hparams.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                phns_bos = torch.cat([phns_bos, phns_bos])
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            phns_bos = self.hparams.wav_augment.replicate_labels(phns_bos)
 
         feats = self.hparams.compute_features(wavs)
         feats = self.modules.normalize(feats, wav_lens)
@@ -62,32 +61,30 @@ class ASR(sb.Brain):
         logits = self.modules.seq_lin(h)
         p_seq = self.hparams.log_softmax(logits)
 
+        hyps = None
         if stage == sb.Stage.VALID:
-            hyps, scores = self.hparams.greedy_searcher(x, wav_lens)
-            return p_ctc, p_seq, wav_lens, hyps
+            hyps, _, _, _ = self.hparams.valid_searcher(x, wav_lens)
 
         elif stage == sb.Stage.TEST:
-            hyps, scores = self.hparams.beam_searcher(x, wav_lens)
-            return p_ctc, p_seq, wav_lens, hyps
+            hyps, _, _, _ = self.hparams.test_searcher(x, wav_lens)
 
-        return p_ctc, p_seq, wav_lens
+        return p_ctc, p_seq, wav_lens, hyps
 
     def compute_objectives(self, predictions, batch, stage):
         "Given the network predictions and targets computed the NLL loss."
-        if stage == sb.Stage.TRAIN:
-            p_ctc, p_seq, wav_lens = predictions
-        else:
-            p_ctc, p_seq, wav_lens, hyps = predictions
+        p_ctc, p_seq, wav_lens, hyps = predictions
 
         ids = batch.id
         phns_eos, phn_lens_eos = batch.phn_encoded_eos
         phns, phn_lens = batch.phn_encoded
 
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            phns = torch.cat([phns, phns], dim=0)
-            phn_lens = torch.cat([phn_lens, phn_lens], dim=0)
-            phns_eos = torch.cat([phns_eos, phns_eos], dim=0)
-            phn_lens_eos = torch.cat([phn_lens_eos, phn_lens_eos], dim=0)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            phns = self.hparams.wav_augment.replicate_labels(phns)
+            phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens)
+            phns_eos = self.hparams.wav_augment.replicate_labels(phns_eos)
+            phn_lens_eos = self.hparams.wav_augment.replicate_labels(
+                phn_lens_eos
+            )
 
         loss_ctc = self.hparams.ctc_cost(p_ctc, phns, wav_lens, phn_lens)
         loss_seq = self.hparams.seq_cost(p_seq, phns_eos, phn_lens_eos)
@@ -104,22 +101,6 @@ class ASR(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         "Gets called when a stage (either training, validation, test) starts."
         self.ctc_metrics = self.hparams.ctc_stats()
@@ -278,15 +259,12 @@ def dataio_prep(hparams):
     # Support for dynamic batching
     if hparams["dynamic_batching"]:
         dynamic_hparams = hparams["dynamic_batch_sampler"]
-        hop_size = dynamic_hparams["feats_hop_size"]
+        hop_size = hparams["feats_hop_size"]
 
         batch_sampler = DynamicBatchSampler(
             train_data,
-            dynamic_hparams["max_batch_len"],
-            num_buckets=dynamic_hparams["num_buckets"],
+            **dynamic_hparams,
             length_func=lambda x: x["duration"] * (1 / hop_size),
-            shuffle=dynamic_hparams["shuffle_ex"],
-            batch_ordering=dynamic_hparams["batch_ordering"],
         )
 
         train_data = SaveableDataLoader(
diff --git a/recipes/TIMIT/ASR/seq2seq/train_with_wav2vec2.py b/recipes/TIMIT/ASR/seq2seq/train_with_wav2vec2.py
index 18fdd0a1ed341598d20f57d827d62d7cc885b6b1..c53f0365db7d39e68fdeb199a1b8d77dc5659c91 100644
--- a/recipes/TIMIT/ASR/seq2seq/train_with_wav2vec2.py
+++ b/recipes/TIMIT/ASR/seq2seq/train_with_wav2vec2.py
@@ -33,9 +33,10 @@ class ASR(sb.Brain):
         wavs, wav_lens = batch.sig
         phns_bos, _ = batch.phn_encoded_bos
 
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            phns_bos = self.hparams.wav_augment.replicate_labels(phns_bos)
 
         feats = self.modules.wav2vec2(wavs, wav_lens)
         x = self.modules.enc(feats)
@@ -51,27 +52,30 @@ class ASR(sb.Brain):
         logits = self.modules.seq_lin(h)
         p_seq = self.hparams.log_softmax(logits)
 
+        hyps = None
         if stage == sb.Stage.VALID:
-            hyps, scores = self.hparams.greedy_searcher(x, wav_lens)
-            return p_ctc, p_seq, wav_lens, hyps
-
+            hyps, _, _, _ = self.hparams.valid_searcher(x, wav_lens)
         elif stage == sb.Stage.TEST:
-            hyps, scores = self.hparams.beam_searcher(x, wav_lens)
-            return p_ctc, p_seq, wav_lens, hyps
+            hyps, _, _, _ = self.hparams.test_searcher(x, wav_lens)
 
-        return p_ctc, p_seq, wav_lens
+        return p_ctc, p_seq, wav_lens, hyps
 
     def compute_objectives(self, predictions, batch, stage):
         "Given the network predictions and targets computed the NLL loss."
-        if stage == sb.Stage.TRAIN:
-            p_ctc, p_seq, wav_lens = predictions
-        else:
-            p_ctc, p_seq, wav_lens, hyps = predictions
+        p_ctc, p_seq, wav_lens, hyps = predictions
 
         ids = batch.id
         phns_eos, phn_lens_eos = batch.phn_encoded_eos
         phns, phn_lens = batch.phn_encoded
 
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            phns = self.hparams.wav_augment.replicate_labels(phns)
+            phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens)
+            phns_eos = self.hparams.wav_augment.replicate_labels(phns_eos)
+            phn_lens_eos = self.hparams.wav_augment.replicate_labels(
+                phn_lens_eos
+            )
+
         loss_ctc = self.hparams.ctc_cost(p_ctc, phns, wav_lens, phn_lens)
         loss_seq = self.hparams.seq_cost(p_seq, phns_eos, phn_lens_eos)
         loss = self.hparams.ctc_weight * loss_ctc
@@ -87,12 +91,6 @@ class ASR(sb.Brain):
 
         return loss
 
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         "Gets called when a stage (either training, validation, test) starts."
         self.ctc_metrics = self.hparams.ctc_stats()
@@ -156,61 +154,6 @@ class ASR(sb.Brain):
                         self.hparams.test_wer_file,
                     )
 
-    def fit_batch(self, batch):
-        """Fit one batch, override to do multiple updates.
-
-        The default implementation depends on a few methods being defined
-        with a particular behavior:
-
-        * ``compute_forward()``
-        * ``compute_objectives()``
-
-        Also depends on having optimizers passed at initialization.
-
-        Arguments
-        ---------
-        batch : list of torch.Tensors
-            Batch of data to use for training. Default implementation assumes
-            this batch has two elements: inputs and targets.
-
-        Returns
-        -------
-        detached loss
-        """
-        # Managing automatic mixed precision
-        if self.auto_mix_prec:
-
-            self.wav2vec_optimizer.zero_grad()
-            self.adam_optimizer.zero_grad()
-
-            with torch.cuda.amp.autocast():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-
-            self.scaler.scale(loss).backward()
-            self.scaler.unscale_(self.wav2vec_optimizer)
-            self.scaler.unscale_(self.adam_optimizer)
-
-            if self.check_gradients(loss):
-                self.scaler.step(self.wav2vec_optimizer)
-                self.scaler.step(self.adam_optimizer)
-
-            self.scaler.update()
-        else:
-            outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            loss.backward()
-
-            if self.check_gradients(loss):
-                self.wav2vec_optimizer.step()
-                self.adam_optimizer.step()
-
-            self.wav2vec_optimizer.zero_grad()
-            self.adam_optimizer.zero_grad()
-
-        return loss.detach().cpu()
-
     def init_optimizers(self):
         "Initializes the wav2vec2 optimizer and model optimizer"
         self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
@@ -226,9 +169,10 @@ class ASR(sb.Brain):
             )
             self.checkpointer.add_recoverable("adam_opt", self.adam_optimizer)
 
-    def zero_grad(self, set_to_none=False):
-        self.wav2vec_optimizer.zero_grad(set_to_none)
-        self.adam_optimizer.zero_grad(set_to_none)
+        self.optimizers_dict = {
+            "wav2vec_opt": self.wav2vec_optimizer,
+            "adam_opt": self.adam_optimizer,
+        }
 
 
 def dataio_prep(hparams):
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md
deleted file mode 100644
index ad8aff73bda6cc9dfd869fd96fbe4668b101b69b..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md
+++ /dev/null
@@ -1,97 +0,0 @@
-## Multi-teacher Knowledge Distillation for CTC/Att models
-This is the implementation of multi-teacher distillation methods to
-joint ctc-attention end-to-end ASR systems. The proposed approaches integrate
-the error rate metric to the teacher selection rather than solely focusing on the observed losses.
-This way, we directly distillate and optimize the student toward the relevant metric for speech recognition.
-For details please refer to: https://arxiv.org/abs/2005.09310
-
-### Results with this recipe
-
-| Distillation Strategy | Valid PER | Test PER | Model link | GPUs |
-|:---------------------------:| :-----:| :-----:| :-----:| :--------:|
-| Weighted | 11.87 | 13.11 | [model](https://www.dropbox.com/sh/h30wdezgw7qocsc/AACsY20GD94Qe-AukFPhRcj2a?dl=0) | 1xV100 16GB |
-| Best | 11.93 | 13.15 | [model](https://www.dropbox.com/sh/6p0szvox5sj8z77/AAALVExCU0YAGXH-nAm3kdkqa?dl=0) | 1xV100 16GB |
-
-The output folders with checkpoints and logs for TIMIT recipes can be found [here](https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0).
-
-## Installing Extra Dependencies
-
-Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
-
-```
-pip install -r extra_requirements.txt
-```
-
-### Training steps
-To speed up student distillation from multiple teachers, we separate the whole procedure into three parts: teacher model training, inference running on teacher models, student distillation.
-
-#### 1. Teacher model training
-Before doing distillation, we require finishing N teacher models training. Here, we propose to set N=10 as in the referenced paper.
-
-Models training can be done in parallel using `train_teacher.py`.
-
-Example:
-```
-python train_teacher.py hparams/teachers/tea0.yaml --data_folder /path-to/data_folder
-```
-
-#### 2. Run inference on all teacher models
-This part run inference on all teacher models and store them on disk using `save_teachers.py`. It is only required that you setup the `tea_models_dir` variable corresponding to the path to a txt file. The latter txt file needs to contain
-a list of paths pointing to each teacher model.ckpt. We decided to work with a file so it can easily scale to hundreds of teachers. Hence, an example of this
-file is:
-
-```
-results/tea0/1234/save/CKPT+2021-01-21+14-50-32+00/model.ckpt
-results/tea1/1234/save/CKPT+2021-01-21+13-55-56+00/model.ckpt
-results/tea2/1234/save/CKPT+2021-01-21+14-25-21+00/model.ckpt
-results/tea3/1234/save/CKPT+2021-01-21+15-02-32+00/model.ckpt
-results/tea4/1234/save/CKPT+2021-01-21+15-47-09+00/model.ckpt
-results/tea5/1234/save/CKPT+2021-01-21+16-02-38+00/model.ckpt
-results/tea6/1234/save/CKPT+2021-01-21+16-05-33+00/model.ckpt
-results/tea7/1234/save/CKPT+2021-01-21+16-03-20+00/model.ckpt
-results/tea8/1234/save/CKPT+2021-01-21+16-25-17+00/model.ckpt
-results/tea9/1234/save/CKPT+2021-01-21+15-48-42+00/model.ckpt
-```
-
-Example:
-```
-python save_teachers.py hparams/save_teachers.yaml --data_folder /path-to/data_folder --tea_models_dir /path-to/tea_model_paths.txt
-```
-
-#### 3. Student distillation
-This is the main part for distillation using `train_kd.py`. Here, the variable `pretrain` might be used to use a pre-trained teacher as the student. Note that if set to `True`, a path to the corresponding `model.ckpt` must be given in `pretrain_st_dir`. Also, `tea_infer_dir` is required, linking to the directory of teacher model inference results. Finally, note that the distillation must be trained on with the exact same input CSV files that are generated by `save_teachers.py`. This ensure that the distillation is perfectly linked to the
-generated teacher predictions! Diverging input CSV files might generate incompatible shape errors!
-
-Example:
-```
-python train_kd.py hparams/train_kd.yaml --data_folder /path-to/data_folder --pretrain_st_dir /path-to/model_directory --tea_infer_dir /path-to/tea_infer_directory
-```
-
-### Distillation strategies
-There are three strategies in the current version that can be switched with the option `strategy` in `hparams/train_kd.yaml`.
-
-- **average**: average losses of teachers when doing distillation.
-- **best**: choosing the best teacher based on WER.
-- **weighted**: assigning weights to teachers based on WER.
-
-
-# **About SpeechBrain**
-- Website: https://speechbrain.github.io/
-- Code: https://github.com/speechbrain/speechbrain/
-- HuggingFace: https://huggingface.co/speechbrain/
-
-
-# **Citing SpeechBrain**
-Please, cite SpeechBrain if you use it for your research or business.
-
-```bibtex
-@misc{speechbrain,
-  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
-  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
-  year={2021},
-  eprint={2106.04624},
-  archivePrefix={arXiv},
-  primaryClass={eess.AS},
-  note={arXiv:2106.04624}
-}
-```
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/extra_requirements.txt b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/extra_requirements.txt
deleted file mode 100644
index c5a4eac431789a52d7f7d521a84937511cdc400c..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/extra_requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-h5py
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/save_teachers.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/save_teachers.yaml
deleted file mode 100644
index 35d190a0208a3f72768b8b5522c34a8c7bb50ac5..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/save_teachers.yaml
+++ /dev/null
@@ -1,486 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/teachers_save/<seed>
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-# Training parameters
-# number_of_epochs: 1
-batch_size: 8
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-emb_size: 128
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-# min_decode_ratio: 0.0
-# max_decode_ratio: 1.0
-# beam_size: 16
-# eos_threshold: 1.5F
-
-# teacher models
-num_tea: 10
-
-# .txt file containing paths for saved teacher models.
-# e.g. each line is /path/to/model.ckpt
-tea_models_dir: !PLACEHOLDER
-
-# distillation parameters
-# Temperature
-temperature: 1
-# distillation weight alpha
-# alpha: 1
-# different stages in dataset
-stage: ["train", "valid", "test"]
-
-# tea0
-tea0_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.15
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 4
-    rnn_neurons: 512
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 512
-
-tea0_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea0_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 512
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea0_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 512
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea0_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea1
-tea1_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.3
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 4
-    rnn_neurons: 512
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 512
-
-tea1_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea1_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 512
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea1_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 512
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea1_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea2
-tea2_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.3
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 4
-    rnn_neurons: 512
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 512
-
-tea2_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea2_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 512
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea2_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 512
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea2_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea3
-tea3_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.2
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 5
-    rnn_neurons: 512
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 512
-
-tea3_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea3_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 512
-    input_size: !ref <emb_size>
-    rnn_type: lstm
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea3_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 512
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea3_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea4
-tea4_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.3
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 4
-    rnn_neurons: 512
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 512
-
-tea4_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea4_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 512
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea4_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 512
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea4_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea5
-tea5_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.3
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 4
-    rnn_neurons: 320
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 320
-
-tea5_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea5_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 320
-    input_size: !ref <emb_size>
-    rnn_type: lstm
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea5_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 320
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea5_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea6
-tea6_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.3
-    cnn_blocks: 1
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 4
-    rnn_neurons: 320
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 320
-
-tea6_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea6_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 320
-    input_size: !ref <emb_size>
-    rnn_type: lstm
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea6_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 320
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea6_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea7
-tea7_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.15
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 4
-    rnn_neurons: 640
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 512
-
-tea7_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea7_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 512
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea7_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 512
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea7_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea8
-tea8_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.3
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 5
-    rnn_neurons: 512
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 512
-
-tea8_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea8_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 512
-    input_size: !ref <emb_size>
-    rnn_type: lstm
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea8_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 512
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea8_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-# tea9
-tea9_enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !name:torch.nn.LeakyReLU
-    dropout: 0.15
-    cnn_blocks: 2
-    cnn_channels: (128, 256)
-    cnn_kernelsize: (3, 3)
-    time_pooling: True
-    rnn_layers: 4
-    rnn_neurons: 512
-    rnn_bidirectional: True
-    dnn_blocks: 2
-    dnn_neurons: 512
-
-tea9_emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-tea9_dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: 512
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: 256
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-tea9_ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 512
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-tea9_seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: 256
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-# epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-#    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea0.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea0.yaml
deleted file mode 100644
index e1d58938f9b406824d82b0bcc2e4fed61c8341b6..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea0.yaml
+++ /dev/null
@@ -1,192 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea0/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.15
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 512
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea1.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea1.yaml
deleted file mode 100644
index 1bf9c1ef2435778e01939d8120dfeafc90873c2b..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea1.yaml
+++ /dev/null
@@ -1,188 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea1/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 16
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.3
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 512
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea2.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea2.yaml
deleted file mode 100644
index 00838335872e0ca08087903576ba6fc1d902430e..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea2.yaml
+++ /dev/null
@@ -1,192 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea2/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 16
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.3
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 512
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea3.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea3.yaml
deleted file mode 100644
index 77b684193af64c706956376f6804bd0288a3798a..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea3.yaml
+++ /dev/null
@@ -1,188 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea3/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.2
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 5
-rnn_neurons: 512
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: lstm
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea4.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea4.yaml
deleted file mode 100644
index e37b620e896fcbc7efbc20e4d1b653720f98c3f9..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea4.yaml
+++ /dev/null
@@ -1,188 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea4/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.3
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 512
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea5.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea5.yaml
deleted file mode 100644
index 28c219c5b5f6d98d1567c506dd53600bce4820a2..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea5.yaml
+++ /dev/null
@@ -1,188 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea5/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.3
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 320
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 320
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: lstm
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea6.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea6.yaml
deleted file mode 100644
index 5a5fbfe45d5087680ecab7d8526134187957e583..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea6.yaml
+++ /dev/null
@@ -1,188 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea6/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.3
-cnn_blocks: 1
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 320
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 320
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: lstm
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea7.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea7.yaml
deleted file mode 100644
index 4f80971c0007eb856f856368f820767766d5c0a7..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea7.yaml
+++ /dev/null
@@ -1,188 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea7/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.15
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 640
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea8.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea8.yaml
deleted file mode 100644
index b45227a839d431c412a5a0c69da7457966aeece6..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea8.yaml
+++ /dev/null
@@ -1,188 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea8/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.3
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 5
-rnn_neurons: 512
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: lstm
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea9.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea9.yaml
deleted file mode 100644
index 9f3a7f9105dee6e5fba15e28db296722ee648be1..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea9.yaml
+++ /dev/null
@@ -1,188 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/tea9/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0003
-ctc_weight: 0.2
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.15
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 512
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-jit_module_keys: [enc]
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        lr_annealing: !ref <lr_annealing>
-        counter: !ref <epoch_counter>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/train_kd.yaml b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/train_kd.yaml
deleted file mode 100644
index 7189fa71d01a69959635e1d107079c7e0af345b6..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/train_kd.yaml
+++ /dev/null
@@ -1,221 +0,0 @@
-# Seed needs to be set at top of yaml, before objects with parameters are made
-seed: 1234
-__set_seed: !apply:torch.manual_seed [!ref <seed>]
-output_folder: !ref results/augment_CRDNN/<seed>
-test_wer_file: !ref <output_folder>/wer_test.txt
-save_folder: !ref <output_folder>/save
-train_log: !ref <output_folder>/train_log.txt
-
-# Data files
-data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <output_folder>/train.json
-valid_annotation: !ref <output_folder>/dev.json
-test_annotation: !ref <output_folder>/test.json
-skip_prep: False
-
-# Path containing the stored inferences of the different teachers
-tea_infer_dir: !PLACEHOLDER
-
-# Training parameters
-number_of_epochs: 50
-batch_size: 8
-lr: 0.0005
-ctc_weight: 0.1
-sorting: ascending
-
-# Feature parameters
-sample_rate: 16000
-n_fft: 400
-n_mels: 40
-
-# teacher models
-num_tea: 10
-
-# distillation parameters
-pretrain: True
-
-# Path to the student model to load the weights from
-pretrain_st_dir: !PLACEHOLDER
-strategy: best  # [average, best, weighted]
-
-# Temperature : smooth the distribution of output probability
-temperature: 1
-# distillation weight alpha
-alpha: 1
-
-# variable name when loading teacher inference
-tea_keys: ["p_ctc_tea", "p_seq_tea", "wer_ctc_tea", "wer_tea"]
-
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-dropout: 0.25
-cnn_blocks: 2
-cnn_channels: (128, 256)
-cnn_kernelsize: (3, 3)
-rnn_layers: 4
-rnn_neurons: 512
-rnn_bidirectional: True
-dnn_blocks: 2
-dnn_neurons: 512
-emb_size: 128
-dec_neurons: 256
-
-# Outputs
-output_neurons: 40
-blank_index: !ref <output_neurons> - 1
-bos_index: !ref <output_neurons> - 1
-eos_index: !ref <output_neurons> - 1
-
-# Decoding parameters
-min_decode_ratio: 0.0
-max_decode_ratio: 1.0
-beam_size: 16
-# eos_threshold: 1.5
-
-# Dataloader options
-train_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-valid_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-test_dataloader_opts:
-    batch_size: !ref <batch_size>
-
-# Functions
-enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
-    input_shape: [null, null, !ref <n_mels>]
-    activation: !ref <activation>
-    dropout: !ref <dropout>
-    cnn_blocks: !ref <cnn_blocks>
-    cnn_channels: !ref <cnn_channels>
-    cnn_kernelsize: !ref <cnn_kernelsize>
-    time_pooling: True
-    rnn_layers: !ref <rnn_layers>
-    rnn_neurons: !ref <rnn_neurons>
-    rnn_bidirectional: !ref <rnn_bidirectional>
-    dnn_blocks: !ref <dnn_blocks>
-    dnn_neurons: !ref <dnn_neurons>
-
-emb: !new:speechbrain.nnet.embedding.Embedding
-    num_embeddings: !ref <output_neurons>
-    embedding_dim: !ref <emb_size>
-
-dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
-    enc_dim: !ref <dnn_neurons>
-    input_size: !ref <emb_size>
-    rnn_type: gru
-    attn_type: location
-    hidden_size: !ref <dec_neurons>
-    attn_dim: 256
-    num_layers: 1
-    scaling: 1.0
-    channels: 10
-    kernel_size: 100
-    re_init: True
-    dropout: 0.5
-
-ctc_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dnn_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 blank
-
-seq_lin: !new:speechbrain.nnet.linear.Linear
-    input_size: !ref <dec_neurons>
-    n_neurons: !ref <output_neurons>  # 39 phonemes + 1 eos
-
-model: !new:torch.nn.ModuleList
-    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]
-
-log_softmax: !new:speechbrain.nnet.activations.Softmax
-    apply_log: True
-
-ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
-    blank_index: !ref <blank_index>
-
-seq_cost: !name:speechbrain.nnet.losses.nll_loss
-    label_smoothing: 0.1
-
-greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNGreedySearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearcher
-    embedding: !ref <emb>
-    decoder: !ref <dec>
-    linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    bos_index: !ref <bos_index>
-    eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
-    min_decode_ratio: !ref <min_decode_ratio>
-    max_decode_ratio: !ref <max_decode_ratio>
-    beam_size: !ref <beam_size>
-
-opt_class: !name:torch.optim.Adam
-    lr: !ref <lr>
-
-lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
-    initial_value: !ref <lr>
-    improvement_threshold: 0.0025
-    annealing_factor: 0.8
-    patient: 0
-
-# Modules to have train/eval/optimizer called on
-modules:
-    enc: !ref <enc>
-    emb: !ref <emb>
-    dec: !ref <dec>
-    ctc_lin: !ref <ctc_lin>
-    seq_lin: !ref <seq_lin>
-    normalize: !ref <normalize>
-
-# Names of modules to be compiled with torch.jit.script
-jit_module_keys: [enc]
-
-epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
-    limit: !ref <number_of_epochs>
-
-normalize: !new:speechbrain.processing.features.InputNormalization
-    norm_type: global
-
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-compute_features: !new:speechbrain.lobes.features.Fbank
-    sample_rate: !ref <sample_rate>
-    n_fft: !ref <n_fft>
-    n_mels: !ref <n_mels>
-
-checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
-    checkpoints_dir: !ref <save_folder>
-    recoverables:
-        model: !ref <model>
-        normalize: !ref <normalize>
-        counter: !ref <epoch_counter>
-        lr_annealing: !ref <lr_annealing>
-
-train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
-    save_file: !ref <train_log>
-
-ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.ctc_loss
-        blank_index: !ref <blank_index>
-        reduction: batch
-
-seq_stats: !name:speechbrain.utils.metric_stats.MetricStats
-    metric: !name:speechbrain.nnet.losses.nll_loss
-        label_smoothing: 0.1
-        reduction: batch
-
-per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
-
-ctc_cost_kd: !name:speechbrain.nnet.losses.ctc_loss_kd
-    blank_index: !ref <blank_index>
-
-seq_cost_kd: !name:speechbrain.nnet.losses.nll_loss_kd
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/save_teachers.py b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/save_teachers.py
deleted file mode 100644
index 83c5ad15a747e64fd507bbe8a7e37bd97c55aabc..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/save_teachers.py
+++ /dev/null
@@ -1,396 +0,0 @@
-#!/usr/bin/env python3
-
-"""Recipe for doing ASR with phoneme targets and joint seq2seq
-and CTC loss on the TIMIT dataset following a knowledge distillation scheme as
-reported in " Distilling Knowledge from Ensembles of Acoustic Models for Joint
-CTC-Attention End-to-End Speech Recognition", Yan Gao et al.
-
-To run this recipe, do the following:
-> python experiment.py hyperparams.yaml --data_folder /path/to/TIMIT
-
-Authors
- * Yan Gao 2021
- * Titouan Parcollet 2021
-"""
-
-import sys
-import torch
-import speechbrain as sb
-from speechbrain.utils.distributed import run_on_main
-from hyperpyyaml import load_hyperpyyaml
-
-from tqdm.contrib import tqdm
-import h5py
-import numpy as np
-
-
-# Define training procedure
-class ASR(sb.Brain):
-    def __init__(self, tea_modules_list=None, hparams=None, run_opts=None):
-        super().__init__(
-            modules=None,
-            opt_class=None,
-            hparams=hparams,
-            run_opts=run_opts,
-            checkpointer=None,
-        )
-
-        # Initialize teacher parameters
-        tea_modules_list_ = []
-        for tea_modules in tea_modules_list:
-            tea_modules_ = torch.nn.ModuleList(tea_modules)
-            tea_modules_ = tea_modules_.to(self.device)
-            tea_modules_list_.append(tea_modules_)
-        self.tea_modules_list = tea_modules_list_
-
-    def compute_forward_tea(self, batch):
-        batch = batch.to(self.device)
-        wavs, wav_lens = batch.sig
-        phns_bos, _ = batch.phn_encoded_bos
-        phns, phn_lens = batch.phn_encoded
-
-        feats = self.hparams.compute_features(wavs)
-        feats = self.hparams.normalize(feats, wav_lens)
-        apply_softmax = torch.nn.Softmax(dim=-1)
-
-        # run inference to each teacher model
-        tea_dict_list = []
-        for num in range(self.hparams.num_tea):
-            tea_dict = {}
-            self.tea_modules_list[num].eval()
-            with torch.no_grad():
-                x_tea = tea_enc_list[num](feats)
-                ctc_logits_tea = tea_ctc_lin_list[num](x_tea)
-
-                # output layer for ctc log-probabilities
-                p_ctc_tea = self.hparams.log_softmax(
-                    ctc_logits_tea / self.hparams.temperature
-                )
-
-                e_in_tea = tea_emb_list[num](phns_bos)
-                h_tea, _ = tea_dec_list[num](e_in_tea, x_tea, wav_lens)
-
-                # output layer for seq2seq log-probabilities
-                seq_logits_tea = tea_seq_lin_list[num](h_tea)
-                p_seq_tea = apply_softmax(
-                    seq_logits_tea / self.hparams.temperature
-                )
-
-                # WER from output layer of CTC
-                sequence_ctc = sb.decoders.ctc_greedy_decode(
-                    p_ctc_tea, wav_lens, blank_id=self.hparams.blank_index
-                )
-
-                phns_decode = sb.utils.data_utils.undo_padding(phns, phn_lens)
-                phns_decode = self.label_encoder.decode_ndim(phns_decode)
-                sequence_decode = self.label_encoder.decode_ndim(sequence_ctc)
-
-                per_stats_ctc = sb.utils.edit_distance.wer_details_for_batch(
-                    batch.id,
-                    phns_decode,
-                    sequence_decode,
-                    compute_alignments=False,
-                )
-
-                wer_ctc_tea = []
-                for item in per_stats_ctc:
-                    wer_ctc_tea.append(item["WER"])
-
-                wer_ctc_tea = exclude_wer(wer_ctc_tea)
-                wer_ctc_tea = np.expand_dims(wer_ctc_tea, axis=0)
-
-                # WER from output layer of CE
-                _, predictions = p_seq_tea.max(dim=-1)
-                hyps = sb.decoders.seq2seq.batch_filter_seq2seq_output(
-                    predictions, eos_id=self.hparams.eos_index
-                )
-                sequence_ce = self.label_encoder.decode_ndim(hyps)
-                per_stats_ce = sb.utils.edit_distance.wer_details_for_batch(
-                    batch.id, phns_decode, sequence_ce, compute_alignments=False
-                )
-
-                wer_tea = []
-                for item in per_stats_ce:
-                    wer_tea.append(item["WER"])
-
-                wer_tea = exclude_wer(wer_tea)
-                wer_tea = np.expand_dims(wer_tea, axis=0)
-
-            # save the variables into dict
-            tea_dict["p_ctc_tea"] = p_ctc_tea.cpu().numpy()
-            tea_dict["p_seq_tea"] = p_seq_tea.cpu().numpy()
-            tea_dict["wer_ctc_tea"] = wer_ctc_tea
-            tea_dict["wer_tea"] = wer_tea
-            tea_dict_list.append(tea_dict)
-
-        return tea_dict_list
-
-    def def_tea_name(self):
-        # define teacher variable name
-        tea_name = []
-        for tea_num in range(self.hparams.num_tea):
-            tea = "t{}".format(tea_num)
-            tea_name.append(tea)
-        return tea_name
-
-    def fit_save(self, train_set, valid_set=None, test_set=None):
-        data_sets = [train_set, valid_set, test_set]
-        stage = self.hparams.stage
-        tea_name = self.def_tea_name()
-
-        # define output file name
-        f_name = "/tea_infer_{}batch.hdf5".format(self.hparams.batch_size)
-        f = h5py.File(self.hparams.output_folder + f_name, "w")
-        for num in range(len(stage)):
-            # create group for each set (train, valid, test).
-            g_sets = f.create_group(stage[num])
-
-            with tqdm(
-                data_sets[num], initial=self.step, dynamic_ncols=True,
-            ) as t:
-                for batch in t:
-                    self.step += 1
-                    # create group for each batch
-                    g_batch = g_sets.create_group(str(self.step))
-
-                    # run inference to each teacher
-                    tea_dict_list = self.compute_forward_tea(batch)
-
-                    for tea_num in range(self.hparams.num_tea):
-                        # create group for each teacher
-                        g_tea = g_batch.create_group(tea_name[tea_num])
-                        g_tea.create_dataset(
-                            "p_ctc_tea",
-                            data=tea_dict_list[tea_num]["p_ctc_tea"],
-                        )
-                        g_tea.create_dataset(
-                            "p_seq_tea",
-                            data=tea_dict_list[tea_num]["p_seq_tea"],
-                        )
-                        g_tea.create_dataset(
-                            "wer_ctc_tea",
-                            data=tea_dict_list[tea_num]["wer_ctc_tea"][0],
-                        )
-                        g_tea.create_dataset(
-                            "wer_tea", data=tea_dict_list[tea_num]["wer_tea"][0]
-                        )
-            self.step = 0
-        f.close()
-
-
-def exclude_wer(wer):
-    """
-    This function is used to exclude the
-    wer values which is more than 100.
-    """
-    wer_list = []
-    for item in wer:
-        if item > 100:
-            item = 100
-        wer_list.append(item)
-    return np.array(wer_list)
-
-
-def data_io_prep(hparams):
-    "Creates the datasets and their data processing pipelines."
-    data_folder = hparams["data_folder"]
-    # 1. Declarations:
-    train_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["train_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    if hparams["sorting"] == "ascending":
-        # we sort training data to speed up training and get better results.
-        train_data = train_data.filtered_sorted(sort_key="duration")
-        # when sorting do not shuffle in dataloader ! otherwise is pointless
-        hparams["train_dataloader_opts"]["shuffle"] = False
-
-    elif hparams["sorting"] == "descending":
-        train_data = train_data.filtered_sorted(
-            sort_key="duration", reverse=True
-        )
-        # when sorting do not shuffle in dataloader ! otherwise is pointless
-        hparams["train_dataloader_opts"]["shuffle"] = False
-
-    elif hparams["sorting"] == "random":
-        pass
-
-    else:
-        raise NotImplementedError(
-            "sorting must be random, ascending or descending"
-        )
-
-    valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["valid_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    valid_data = valid_data.filtered_sorted(sort_key="duration")
-
-    test_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["test_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    test_data = test_data.filtered_sorted(sort_key="duration")
-
-    datasets = [train_data, valid_data, test_data]
-    label_encoder = sb.dataio.encoder.CTCTextEncoder()
-
-    # 2. Define audio pipeline:
-    @sb.utils.data_pipeline.takes("wav")
-    @sb.utils.data_pipeline.provides("sig")
-    def audio_pipeline(wav):
-        sig = sb.dataio.dataio.read_audio(wav)
-        return sig
-
-    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
-
-    # 3. Define text pipeline:
-    @sb.utils.data_pipeline.takes("phn")
-    @sb.utils.data_pipeline.provides(
-        "phn_list",
-        "phn_encoded_list",
-        "phn_encoded",
-        "phn_encoded_eos",
-        "phn_encoded_bos",
-    )
-    def text_pipeline(phn):
-        phn_list = phn.strip().split()
-        yield phn_list
-        phn_encoded_list = label_encoder.encode_sequence(phn_list)
-        yield phn_encoded_list
-        phn_encoded = torch.LongTensor(phn_encoded_list)
-        yield phn_encoded
-        phn_encoded_eos = torch.LongTensor(
-            label_encoder.append_eos_index(phn_encoded_list)
-        )
-        yield phn_encoded_eos
-        phn_encoded_bos = torch.LongTensor(
-            label_encoder.prepend_bos_index(phn_encoded_list)
-        )
-        yield phn_encoded_bos
-
-    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
-
-    # 3. Fit encoder:
-    # NOTE: In this minimal example, also update from valid data
-
-    label_encoder.update_from_didataset(train_data, output_key="phn_list")
-    if (
-        hparams["blank_index"] != hparams["bos_index"]
-        or hparams["blank_index"] != hparams["eos_index"]
-    ):
-        label_encoder.insert_blank(index=hparams["blank_index"])
-
-    if hparams["bos_index"] == hparams["eos_index"]:
-        label_encoder.insert_bos_eos(
-            bos_label="<eos-bos>",
-            eos_label="<eos-bos>",
-            bos_index=hparams["bos_index"],
-        )
-    else:
-        label_encoder.insert_bos_eos(
-            bos_label="<bos>",
-            eos_label="<eos>",
-            bos_index=hparams["bos_index"],
-            eos_index=hparams["eos_index"],
-        )
-
-    # 4. Set output:
-    sb.dataio.dataset.set_output_keys(
-        datasets,
-        ["id", "sig", "phn_encoded", "phn_encoded_eos", "phn_encoded_bos"],
-    )
-
-    return train_data, valid_data, test_data, label_encoder
-
-
-if __name__ == "__main__":
-    # CLI:
-    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
-
-    # Load hyperparameters file with command-line overrides
-    with open(hparams_file) as fin:
-        hparams = load_hyperpyyaml(fin, overrides)
-
-    # Dataset prep (parsing TIMIT and annotation into csv files)
-    from timit_prepare import prepare_timit  # noqa
-
-    # Initialize ddp (useful only for multi-GPU DDP training)
-    sb.utils.distributed.ddp_init_group(run_opts)
-
-    # multi-gpu (ddp) save data preparation
-    run_on_main(
-        prepare_timit,
-        kwargs={
-            "data_folder": hparams["data_folder"],
-            "save_json_train": hparams["train_annotation"],
-            "save_json_valid": hparams["valid_annotation"],
-            "save_json_test": hparams["test_annotation"],
-            "skip_prep": hparams["skip_prep"],
-        },
-    )
-
-    # Dataset IO prep: creating Dataset objects and proper encodings for phones
-    train_data, valid_data, test_data, label_encoder = data_io_prep(hparams)
-
-    # Create experiment directory
-    sb.create_experiment_directory(
-        experiment_directory=hparams["output_folder"],
-        hyperparams_to_save=hparams_file,
-        overrides=overrides,
-    )
-
-    # initialise teacher model variables
-    tea_enc_list = []
-    tea_emb_list = []
-    tea_dec_list = []
-    tea_ctc_lin_list = []
-    tea_seq_lin_list = []
-    for i in range(hparams["num_tea"]):
-        exec("tea_enc_list.append(hparams['tea{}_enc'])".format(i))
-        exec("tea_emb_list.append(hparams['tea{}_emb'])".format(i))
-        exec("tea_dec_list.append(hparams['tea{}_dec'])".format(i))
-        exec("tea_ctc_lin_list.append(hparams['tea{}_ctc_lin'])".format(i))
-        exec("tea_seq_lin_list.append(hparams['tea{}_seq_lin'])".format(i))
-
-    # create ModuleList
-    for i in range(hparams["num_tea"]):
-        exec(
-            "tea{}_modules = torch.nn.ModuleList([tea_enc_list[i], tea_emb_list[i], tea_dec_list[i], tea_ctc_lin_list[i], tea_seq_lin_list[i]])".format(
-                i
-            )
-        )  # i denotes the index of teacher models
-
-    tea_modules_list = []
-    for i in range(hparams["num_tea"]):
-        exec("tea_modules_list.append(tea{}_modules)".format(i))
-
-    # Trainer initialization
-    asr_brain = ASR(
-        tea_modules_list=tea_modules_list, hparams=hparams, run_opts=run_opts
-    )
-    asr_brain.label_encoder = label_encoder
-
-    # load pre-trained weights of teacher models
-    with open(hparams["tea_models_dir"], "r") as f:
-        enter_token = "\n"
-        for i, path in enumerate(f.readlines()):
-            exec(
-                "tea{}_modules.load_state_dict(torch.load(path.strip(enter_token)))".format(
-                    i
-                )
-            )
-
-    # make dataloaders
-    train_set = sb.dataio.dataloader.make_dataloader(
-        train_data, **hparams["train_dataloader_opts"]
-    )
-    valid_set = sb.dataio.dataloader.make_dataloader(
-        valid_data, **hparams["valid_dataloader_opts"]
-    )
-    test_set = sb.dataio.dataloader.make_dataloader(
-        test_data, **hparams["test_dataloader_opts"]
-    )
-
-    # run inference and save results
-    asr_brain.fit_save(train_set, valid_set, test_set)
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py
deleted file mode 120000
index 9b0f68bc85836025fe561d947e55f14b00390e0c..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py
+++ /dev/null
@@ -1 +0,0 @@
-../../timit_prepare.py
\ No newline at end of file
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_kd.py b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_kd.py
deleted file mode 100644
index e06c34b585c98541d5ca983a5136b4a957e44ed0..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_kd.py
+++ /dev/null
@@ -1,528 +0,0 @@
-#!/usr/bin/env python3
-
-"""Recipe for doing ASR with phoneme targets and joint seq2seq
-and CTC loss on the TIMIT dataset following a knowledge distillation scheme as
-reported in " Distilling Knowledge from Ensembles of Acoustic Models for Joint
-CTC-Attention End-to-End Speech Recognition", Yan Gao et al.
-
-To run this recipe, do the following:
-> python experiment.py hyperparams.yaml --data_folder /path/to/TIMIT
-
-Authors
- * Yan Gao 2021
- * Titouan Parcollet 2021
-"""
-
-import sys
-import torch
-import h5py
-import speechbrain as sb
-from speechbrain.utils.distributed import run_on_main, if_main_process
-from hyperpyyaml import load_hyperpyyaml
-
-
-# Define training procedure
-class ASR(sb.Brain):
-    def compute_forward(self, batch, stage):
-        batch = batch.to(self.device)
-        wavs, wav_lens = batch.sig
-        phns_bos, _ = batch.phn_encoded_bos
-
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "env_corrupt"):
-                wavs_noise = self.hparams.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                phns_bos = torch.cat([phns_bos, phns_bos])
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
-
-        feats = self.hparams.compute_features(wavs)
-        feats = self.modules.normalize(feats, wav_lens)
-        x = self.modules.enc(feats)
-
-        # output layer for ctc log-probabilities
-        logits = self.modules.ctc_lin(x)
-        p_ctc = self.hparams.log_softmax(logits)
-
-        e_in = self.modules.emb(phns_bos)
-        h, _ = self.modules.dec(e_in, x, wav_lens)
-
-        # output layer for seq2seq log-probabilities
-        logits = self.modules.seq_lin(h)
-        p_seq = self.hparams.log_softmax(logits)
-
-        if stage == sb.Stage.VALID:
-            hyps, scores = self.hparams.greedy_searcher(x, wav_lens)
-            return p_ctc, p_seq, wav_lens, hyps
-
-        elif stage == sb.Stage.TEST:
-            hyps, scores = self.hparams.beam_searcher(x, wav_lens)
-            return p_ctc, p_seq, wav_lens, hyps
-
-        return p_ctc, p_seq, wav_lens
-
-    def def_tea_name(self):
-        # define teacher variable name
-        tea_name = []
-        for tea_num in range(self.hparams.num_tea):
-            tea = "t{}".format(tea_num)
-            tea_name.append(tea)
-        return tea_name
-
-    def re_format(self, data_dict):
-        item_tea_list = [None, None, None, None]
-        tea_name = self.def_tea_name()
-        for tea_num in range(self.hparams.num_tea):
-            for i in range(4):
-                item_tea = data_dict[str(self.step)][tea_name[tea_num]][
-                    self.hparams.tea_keys[i]
-                ][()]
-
-                if self.hparams.tea_keys[i].startswith("wer"):
-                    item_tea = torch.tensor(item_tea)
-                else:
-                    item_tea = torch.from_numpy(item_tea)
-
-                item_tea = item_tea.to(self.device)
-                item_tea = torch.unsqueeze(item_tea, 0)
-                if tea_num == 0:
-                    item_tea_list[i] = item_tea
-                else:
-                    item_tea_list[i] = torch.cat(
-                        [item_tea_list[i], item_tea], 0
-                    )
-        return item_tea_list
-
-    def compute_objectives(self, predictions, batch, stage):
-        if stage == sb.Stage.TRAIN:
-            p_ctc, p_seq, wav_lens = predictions
-        else:
-            p_ctc, p_seq, wav_lens, hyps = predictions
-
-        ids = batch.id
-        phns_eos, phn_lens_eos = batch.phn_encoded_eos
-        phns, phn_lens = batch.phn_encoded
-
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            phns_eos = torch.cat([phns_eos, phns_eos], dim=0)
-            phn_lens_eos = torch.cat([phn_lens_eos, phn_lens_eos], dim=0)
-
-        # normal supervised training
-        loss_ctc_nor = self.hparams.ctc_cost(p_ctc, phns, wav_lens, phn_lens)
-        loss_seq_nor = self.hparams.seq_cost(p_seq, phns_eos, phn_lens_eos)
-
-        # load teacher inference results
-        data_dict = (
-            self.train_dict
-            if stage == sb.Stage.TRAIN
-            else self.valid_dict
-            if stage == sb.Stage.VALID
-            else self.test_dict
-        )
-
-        item_tea_list = self.re_format(data_dict)
-        p_ctc_tea, p_seq_tea, wer_ctc_tea, wer_tea = [
-            item for item in item_tea_list
-        ]
-
-        # Strategy "average": average losses of teachers when doing distillation.
-        # Strategy "best": choosing the best teacher based on WER.
-        # Strategy "weighted": assigning weights to teachers based on WER.
-        if self.hparams.strategy == "best":
-            # tea_ce for kd
-            wer_scores, indx = torch.min(wer_tea, dim=0)
-            indx = list(indx.cpu().numpy())
-
-            # select the best teacher for each sentence
-            tea_seq2seq_pout = None
-            for stn_indx, tea_indx in enumerate(indx):
-                s2s_one = p_seq_tea[tea_indx][stn_indx]
-                s2s_one = torch.unsqueeze(s2s_one, 0)
-                if stn_indx == 0:
-                    tea_seq2seq_pout = s2s_one
-                else:
-                    tea_seq2seq_pout = torch.cat([tea_seq2seq_pout, s2s_one], 0)
-
-        apply_softmax = torch.nn.Softmax(dim=0)
-
-        if (
-            self.hparams.strategy == "best"
-            or self.hparams.strategy == "weighted"
-        ):
-            # mean wer for ctc
-            tea_wer_ctc_mean = wer_ctc_tea.mean(1)
-            tea_acc_main = 100 - tea_wer_ctc_mean
-
-            # normalise weights via Softmax function
-            tea_acc_softmax = apply_softmax(tea_acc_main)
-
-        if self.hparams.strategy == "weighted":
-            # mean wer for ce
-            tea_wer_mean = wer_tea.mean(1)
-            tea_acc_ce_main = 100 - tea_wer_mean
-
-            # normalise weights via Softmax function
-            tea_acc_ce_softmax = apply_softmax(tea_acc_ce_main)
-
-        # kd loss
-        ctc_loss_list = None
-        ce_loss_list = None
-        for tea_num in range(self.hparams.num_tea):
-            # ctc
-            p_ctc_tea_one = p_ctc_tea[tea_num]
-            # calculate CTC distillation loss of one teacher
-            loss_ctc_one = self.hparams.ctc_cost_kd(
-                p_ctc, p_ctc_tea_one, wav_lens, device=self.device
-            )
-            loss_ctc_one = torch.unsqueeze(loss_ctc_one, 0)
-            if tea_num == 0:
-                ctc_loss_list = loss_ctc_one
-            else:
-                ctc_loss_list = torch.cat([ctc_loss_list, loss_ctc_one])
-
-            # ce
-            p_seq_tea_one = p_seq_tea[tea_num]
-            # calculate CE distillation loss of one teacher
-            loss_seq_one = self.hparams.seq_cost_kd(
-                p_seq, p_seq_tea_one, phn_lens_eos
-            )
-            loss_seq_one = torch.unsqueeze(loss_seq_one, 0)
-            if tea_num == 0:
-                ce_loss_list = loss_seq_one
-            else:
-                ce_loss_list = torch.cat([ce_loss_list, loss_seq_one])
-
-        # kd loss
-        if self.hparams.strategy == "average":
-            # get average value of losses from all teachers (CTC and CE loss)
-            ctc_loss_kd = ctc_loss_list.mean(0)
-            seq2seq_loss_kd = ce_loss_list.mean(0)
-        else:
-            # assign weights to different teachers (CTC loss)
-            ctc_loss_kd = (tea_acc_softmax * ctc_loss_list).sum(0)
-            if self.hparams.strategy == "best":
-                # only use the best teacher to compute CE loss
-                seq2seq_loss_kd = self.hparams.seq_cost_kd(
-                    p_seq, tea_seq2seq_pout, phn_lens_eos
-                )
-            if self.hparams.strategy == "weighted":
-                # assign weights to different teachers (CE loss)
-                seq2seq_loss_kd = (tea_acc_ce_softmax * ce_loss_list).sum(0)
-
-        # total loss
-        # combine normal supervised training
-        loss_ctc = (
-            self.hparams.temperature
-            * self.hparams.temperature
-            * self.hparams.alpha
-            * ctc_loss_kd
-            + (1 - self.hparams.alpha) * loss_ctc_nor
-        )
-        loss_seq = (
-            self.hparams.temperature
-            * self.hparams.temperature
-            * self.hparams.alpha
-            * seq2seq_loss_kd
-            + (1 - self.hparams.alpha) * loss_seq_nor
-        )
-
-        loss = (
-            self.hparams.ctc_weight * loss_ctc
-            + (1 - self.hparams.ctc_weight) * loss_seq
-        )
-
-        # Record losses for posterity
-        if stage != sb.Stage.TRAIN:
-            self.ctc_metrics.append(ids, p_ctc, phns, wav_lens, phn_lens)
-            self.seq_metrics.append(ids, p_seq, phns_eos, phn_lens_eos)
-            self.per_metrics.append(
-                ids, hyps, phns, None, phn_lens, self.label_encoder.decode_ndim,
-            )
-
-        return loss
-
-    def fit_batch(self, batch):
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
-    def on_stage_start(self, stage, epoch):
-        self.ctc_metrics = self.hparams.ctc_stats()
-        self.seq_metrics = self.hparams.seq_stats()
-
-        if stage != sb.Stage.TRAIN:
-            self.per_metrics = self.hparams.per_stats()
-
-    def on_stage_end(self, stage, stage_loss, epoch):
-        if stage == sb.Stage.TRAIN:
-            self.train_loss = stage_loss
-        else:
-            per = self.per_metrics.summarize("error_rate")
-
-        if stage == sb.Stage.VALID:
-            old_lr, new_lr = self.hparams.lr_annealing(per)
-            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
-
-            self.hparams.train_logger.log_stats(
-                stats_meta={"epoch": epoch, "lr": old_lr},
-                train_stats={"loss": self.train_loss},
-                valid_stats={
-                    "loss": stage_loss,
-                    "ctc_loss": self.ctc_metrics.summarize("average"),
-                    "seq_loss": self.seq_metrics.summarize("average"),
-                    "PER": per,
-                },
-            )
-            self.checkpointer.save_and_keep_only(
-                meta={"PER": per}, min_keys=["PER"]
-            )
-
-        if stage == sb.Stage.TEST:
-            self.hparams.train_logger.log_stats(
-                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
-                test_stats={"loss": stage_loss, "PER": per},
-            )
-            if if_main_process():
-                with open(self.hparams.test_wer_file, "w") as w:
-                    w.write("CTC loss stats:\n")
-                    self.ctc_metrics.write_stats(w)
-                    w.write("\nseq2seq loss stats:\n")
-                    self.seq_metrics.write_stats(w)
-                    w.write("\nPER stats:\n")
-                    self.per_metrics.write_stats(w)
-                    print(
-                        "CTC, seq2seq, and PER stats written to file",
-                        self.hparams.test_wer_file,
-                    )
-
-
-def data_io_prep(hparams):
-    "Creates the datasets and their data processing pipelines."
-    data_folder = hparams["data_folder"]
-    # 1. Declarations:
-    train_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["train_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    if hparams["sorting"] == "ascending":
-        # we sort training data to speed up training and get better results.
-        train_data = train_data.filtered_sorted(sort_key="duration")
-        # when sorting do not shuffle in dataloader ! otherwise is pointless
-        hparams["train_dataloader_opts"]["shuffle"] = False
-
-    elif hparams["sorting"] == "descending":
-        train_data = train_data.filtered_sorted(
-            sort_key="duration", reverse=True
-        )
-        # when sorting do not shuffle in dataloader ! otherwise is pointless
-        hparams["train_dataloader_opts"]["shuffle"] = False
-
-    elif hparams["sorting"] == "random":
-        pass
-
-    else:
-        raise NotImplementedError(
-            "sorting must be random, ascending or descending"
-        )
-
-    valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["valid_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    valid_data = valid_data.filtered_sorted(sort_key="duration")
-
-    test_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["test_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    test_data = test_data.filtered_sorted(sort_key="duration")
-
-    datasets = [train_data, valid_data, test_data]
-    label_encoder = sb.dataio.encoder.CTCTextEncoder()
-
-    # 2. Define audio pipeline:
-    @sb.utils.data_pipeline.takes("wav")
-    @sb.utils.data_pipeline.provides("sig")
-    def audio_pipeline(wav):
-        sig = sb.dataio.dataio.read_audio(wav)
-        return sig
-
-    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
-
-    # 3. Define text pipeline:
-    @sb.utils.data_pipeline.takes("phn")
-    @sb.utils.data_pipeline.provides(
-        "phn_list",
-        "phn_encoded_list",
-        "phn_encoded",
-        "phn_encoded_eos",
-        "phn_encoded_bos",
-    )
-    def text_pipeline(phn):
-        phn_list = phn.strip().split()
-        yield phn_list
-        phn_encoded_list = label_encoder.encode_sequence(phn_list)
-        yield phn_encoded_list
-        phn_encoded = torch.LongTensor(phn_encoded_list)
-        yield phn_encoded
-        phn_encoded_eos = torch.LongTensor(
-            label_encoder.append_eos_index(phn_encoded_list)
-        )
-        yield phn_encoded_eos
-        phn_encoded_bos = torch.LongTensor(
-            label_encoder.prepend_bos_index(phn_encoded_list)
-        )
-        yield phn_encoded_bos
-
-    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
-
-    # 3. Fit encoder:
-    # NOTE: In this minimal example, also update from valid data
-
-    label_encoder.update_from_didataset(train_data, output_key="phn_list")
-    if (
-        hparams["blank_index"] != hparams["bos_index"]
-        or hparams["blank_index"] != hparams["eos_index"]
-    ):
-        label_encoder.insert_blank(index=hparams["blank_index"])
-
-    if hparams["bos_index"] == hparams["eos_index"]:
-        label_encoder.insert_bos_eos(
-            bos_label="<eos-bos>",
-            eos_label="<eos-bos>",
-            bos_index=hparams["bos_index"],
-        )
-    else:
-        label_encoder.insert_bos_eos(
-            bos_label="<bos>",
-            eos_label="<eos>",
-            bos_index=hparams["bos_index"],
-            eos_index=hparams["eos_index"],
-        )
-
-    # 4. Set output:
-    sb.dataio.dataset.set_output_keys(
-        datasets,
-        ["id", "sig", "phn_encoded", "phn_encoded_eos", "phn_encoded_bos"],
-    )
-
-    return train_data, valid_data, test_data, label_encoder
-
-
-def load_teachers(hparams):
-    """
-    Load results of inference of teacher models stored on disk.
-    Note: Run experiment_save_teachers.py beforehand to generate .hdf5 files.
-    """
-    path = hparams["tea_infer_dir"] + "/tea_infer_{}batch.hdf5".format(
-        hparams["batch_size"]
-    )
-    f = h5py.File(path, "r")
-    train_dict = f["train"]
-    valid_dict = f["valid"]
-    test_dict = f["test"]
-
-    return train_dict, valid_dict, test_dict
-
-
-def st_load(hparams, asr_brain):
-    """
-    load pre-trained student model and remove decoder layer.
-    """
-    print("loading pre-trained student model...")
-    chpt_path = hparams["pretrain_st_dir"] + "/model.ckpt"
-    weight_dict = torch.load(chpt_path)
-    # del the decoder layer
-    key_list = []
-    for k in weight_dict.keys():
-        key_list.append(k)
-    for k in key_list:
-        if not k.startswith("0"):
-            del weight_dict[k]
-
-    # loading weights
-    asr_brain.hparams.model.load_state_dict(weight_dict, strict=False)
-
-
-if __name__ == "__main__":
-    # CLI:
-    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
-
-    # Load hyperparameters file with command-line overrides
-    with open(hparams_file) as fin:
-        hparams = load_hyperpyyaml(fin, overrides)
-
-    # Dataset prep (parsing TIMIT and annotation into csv files)
-    from timit_prepare import prepare_timit  # noqa
-
-    # Initialize ddp (useful only for multi-GPU DDP training)
-    sb.utils.distributed.ddp_init_group(run_opts)
-
-    # multi-gpu (ddp) save data preparation
-    run_on_main(
-        prepare_timit,
-        kwargs={
-            "data_folder": hparams["data_folder"],
-            "save_json_train": hparams["train_annotation"],
-            "save_json_valid": hparams["valid_annotation"],
-            "save_json_test": hparams["test_annotation"],
-            "skip_prep": hparams["skip_prep"],
-        },
-    )
-
-    # Dataset IO prep: creating Dataset objects and proper encodings for phones
-    train_data, valid_data, test_data, label_encoder = data_io_prep(hparams)
-
-    # Create experiment directory
-    sb.create_experiment_directory(
-        experiment_directory=hparams["output_folder"],
-        hyperparams_to_save=hparams_file,
-        overrides=overrides,
-    )
-
-    # Trainer initialization
-    asr_brain = ASR(
-        modules=hparams["modules"],
-        opt_class=hparams["opt_class"],
-        hparams=hparams,
-        run_opts=run_opts,
-        checkpointer=hparams["checkpointer"],
-    )
-    asr_brain.label_encoder = label_encoder
-
-    # load teacher models
-    train_dict, valid_dict, test_dict = load_teachers(hparams)
-    asr_brain.train_dict = train_dict
-    asr_brain.valid_dict = valid_dict
-    asr_brain.test_dict = test_dict
-
-    if hparams["pretrain"]:
-        # load pre-trained student model except last layer
-        if hparams["epoch_counter"].current == 0:
-            st_load(hparams, asr_brain)
-
-    # Training/validation loop
-    asr_brain.fit(
-        asr_brain.hparams.epoch_counter,
-        train_data,
-        valid_data,
-        train_loader_kwargs=hparams["train_dataloader_opts"],
-        valid_loader_kwargs=hparams["valid_dataloader_opts"],
-    )
-
-    # Test
-    asr_brain.evaluate(
-        test_data,
-        min_key="PER",
-        test_loader_kwargs=hparams["test_dataloader_opts"],
-    )
diff --git a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py b/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py
deleted file mode 100644
index 1d247a420398502d324b68eac9acb59357695043..0000000000000000000000000000000000000000
--- a/recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py
+++ /dev/null
@@ -1,337 +0,0 @@
-#!/usr/bin/env python3
-"""Recipe for doing ASR with phoneme targets and joint seq2seq
-and CTC loss on the TIMIT dataset following a knowledge distillation scheme as
-reported in " Distilling Knowledge from Ensembles of Acoustic Models for Joint
-CTC-Attention End-to-End Speech Recognition", Yan Gao et al.
-
-To run this recipe, do the following:
-> python experiment.py hyperparams.yaml --data_folder /path/to/TIMIT
-
-Authors
- * Yan Gao 2021
- * Titouan Parcollet 2021
-"""
-import os
-import sys
-import torch
-import speechbrain as sb
-from speechbrain.utils.distributed import run_on_main, if_main_process
-from hyperpyyaml import load_hyperpyyaml
-
-
-# Define training procedure
-class ASR(sb.Brain):
-    def compute_forward(self, batch, stage):
-        batch = batch.to(self.device)
-        wavs, wav_lens = batch.sig
-        phns_bos, _ = batch.phn_encoded_bos
-
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "env_corrupt"):
-                wavs_noise = self.hparams.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                phns_bos = torch.cat([phns_bos, phns_bos])
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
-
-        feats = self.hparams.compute_features(wavs)
-        feats = self.modules.normalize(feats, wav_lens)
-        x = self.modules.enc(feats)
-
-        # output layer for ctc log-probabilities
-        logits = self.modules.ctc_lin(x)
-        p_ctc = self.hparams.log_softmax(logits)
-
-        e_in = self.modules.emb(phns_bos)
-        h, _ = self.modules.dec(e_in, x, wav_lens)
-
-        # output layer for seq2seq log-probabilities
-        logits = self.modules.seq_lin(h)
-        p_seq = self.hparams.log_softmax(logits)
-
-        if stage == sb.Stage.VALID:
-            hyps, scores = self.hparams.greedy_searcher(x, wav_lens)
-            return p_ctc, p_seq, wav_lens, hyps
-
-        elif stage == sb.Stage.TEST:
-            hyps, scores = self.hparams.beam_searcher(x, wav_lens)
-            return p_ctc, p_seq, wav_lens, hyps
-
-        return p_ctc, p_seq, wav_lens
-
-    def compute_objectives(self, predictions, batch, stage):
-        if stage == sb.Stage.TRAIN:
-            p_ctc, p_seq, wav_lens = predictions
-        else:
-            p_ctc, p_seq, wav_lens, hyps = predictions
-
-        ids = batch.id
-        phns_eos, phn_lens_eos = batch.phn_encoded_eos
-        phns, phn_lens = batch.phn_encoded
-
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            phns_eos = torch.cat([phns_eos, phns_eos], dim=0)
-            phn_lens_eos = torch.cat([phn_lens_eos, phn_lens_eos], dim=0)
-
-        loss_ctc = self.hparams.ctc_cost(p_ctc, phns, wav_lens, phn_lens)
-        loss_seq = self.hparams.seq_cost(p_seq, phns_eos, phn_lens_eos)
-        loss = self.hparams.ctc_weight * loss_ctc
-        loss += (1 - self.hparams.ctc_weight) * loss_seq
-
-        # Record losses for posterity
-        if stage != sb.Stage.TRAIN:
-            self.ctc_metrics.append(ids, p_ctc, phns, wav_lens, phn_lens)
-            self.seq_metrics.append(ids, p_seq, phns_eos, phn_lens)
-            self.per_metrics.append(
-                ids, hyps, phns, None, phn_lens, self.label_encoder.decode_ndim,
-            )
-
-        return loss
-
-    def fit_batch(self, batch):
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
-    def on_stage_start(self, stage, epoch):
-        self.ctc_metrics = self.hparams.ctc_stats()
-        self.seq_metrics = self.hparams.seq_stats()
-
-        if stage != sb.Stage.TRAIN:
-            self.per_metrics = self.hparams.per_stats()
-
-    def on_stage_end(self, stage, stage_loss, epoch):
-        if stage == sb.Stage.TRAIN:
-            self.train_loss = stage_loss
-        else:
-            per = self.per_metrics.summarize("error_rate")
-
-        if stage == sb.Stage.VALID:
-            old_lr, new_lr = self.hparams.lr_annealing(per)
-            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
-
-            self.hparams.train_logger.log_stats(
-                stats_meta={"epoch": epoch, "lr": old_lr},
-                train_stats={"loss": self.train_loss},
-                valid_stats={
-                    "loss": stage_loss,
-                    "ctc_loss": self.ctc_metrics.summarize("average"),
-                    "seq_loss": self.seq_metrics.summarize("average"),
-                    "PER": per,
-                },
-            )
-            self.checkpointer.save_and_keep_only(
-                meta={"PER": per}, min_keys=["PER"]
-            )
-
-        if stage == sb.Stage.TEST:
-            self.hparams.train_logger.log_stats(
-                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
-                test_stats={"loss": stage_loss, "PER": per},
-            )
-            if if_main_process():
-                with open(self.hparams.test_wer_file, "w") as w:
-                    w.write("CTC loss stats:\n")
-                    self.ctc_metrics.write_stats(w)
-                    w.write("\nseq2seq loss stats:\n")
-                    self.seq_metrics.write_stats(w)
-                    w.write("\nPER stats:\n")
-                    self.per_metrics.write_stats(w)
-                    print(
-                        "CTC, seq2seq, and PER stats written to file",
-                        self.hparams.test_wer_file,
-                    )
-
-
-def data_io_prep(hparams):
-    "Creates the datasets and their data processing pipelines."
-    data_folder = hparams["data_folder"]
-    # 1. Declarations:
-    train_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["train_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    if hparams["sorting"] == "ascending":
-        # we sort training data to speed up training and get better results.
-        train_data = train_data.filtered_sorted(sort_key="duration")
-        # when sorting do not shuffle in dataloader ! otherwise is pointless
-        hparams["train_dataloader_opts"]["shuffle"] = False
-
-    elif hparams["sorting"] == "descending":
-        train_data = train_data.filtered_sorted(
-            sort_key="duration", reverse=True
-        )
-        # when sorting do not shuffle in dataloader ! otherwise is pointless
-        hparams["train_dataloader_opts"]["shuffle"] = False
-
-    elif hparams["sorting"] == "random":
-        pass
-
-    else:
-        raise NotImplementedError(
-            "sorting must be random, ascending or descending"
-        )
-
-    valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["valid_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    valid_data = valid_data.filtered_sorted(sort_key="duration")
-
-    test_data = sb.dataio.dataset.DynamicItemDataset.from_json(
-        json_path=hparams["test_annotation"],
-        replacements={"data_root": data_folder},
-    )
-    test_data = test_data.filtered_sorted(sort_key="duration")
-
-    datasets = [train_data, valid_data, test_data]
-    label_encoder = sb.dataio.encoder.CTCTextEncoder()
-
-    # 2. Define audio pipeline:
-    @sb.utils.data_pipeline.takes("wav")
-    @sb.utils.data_pipeline.provides("sig")
-    def audio_pipeline(wav):
-        sig = sb.dataio.dataio.read_audio(wav)
-        return sig
-
-    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
-
-    # 3. Define text pipeline:
-    @sb.utils.data_pipeline.takes("phn")
-    @sb.utils.data_pipeline.provides(
-        "phn_list",
-        "phn_encoded_list",
-        "phn_encoded",
-        "phn_encoded_eos",
-        "phn_encoded_bos",
-    )
-    def text_pipeline(phn):
-        phn_list = phn.strip().split()
-        yield phn_list
-        phn_encoded_list = label_encoder.encode_sequence(phn_list)
-        yield phn_encoded_list
-        phn_encoded = torch.LongTensor(phn_encoded_list)
-        yield phn_encoded
-        phn_encoded_eos = torch.LongTensor(
-            label_encoder.append_eos_index(phn_encoded_list)
-        )
-        yield phn_encoded_eos
-        phn_encoded_bos = torch.LongTensor(
-            label_encoder.prepend_bos_index(phn_encoded_list)
-        )
-        yield phn_encoded_bos
-
-    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
-
-    # 3. Fit encoder:
-    # Load or compute the label encoder
-    label_encoder_file = os.path.join(
-        hparams["save_folder"], "label_encoder.txt"
-    )
-    if os.path.exists(label_encoder_file):
-        label_encoder.load(label_encoder_file)
-    else:
-        label_encoder.update_from_didataset(train_data, output_key="phn_list")
-        if (
-            hparams["blank_index"] != hparams["bos_index"]
-            or hparams["blank_index"] != hparams["eos_index"]
-        ):
-            label_encoder.insert_blank(index=hparams["blank_index"])
-
-        if hparams["bos_index"] == hparams["eos_index"]:
-            label_encoder.insert_bos_eos(
-                bos_label="<eos-bos>",
-                eos_label="<eos-bos>",
-                bos_index=hparams["bos_index"],
-            )
-        else:
-            label_encoder.insert_bos_eos(
-                bos_label="<bos>",
-                eos_label="<eos>",
-                bos_index=hparams["bos_index"],
-                eos_index=hparams["eos_index"],
-            )
-        label_encoder.save(
-            os.path.join(hparams["save_folder"], "label_encoder.txt")
-        )
-
-    # 4. Set output:
-    sb.dataio.dataset.set_output_keys(
-        datasets,
-        ["id", "sig", "phn_encoded", "phn_encoded_eos", "phn_encoded_bos"],
-    )
-
-    return train_data, valid_data, test_data, label_encoder
-
-
-if __name__ == "__main__":
-    # CLI:
-    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
-
-    # Load hyperparameters file with command-line overrides
-    with open(hparams_file) as fin:
-        hparams = load_hyperpyyaml(fin, overrides)
-
-    # Dataset prep (parsing TIMIT and annotation into csv files)
-    from timit_prepare import prepare_timit  # noqa
-
-    # Initialize ddp (useful only for multi-GPU DDP training)
-    sb.utils.distributed.ddp_init_group(run_opts)
-
-    # multi-gpu (ddp) save data preparation
-    run_on_main(
-        prepare_timit,
-        kwargs={
-            "data_folder": hparams["data_folder"],
-            "save_json_train": hparams["train_annotation"],
-            "save_json_valid": hparams["valid_annotation"],
-            "save_json_test": hparams["test_annotation"],
-            "skip_prep": hparams["skip_prep"],
-        },
-    )
-
-    # Dataset IO prep: creating Dataset objects and proper encodings for phones
-    train_data, valid_data, test_data, label_encoder = data_io_prep(hparams)
-
-    # Create experiment directory
-    sb.create_experiment_directory(
-        experiment_directory=hparams["output_folder"],
-        hyperparams_to_save=hparams_file,
-        overrides=overrides,
-    )
-
-    # Trainer initialization
-    asr_brain = ASR(
-        modules=hparams["modules"],
-        opt_class=hparams["opt_class"],
-        hparams=hparams,
-        run_opts=run_opts,
-        checkpointer=hparams["checkpointer"],
-    )
-    asr_brain.label_encoder = label_encoder
-
-    # Training/validation loop
-    asr_brain.fit(
-        asr_brain.hparams.epoch_counter,
-        train_data,
-        valid_data,
-        train_loader_kwargs=hparams["train_dataloader_opts"],
-        valid_loader_kwargs=hparams["valid_dataloader_opts"],
-    )
-
-    # Test
-    asr_brain.evaluate(
-        test_data,
-        min_key="PER",
-        test_loader_kwargs=hparams["test_dataloader_opts"],
-    )
diff --git a/recipes/TIMIT/ASR/transducer/README.md b/recipes/TIMIT/ASR/transducer/README.md
index fa15333110089831e2bdeb7caba1e261fd2f11db..54112e76cf4b742435362248addbcc3ba0260cf1 100644
--- a/recipes/TIMIT/ASR/transducer/README.md
+++ b/recipes/TIMIT/ASR/transducer/README.md
@@ -15,9 +15,13 @@ pip install numba
 # How to run
 Update the path to the dataset in the yaml config file and run the following.
 ```
-python train.py hparams/train.yaml
+python train.py hparams/train.yaml --data_folde=your/data/folder/TIMIT --jit
 ```
 
+**Note on Compilation**:
+Enabling the just-in-time (JIT) compiler with --jit significantly improves code performance, resulting in a 50-60% speed boost. We highly recommend utilizing the JIT compiler for optimal results.
+This speed improvement is observed specifically when using the CRDNN model.
+
 # Results
 
 | Release | hyperparams file | Val. PER | Test PER | Model link | GPUs |
diff --git a/recipes/TIMIT/ASR/transducer/hparams/train.yaml b/recipes/TIMIT/ASR/transducer/hparams/train.yaml
index 32379d6b45d8bad4329e26482804eb079510c05c..204297dc68591031a53ac15128d3194a551071c4 100644
--- a/recipes/TIMIT/ASR/transducer/hparams/train.yaml
+++ b/recipes/TIMIT/ASR/transducer/hparams/train.yaml
@@ -17,14 +17,18 @@ train_log: !ref <output_folder>/train_log.txt
 
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-openrir_folder: !ref <data_folder> # where storing the noisy data for augment
-train_annotation: !ref <data_folder>/train.json
-valid_annotation: !ref <data_folder>/dev.json
-test_annotation: !ref <data_folder>/test.json
+train_annotation: !ref <save_folder>/train.json
+valid_annotation: !ref <save_folder>/dev.json
+test_annotation: !ref <save_folder>/test.json
 skip_prep: False # Skip data preparation
 uppercase: False # Must be True when the TIMIT dataset is in the upper-case version
 
-# Training parameters
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
+####################### Training Parameters ####################################
 number_of_epochs: 50
 batch_size: 8
 lr: 1.0
@@ -36,7 +40,7 @@ n_fft: 400
 n_mels: 40
 
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -51,7 +55,7 @@ dnn_neurons: 512
 dec_neurons: 128
 
 # Outputs
-output_neurons: 40
+output_neurons: 42
 joint_dim: 128
 blank_index: 0
 
@@ -64,14 +68,17 @@ state_beam: 1.0
 expand_beam: 1.0
 
 # Dataloader options
+num_workers: 4
 train_dataloader_opts:
     batch_size: !ref <batch_size>
-
+    num_workers: !ref <num_workers>
 valid_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 test_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
@@ -85,17 +92,57 @@ compute_features: !new:speechbrain.lobes.features.Fbank
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <openrir_folder>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
 
 
 enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
@@ -146,9 +193,6 @@ output: !new:speechbrain.nnet.linear.Linear
     n_neurons: !ref <output_neurons>  # 42 phonemes + 1 blank
     bias: False
 
-# log_softmax: !new:speechbrain.nnet.activations.Softmax
-#    apply_log: True
-
 compute_cost: !name:speechbrain.nnet.losses.transducer_loss
     use_torchaudio: True
     blank_index: !ref <blank_index>
@@ -200,8 +244,6 @@ modules:
     Tjoint: !ref <Tjoint>
     output: !ref <output>
     normalize: !ref <normalize>
-    env_corrupt: !ref <env_corrupt>
-    augmentation: !ref <augmentation>
 
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
diff --git a/recipes/TIMIT/ASR/transducer/hparams/train_wav2vec.yaml b/recipes/TIMIT/ASR/transducer/hparams/train_wav2vec.yaml
index 2c1e16fffda73ce7cf4c195f75716dfb3afe34dd..9ead09f56ecc8d959455233397cc28ab11d875b6 100644
--- a/recipes/TIMIT/ASR/transducer/hparams/train_wav2vec.yaml
+++ b/recipes/TIMIT/ASR/transducer/hparams/train_wav2vec.yaml
@@ -22,36 +22,35 @@ freeze_wav2vec: False
 
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/TIMIT
-train_annotation: !ref <data_folder>/train.json
-valid_annotation: !ref <data_folder>/dev.json
-test_annotation: !ref <data_folder>/test.json
+train_annotation: !ref <save_folder>/train.json
+valid_annotation: !ref <save_folder>/dev.json
+test_annotation: !ref <save_folder>/test.json
 skip_prep: False # Skip data preparation
 uppercase: False # Must be True when the TIMIT dataset is in the upper-case version
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 batch_size: 8
 lr: 0.0003
 lr_wav2vec: 0.0001
 sorting: ascending # choose between ascending, descending and random
-auto_mix_prec: True
+precision: fp16 # bf16, fp16 or fp32
 
 # Feature parameters
 sample_rate: 16000
 # n_fft: 400
 # n_mels: 40
 
-
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 # dropout: 0.15
 dnn_blocks: 1
-dnn_neurons: 40
+dnn_neurons: 43
 dec_neurons: 128
 
 # Outputs
-output_neurons: 40
-joint_dim: 40
+output_neurons: 43
+joint_dim: 43
 blank_index: 0
 
 # Decoding parameters
@@ -75,16 +74,41 @@ test_dataloader_opts:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# compute_features: !new:speechbrain.lobes.features.Fbank
-#    sample_rate: !ref <sample_rate>
-#    n_fft: !ref <n_fft>
-#    n_mels: !ref <n_mels>
+############################## Augmentations ###################################
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
@@ -96,8 +120,6 @@ enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
     dnn_blocks: !ref <dnn_blocks>
     dnn_neurons: !ref <dnn_neurons>
 
-jit_module_keys: [enc]
-
 enc_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dnn_neurons>
     n_neurons: !ref <joint_dim>
@@ -129,9 +151,6 @@ output: !new:speechbrain.nnet.linear.Linear
     n_neurons: !ref <output_neurons>  # 42 phonemes + 1 blank
     bias: False
 
-#log_softmax: !new:speechbrain.nnet.activations.Softmax
-#    apply_log: True
-
 compute_cost: !name:speechbrain.nnet.losses.transducer_loss
     use_torchaudio: True
     blank_index: !ref <blank_index>
@@ -189,7 +208,6 @@ modules:
     dec_lin: !ref <dec_lin>
     Tjoint: !ref <Tjoint>
     output: !ref <output>
-    augmentation: !ref <augmentation>
 
 checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
     checkpoints_dir: !ref <save_folder>
diff --git a/recipes/TIMIT/ASR/transducer/train.py b/recipes/TIMIT/ASR/transducer/train.py
index ba582b0ffb8dd39a2f61ec7a8a66bd4f000c2a25..661f5be2cfd11a36a97d4ce45e13bd0920eb7966 100644
--- a/recipes/TIMIT/ASR/transducer/train.py
+++ b/recipes/TIMIT/ASR/transducer/train.py
@@ -3,7 +3,11 @@
 Transducer loss on the TIMIT dataset.
 
 To run this recipe, do the following:
-> python train.py hparams/train.yaml --data_folder /path/to/TIMIT
+> python train.py hparams/train.yaml --data_folder /path/to/TIMIT --jit
+
+Note on Compilation:
+Enabling the just-in-time (JIT) compiler with --jit significantly improves code performance,
+resulting in a 50-60% speed boost. We highly recommend utilizing the JIT compiler for optimal results.
 
 
 Authors
@@ -13,11 +17,10 @@ Authors
 """
 import os
 import sys
-import torch
 import logging
 import speechbrain as sb
 from hyperpyyaml import load_hyperpyyaml
-from speechbrain.utils.distributed import run_on_main, if_main_process
+from speechbrain.utils.distributed import run_on_main
 
 logger = logging.getLogger(__name__)
 
@@ -30,18 +33,10 @@ class ASR_Brain(sb.Brain):
         wavs, wav_lens = batch.sig
         phns, phn_lens = batch.phn_encoded
 
-        # Adding optional augmentation when specified:
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "env_corrupt"):
-                wavs_noise = self.hparams.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                batch.sig = wavs, wav_lens
-                phns = torch.cat([phns, phns], dim=0)
-                phn_lens = torch.cat([phn_lens, phn_lens])
-                batch.phn_encoded = phns, phn_lens
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            phns = self.hparams.wav_augment.replicate_labels(phns)
 
         # Model computations
         feats = self.hparams.compute_features(wavs)
@@ -66,8 +61,8 @@ class ASR_Brain(sb.Brain):
         logits = self.modules.output(joint)
 
         if stage == sb.Stage.VALID:
-            hyps, scores, _, _ = self.hparams.Greedysearcher(x)
-            return logits, hyps
+            hyps, _, _, _ = self.hparams.Greedysearcher(x)
+            return logits, wav_lens, hyps
 
         elif stage == sb.Stage.TEST:
             (
@@ -76,16 +71,22 @@ class ASR_Brain(sb.Brain):
                 nbest_hyps,
                 nbest_scores,
             ) = self.hparams.Beamsearcher(x)
-            return logits, best_hyps
-        return logits
+            return logits, wav_lens, best_hyps
+        return logits, wav_lens
 
     def compute_objectives(self, predictions, batch, stage):
         "Given the network predictions and targets computed the loss."
         ids = batch.id
-        _, wav_lens = batch.sig
         phns, phn_lens = batch.phn_encoded
-        if stage != sb.Stage.TRAIN:
-            predictions, hyps = predictions
+
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            phns = self.hparams.wav_augment.replicate_labels(phns)
+            phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens)
+
+        if stage == sb.Stage.TRAIN:
+            predictions, wav_lens = predictions
+        else:
+            predictions, wav_lens, hyps = predictions
 
         # Transducer loss use logits from RNN-T model.
         loss = self.hparams.compute_cost(predictions, phns, wav_lens, phn_lens)
@@ -132,16 +133,26 @@ class ASR_Brain(sb.Brain):
                 stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                 test_stats={"loss": stage_loss, "PER": per},
             )
-            if if_main_process():
-                with open(self.hparams.test_wer_file, "w") as w:
-                    w.write("Transducer loss stats:\n")
-                    self.transducer_metrics.write_stats(w)
-                    w.write("\nPER stats:\n")
-                    self.per_metrics.write_stats(w)
-                    print(
-                        "Transducer and PER stats written to file",
-                        self.hparams.test_wer_file,
-                    )
+            run_on_main(
+                save_metrics_to_file,
+                args=[
+                    self.hparams.test_wer_file,
+                    self.transducer_metrics,
+                    self.per_metrics,
+                ],
+            )
+
+
+def save_metrics_to_file(wer_file, transducer_metrics, per_metrics):
+    with open(wer_file, "w") as w:
+        w.write("Transducer loss stats:\n")
+        transducer_metrics.write_stats(w)
+        w.write("\nPER stats:\n")
+        per_metrics.write_stats(w)
+        print(
+            "Transducer and PER stats written to file",
+            hparams["test_wer_file"],
+        )
 
 
 def dataio_prep(hparams):
@@ -264,6 +275,7 @@ if __name__ == "__main__":
             "uppercase": hparams["uppercase"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
 
     # Dataset IO prep: creating Dataset objects and proper encodings for phones
     train_data, valid_data, test_data, label_encoder = dataio_prep(hparams)
diff --git a/recipes/TIMIT/ASR/transducer/train_wav2vec.py b/recipes/TIMIT/ASR/transducer/train_wav2vec.py
index ae8d89122d40b5c9c0f2fc9d41675468726e0c97..cc6d99faddf07653f12c3b380849c1706e072de6 100644
--- a/recipes/TIMIT/ASR/transducer/train_wav2vec.py
+++ b/recipes/TIMIT/ASR/transducer/train_wav2vec.py
@@ -13,7 +13,6 @@ Authors
 """
 import os
 import sys
-import torch
 import logging
 import speechbrain as sb
 from hyperpyyaml import load_hyperpyyaml
@@ -30,10 +29,10 @@ class ASR_Brain(sb.Brain):
         wavs, wav_lens = batch.sig
         phns, phn_lens = batch.phn_encoded
 
-        # Adding optional augmentation when specified:
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            phns = self.hparams.wav_augment.replicate_labels(phns)
 
         # Model computations
         feats = self.modules.wav2vec2(wavs, wav_lens)
@@ -57,8 +56,8 @@ class ASR_Brain(sb.Brain):
         logits = self.modules.output(joint)
 
         if stage == sb.Stage.VALID:
-            hyps, scores, _, _ = self.hparams.Greedysearcher(x)
-            return logits, hyps
+            hyps, _, _, _ = self.hparams.Greedysearcher(x)
+            return logits, wav_lens, hyps
 
         elif stage == sb.Stage.TEST:
             (
@@ -67,16 +66,22 @@ class ASR_Brain(sb.Brain):
                 nbest_hyps,
                 nbest_scores,
             ) = self.hparams.Beamsearcher(x)
-            return logits, best_hyps
-        return logits
+            return logits, wav_lens, best_hyps
+        return logits, wav_lens
 
     def compute_objectives(self, predictions, batch, stage):
         "Given the network predictions and targets computed the loss."
         ids = batch.id
-        _, wav_lens = batch.sig
         phns, phn_lens = batch.phn_encoded
-        if stage != sb.Stage.TRAIN:
-            predictions, hyps = predictions
+
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            phns = self.hparams.wav_augment.replicate_labels(phns)
+            phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens)
+
+        if stage == sb.Stage.TRAIN:
+            predictions, wav_lens = predictions
+        else:
+            predictions, wav_lens, hyps = predictions
 
         # Transducer loss use logits from RNN-T model.
         loss = self.hparams.compute_cost(predictions, phns, wav_lens, phn_lens)
@@ -146,61 +151,6 @@ class ASR_Brain(sb.Brain):
                         self.hparams.test_wer_file,
                     )
 
-    def fit_batch(self, batch):
-        """Fit one batch, override to do multiple updates.
-
-        The default implementation depends on a few methods being defined
-        with a particular behavior:
-
-        * ``compute_forward()``
-        * ``compute_objectives()``
-
-        Also depends on having optimizers passed at initialization.
-
-        Arguments
-        ---------
-        batch : list of torch.Tensors
-            Batch of data to use for training. Default implementation assumes
-            this batch has two elements: inputs and targets.
-
-        Returns
-        -------
-        detached loss
-        """
-        # Managing automatic mixed precision
-        if self.auto_mix_prec:
-
-            self.wav2vec_optimizer.zero_grad()
-            self.adam_optimizer.zero_grad()
-
-            with torch.cuda.amp.autocast():
-                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-
-            self.scaler.scale(loss).backward()
-            self.scaler.unscale_(self.wav2vec_optimizer)
-            self.scaler.unscale_(self.adam_optimizer)
-
-            if self.check_gradients(loss):
-                self.scaler.step(self.wav2vec_optimizer)
-                self.scaler.step(self.adam_optimizer)
-
-            self.scaler.update()
-        else:
-            outputs = self.compute_forward(batch, sb.Stage.TRAIN)
-
-            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
-            loss.backward()
-
-            if self.check_gradients(loss):
-                self.wav2vec_optimizer.step()
-                self.adam_optimizer.step()
-
-            self.wav2vec_optimizer.zero_grad()
-            self.adam_optimizer.zero_grad()
-
-        return loss.detach()
-
     def init_optimizers(self):
         "Initializes the wav2vec2 optimizer and model optimizer"
         self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
@@ -216,9 +166,10 @@ class ASR_Brain(sb.Brain):
             )
             self.checkpointer.add_recoverable("adam_opt", self.adam_optimizer)
 
-    def zero_grad(self, set_to_none=False):
-        self.wav2vec_optimizer.zero_grad(set_to_none)
-        self.adam_optimizer.zero_grad(set_to_none)
+        self.optimizers_dict = {
+            "wav2vec": self.wav2vec_optimizer,
+            "adam": self.adam_optimizer,
+        }
 
 
 def dataio_prep(hparams):
diff --git a/recipes/TIMIT/Alignment/hparams/train.yaml b/recipes/TIMIT/Alignment/hparams/train.yaml
index 24a2088b291cce0cfe43db8d9ef9f674093bf340..aaf06b7ffc0f61291d5de751c9fe3094876fdef7 100644
--- a/recipes/TIMIT/Alignment/hparams/train.yaml
+++ b/recipes/TIMIT/Alignment/hparams/train.yaml
@@ -20,7 +20,7 @@ valid_annotation: !ref <data_folder>/dev.json
 test_annotation: !ref <data_folder>/test.json
 skip_prep: False # Skip data prep
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 10
 batch_size: 256
 lr: 0.0003
@@ -40,7 +40,7 @@ phn_set: 60 # {60, 48, 39}
 output_neurons: 183
 blank_index: 182
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dnn_blocks: 1
 dnn_neurons: 2000
@@ -55,9 +55,52 @@ dataloader_options:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: False
+    concat_original: False
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 compute_features: !new:speechbrain.lobes.features.Fbank
     context: True
diff --git a/recipes/TIMIT/Alignment/train.py b/recipes/TIMIT/Alignment/train.py
index 7cd1a5d1b7f718a90502a8c3a1f4fc6aff9b24af..c874504d1a6ac97cf5703ced757176e93156b3d1 100644
--- a/recipes/TIMIT/Alignment/train.py
+++ b/recipes/TIMIT/Alignment/train.py
@@ -26,14 +26,9 @@ class AlignBrain(sb.Brain):
         batch = batch.to(self.device)
         wavs, wav_lens = batch.sig
 
-        # Adding augmentation when specified:
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
 
         feats = self.hparams.compute_features(wavs)
         if hasattr(self.hparams, "normalize"):
@@ -51,9 +46,9 @@ class AlignBrain(sb.Brain):
         phns, phn_lens = batch.phn_encoded
         phn_ends, _ = batch.phn_ends
 
-        if stage == sb.Stage.TRAIN and hasattr(self.modules, "env_corrupt"):
-            phns = torch.cat([phns, phns], dim=0)
-            phn_lens = torch.cat([phn_lens, phn_lens], dim=0)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            phns = self.hparams.wav_augment.replicate_labels(phns)
+            phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens)
 
         phns, phn_lens = phns.to(self.device), phn_lens.to(self.device)
         phns_orig = sb.utils.data_utils.undo_padding(phns, phn_lens)
diff --git a/recipes/Tedlium2/ASR/transformer/README.md b/recipes/Tedlium2/ASR/transformer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f8b52ec08fff938b12302d08b666dea1bcdaa60e
--- /dev/null
+++ b/recipes/Tedlium2/ASR/transformer/README.md
@@ -0,0 +1,58 @@
+# Tedlium2 ASR with Transformers
+This folder contains the scripts to train a Transformer-based speech recognizer.
+
+You can download Tedlium2 at https://lium.univ-lemans.fr/ted-lium2/
+
+# How to Run:
+
+1. Begin by training the tokenizer:
+
+```shell
+cd ../../Tokenizer
+python train.py hparams/tedlium2_500_bpe.yaml --data_folder /path/to/tedlium2 --clipped_utt_folder /path/to/clipped_folder
+```
+
+Please, read  ../../Tokenizer/README.md before proceeding.
+This training script will handle data preparation and tokenizer training. Note that this script prepares the data in a format suitable for training the ASR model.
+Specifically, it segments the entire TED recording into individual utterance-level recordings, resulting in approximately 46 gigabytes of data.
+The CSV files generated for training, development, and testing are also utilized in ASR training.
+
+**IMPORTANT:** Ensure you complete this step before proceeding to train the ASR Model.
+
+2. Proceed to train the ASR model:
+
+```shell
+python train.py hparams/branchformer_large.yaml --pretrained_tokenizer_file /path/to/tokenizer --data_folder /path/to/tedlium2 --clipped_utt_folder /path/to/clipped_folder
+```
+
+This script relies on the data manifest files prepared in step 1.
+
+
+# Results
+
+| Release | hyperparams file |  Test WER (No LM) | HuggingFace link | Model link | GPUs |
+|:-------------:|:-------------:|:-------------:|:---------------------------:| :-----:| :-----:|
+| 24-10-23 | branchformer_large.yaml | 8.11 | [HuggingFace](https://huggingface.co/speechbrain/asr-branchformer-large-tedlium2) | [DropBox](https://www.dropbox.com/sh/el523uofs96czfi/AADgTd838pKo2aR8fhqVOh-Oa?dl=0) | 1xA100 80GB |
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+# Training Time
+
+It takes about 15 minutes per epoch for the branchformer large model.
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
diff --git a/recipes/Tedlium2/ASR/transformer/hparams/branchformer_large.yaml b/recipes/Tedlium2/ASR/transformer/hparams/branchformer_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f9f924c71069e6dadffcd638d25692efd1597c5
--- /dev/null
+++ b/recipes/Tedlium2/ASR/transformer/hparams/branchformer_large.yaml
@@ -0,0 +1,326 @@
+# ############################################################################
+# Model: E2E ASR with Transformer
+# Encoder: Branchformer Encoder
+# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch
+# Tokens: unigram
+# losses: CTC + KLdiv (Label Smoothing loss)
+# Training: Tedlium2
+# Authors:  Titouan Parcollet, Shucong Zhang
+# ############################################################################
+# Seed needs to be set at top of yaml, before objects with parameters are made
+
+seed: 3407
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/branchformer_large/<seed>
+output_wer_folder: !ref <output_folder>/
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Data files
+# IMPORTANT: before running this script, you need to train the tokenizer (refer to ../../Tokenizer/README.md for details).
+# Or use the pretrained tokenizer provided in the DropBox folder.
+#
+# The tokenizer is stored in ../../Tokenizer/results/tokenizer/tokenizer.ckpt
+
+# Please ensure that the tokenizer has been trained before (refer to ../../Tokenizer/README.md for details).
+pretrained_tokenizer_file: !PLACEHOLDER
+clipped_utt_folder: !PLACEHOLDER # folder where to store the clipped utterence-level recordings
+data_folder: !PLACEHOLDER # e.g, /path/to/TEDLIUM_release2
+skip_prep: False
+avoid_if_shorter_than: 1.0
+
+train_csv: !ref <output_folder>/train/train.csv
+valid_csv: !ref <output_folder>/dev/dev.csv
+test_csv:
+    - !ref <output_folder>/test/test.csv
+
+####################### Training Parameters ####################################
+# To make Transformers converge, the global bath size should be large enough.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
+# Empirically, we found that this value should be >= 128.
+# Please, set your parameters accordingly.
+precision: fp32 # bf16, fp16 or fp32
+number_of_epochs: 120
+batch_size: 16 # This works for 2x GPUs with 32GB
+ctc_weight: 0.3
+grad_accumulation_factor: 2
+max_grad_norm: 5.0
+loss_reduction: 'batchmean'
+sorting: random
+num_workers: 4
+
+# stages related parameters
+# stage_one_epochs: 90
+lr_adam: 0.0005
+weight_decay: 0.05
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 400
+n_mels: 80
+win_length: 25
+
+# This setup works well for A100 80GB GPU, adapts it to your needs.
+# Or turn it off (but training speed will decrease)
+dynamic_batching: True
+max_batch_length_train: 800
+max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
+num_bucket: 200
+shuffle: True
+batch_ordering: random
+max_batch_ex: 128
+
+dynamic_batch_sampler_train:
+    max_batch_length: !ref <max_batch_length_train>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+dynamic_batch_sampler_valid:
+    max_batch_length: !ref <max_batch_length_val>
+    num_buckets: !ref <num_bucket>
+    shuffle: !ref <shuffle>
+    batch_ordering: !ref <batch_ordering>
+    max_batch_ex: !ref <max_batch_ex>
+
+
+# Dataloader options
+train_dataloader_opts:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: !ref <num_workers>
+
+valid_dataloader_opts:
+    batch_size: 1
+
+test_dataloader_opts:
+    batch_size: 1
+
+####################### Model Parameters ###########################
+# Transformer
+d_model: 512
+nhead: 8
+num_encoder_layers: 18
+num_decoder_layers: 6
+csgu_linear_units: 3072
+csgu_kernel_size: 31
+transformer_dropout: 0.1
+activation: !name:torch.nn.GELU
+output_neurons: 500
+
+# Outputs
+blank_index: 0
+label_smoothing: 0.1
+pad_index: 0
+bos_index: 1
+eos_index: 2
+
+# Decoding parameters
+min_decode_ratio: 0.0
+max_decode_ratio: 1.0
+valid_search_interval: 10
+valid_beam_size: 20
+test_beam_size: 20
+ctc_weight_decode: 0.3
+
+############################## models ################################
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+    input_shape: (8, 10, 80)
+    num_blocks: 2
+    num_layers_per_block: 1
+    out_channels: (64, 32)
+    kernel_sizes: (3, 3)
+    strides: (2, 2)
+    residuals: (False, False)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+    input_size: 640
+    tgt_vocab: !ref <output_neurons>
+    d_model: !ref <d_model>
+    nhead: !ref <nhead>
+    num_encoder_layers: !ref <num_encoder_layers>
+    num_decoder_layers: !ref <num_decoder_layers>
+    dropout: !ref <transformer_dropout>
+    activation: !ref <activation>
+    branchformer_activation: !ref <activation>
+    encoder_module: branchformer
+    csgu_linear_units: !ref <csgu_linear_units>
+    kernel_size: !ref <csgu_kernel_size>
+    attention_type: RelPosMHAXL
+    normalize_before: True
+    causal: False
+
+tokenizer: !new:sentencepiece.SentencePieceProcessor
+
+ctc_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+seq_lin: !new:speechbrain.nnet.linear.Linear
+    input_size: !ref <d_model>
+    n_neurons: !ref <output_neurons>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+    norm_type: global
+    update_until_epoch: 4
+
+modules:
+    CNN: !ref <CNN>
+    Transformer: !ref <Transformer>
+    seq_lin: !ref <seq_lin>
+    ctc_lin: !ref <ctc_lin>
+    normalize: !ref <normalize>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+
+Adam: !name:torch.optim.AdamW
+    lr: !ref <lr_adam>
+    betas: (0.9, 0.98)
+    eps: 0.000000001
+    weight_decay: !ref <weight_decay>
+
+# Scorer
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <ctc_scorer>]
+    weights:
+        ctc: !ref <ctc_weight_decode>
+
+
+valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <valid_beam_size>
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer>
+
+test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
+    modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
+    bos_index: !ref <bos_index>
+    eos_index: !ref <eos_index>
+    min_decode_ratio: !ref <min_decode_ratio>
+    max_decode_ratio: !ref <max_decode_ratio>
+    beam_size: !ref <test_beam_size>
+    temperature: 1.15
+    using_eos_threshold: False
+    length_normalization: True
+    scorer: !ref <scorer>
+
+log_softmax: !new:torch.nn.LogSoftmax
+    dim: -1
+
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+    blank_index: !ref <blank_index>
+    reduction: !ref <loss_reduction>
+
+seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
+    label_smoothing: !ref <label_smoothing>
+    reduction: !ref <loss_reduction>
+
+noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
+    lr_initial: !ref <lr_adam>
+    n_warmup_steps: 30000
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        model: !ref <model>
+        noam_scheduler: !ref <noam_annealing>
+        normalizer: !ref <normalize>
+        counter: !ref <epoch_counter>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Time Drop
+time_drop_length_low: 20  # Min length for temporal chunk to drop in spectrogram
+time_drop_length_high: 25  # Max length for temporal chunk to drop in spectrogram
+time_drop_count_low: 7  # Min number of chunks to drop in time in the spectrogram
+time_drop_count_high: 7  # Max number of chunks to drop in time in the spectrogram
+time_drop_replace: "mean"  # Method of dropping chunks
+
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: !ref <time_drop_length_low>
+    drop_length_high: !ref <time_drop_length_high>
+    drop_count_low: !ref <time_drop_count_low>
+    drop_count_high: !ref <time_drop_count_high>
+    replace: !ref <time_drop_replace>
+    dim: 1
+
+# Frequency Drop
+freq_drop_length_low: 25  # Min length for chunks to drop in frequency in the spectrogram
+freq_drop_length_high: 30  # Max length for chunks to drop in frequency in the spectrogram
+freq_drop_count_low: 2  # Min number of chunks to drop in frequency in the spectrogram
+freq_drop_count_high: 2  # Max number of chunks to drop in frequency in the spectrogram
+freq_drop_replace: "mean"  # Method of dropping chunks
+
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: !ref <freq_drop_length_low>
+    drop_length_high: !ref <freq_drop_length_high>
+    drop_count_low: !ref <freq_drop_count_low>
+    drop_count_high: !ref <freq_drop_count_high>
+    replace: !ref <freq_drop_replace>
+    dim: 2
+
+# Time warp
+time_warp_window: 5  # Length of time warping window
+time_warp_mode: "bicubic"  # Time warping method
+
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+    warp_window: !ref <time_warp_window>
+    warp_mode: !ref <time_warp_mode>
+    dim: 1
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: False
+    concat_original: False
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>]
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+    sample_rate: !ref <sample_rate>
+    n_fft: !ref <n_fft>
+    win_length: !ref <win_length>
+    n_mels: !ref <n_mels>
+
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
+
+# The pretrainer allows a mapping between pretrained files and instances that
+# are declared in the yaml. E.g here, we will download the file lm.ckpt
+# and it will be loaded into "lm" which is pointing to the <lm_model> defined
+# before.
+pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
+    collect_in: !ref <save_folder>
+    loadables:
+        tokenizer: !ref <tokenizer>
+    paths:
+        tokenizer: !ref <pretrained_tokenizer_file>
diff --git a/recipes/Tedlium2/ASR/transformer/tedlium2_prepare.py b/recipes/Tedlium2/ASR/transformer/tedlium2_prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..c56bee6cba8c2a4099d9e0b38c65995f66c83dea
--- /dev/null
+++ b/recipes/Tedlium2/ASR/transformer/tedlium2_prepare.py
@@ -0,0 +1,226 @@
+"""
+Download link: https://lium.univ-lemans.fr/ted-lium2/
+
+Authors
+ * Shucong Zhang 2023
+ * Adel Moumen 2023
+"""
+
+import os
+import csv
+import logging
+import torchaudio
+import functools
+from speechbrain.utils.parallel import parallel_map
+
+logger = logging.getLogger(__name__)
+
+
+def make_splits(
+    sph_file, stm_file, utt_save_folder, avoid_if_shorter_than,
+):
+    """
+    This function splits the .sph Ted-talk recording into utterences based on the .stm annotation.
+
+    Arguments
+    ---------
+    sph_file : str
+        Path to the sph file containing Ted-talk recording.
+    stm_file : str
+        Path to the stm file containing Ted-talk annotation.
+    utt_save_folder: str
+        The folder stores the clipped individual utterences.
+    avoid_if_shorter_than: int
+        Any utterance shorter than this will be discarded.
+    """
+    # the annotation for JillSobuleMANHATTANINJANUARY_2006.sph is not useful
+    if "JillSobuleMANHATTANINJANUARY_2006" in sph_file:
+        logger.info("JillSobuleMANHATTANINJANUARY_2006.sph is skipped")
+        return
+
+    # load the annotation of the entire speech recording
+    annotation_file = open(stm_file, "r")
+    annotations = annotation_file.readlines()
+
+    # load the original speech recording
+    original_speech, sample_rate = torchaudio.load(sph_file)
+
+    entry = []
+
+    # process the annotation utterence by utterance
+    for i, line in enumerate(annotations):
+        line = line.strip("\n")
+        line = line.split(" ")
+        # parse the annotation
+        talk_id = line[0]
+        spk_id = line[2]
+
+        # start and end point of the utterences in the recording
+        start = float(line[3])
+        end = float(line[4])
+        duration = -start + end
+        # we skip short utterences in case of CNN padding issues
+        if duration < avoid_if_shorter_than:
+            continue
+
+        # transcriptions
+        wrd_list = line[6:]
+        if wrd_list[-1] == "":
+            wrd_list = wrd_list[:-1]
+        transcript = " ".join(wrd_list)
+        if not transcript[-1].isalpha():
+            transcript = transcript[:-1]
+        transcript = transcript.replace(" 've", "'ve")
+        transcript = transcript.replace(" 't", "'t")
+        transcript = transcript.replace(" 'll", "'ll")
+        transcript = transcript.replace(" 'd", "'d")
+        transcript = transcript.replace(" 'm", "'m")
+        transcript = transcript.replace(" 're", "'re")
+        transcript = transcript.replace(" 's", "'s")
+        # skip invalid transcriptions
+        if len(wrd_list) <= 1 or transcript == "ignore_time_segment_in_scoring":
+            continue
+
+        # clip and save the current utterance
+        clipped_save_path = os.path.join(
+            utt_save_folder, talk_id + "-" + str(i) + ".wav"
+        )
+
+        # we avoid duplicated clip and save
+        if not os.path.exists(clipped_save_path):
+            start = float(line[3]) * sample_rate
+            end = float(line[4]) * sample_rate
+            curr_utt = original_speech[:, int(start) : int(end)]
+            torchaudio.save(clipped_save_path, curr_utt, sample_rate)
+        # append to the csv entry list
+        csv_line = [
+            f"{talk_id}-{str(i)}",
+            str(duration),
+            clipped_save_path,
+            spk_id,
+            transcript,
+        ]
+        entry.append(csv_line)
+
+    return entry
+
+
+def process_line(
+    talk_sph, avoid_if_shorter_than, utt_save_folder_split, data_folder, split
+):
+    """ This function processes a single Ted-talk recording.
+
+    Arguments
+    ---------
+    talk_sph : str
+        The name of the Ted-talk recording.
+    avoid_if_shorter_than: int
+        Any utterance shorter than this will be discarded.
+    utt_save_folder_split: str
+        The folder stores the clipped individual utterences.
+    data_folder: str
+        The folder stores the original Ted-talk recordings.
+    split: str
+        The split of the dataset, e.g., train, dev, test.
+    """
+    talk_name = talk_sph[:-4]
+    talk_sph_path = os.path.join(data_folder, split, "sph", talk_sph)
+    talk_stm_path = os.path.join(data_folder, split, "stm", talk_name + ".stm")
+
+    return make_splits(
+        talk_sph_path,
+        talk_stm_path,
+        utt_save_folder_split,
+        avoid_if_shorter_than,
+    )
+
+
+def prepare_tedlium2(
+    data_folder,
+    utt_save_folder,
+    csv_save_folder,
+    skip_prep=False,
+    avoid_if_shorter_than=1,
+):
+    """ This function prepares the Tedlium2 dataset.
+    Download link: https://lium.univ-lemans.fr/ted-lium2/
+
+    Arguments
+    ---------
+    data_folder : str
+        Path to the folder where the original Tedlium2 dataset is stored.
+    utt_save_folder : list
+        Path where to save the clipped utterence-leve recordings.
+    csv_save_folder: str
+        Path where to save the generated .csv files.
+    skip_prep: bool
+        If True, data preparation is skipped.
+    avoid_if_shorter_than: int
+        Any utterance shorter than this will be discarded.
+
+    Example
+    -------
+    >>> data_folder = 'datasets/TEDLIUM_release2'
+    >>> utt_save_folder = 'datasets/TEDLIUM_release2_processed'
+    >>> csv_save_folder = 'TEDLIUM2'
+    >>> prepare_tedlium2(data_folder, utt_save_folder, csv_save_folder)
+    """
+    if skip_prep:
+        return
+
+    splits = [
+        "train",
+        "dev",
+        "test",
+    ]
+
+    for split in splits:
+        utt_save_folder_split = os.path.join(utt_save_folder, split)
+        csv_save_folder_split = os.path.join(csv_save_folder, split)
+        os.makedirs(utt_save_folder_split, exist_ok=True)
+        os.makedirs(csv_save_folder_split, exist_ok=True)
+        new_filename = os.path.join(csv_save_folder_split, split + ".csv")
+        if os.path.exists(new_filename):
+            continue
+        logger.info("Preparing %s..." % new_filename)
+        data_folder_split = os.path.join(data_folder, split)
+        talk_sphs = os.listdir(os.path.join(data_folder_split, "sph"))
+
+        line_processor = functools.partial(
+            process_line,
+            avoid_if_shorter_than=avoid_if_shorter_than,
+            utt_save_folder_split=utt_save_folder_split,
+            data_folder=data_folder,
+            split=split,
+        )
+
+        tmp_csv = os.path.join(csv_save_folder_split, split + ".tmp")
+        final_csv = os.path.join(csv_save_folder_split, split + ".csv")
+        total_line = 0
+        total_duration = 0
+        with open(tmp_csv, mode="w", encoding="utf-8") as csv_f:
+            csv_writer = csv.writer(
+                csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
+            )
+
+            csv_writer.writerow(["ID", "duration", "wav", "spk_id", "wrd"])
+            for row in parallel_map(line_processor, talk_sphs):
+                if row is None:
+                    continue
+
+                for line in row:
+                    csv_writer.writerow(line)
+                    total_duration += float(line[1])
+                total_line += len(row)
+
+        os.replace(tmp_csv, final_csv)
+
+        msg = "\t%s successfully created!" % (new_filename)
+        logger.info(msg)
+
+        msg = f"Number of samples: {total_line} "
+        logger.info(msg)
+        msg = "Total duration: %s Hours" % (
+            str(round(total_duration / 3600, 2))
+        )
+        logger.info(msg)
diff --git a/recipes/Tedlium2/ASR/transformer/train.py b/recipes/Tedlium2/ASR/transformer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fd888b3e97845bcdad22f24e9f91e497db37822
--- /dev/null
+++ b/recipes/Tedlium2/ASR/transformer/train.py
@@ -0,0 +1,467 @@
+#!/usr/bin/env python3
+"""Recipe for training a Transformer ASR system with Tedlium2.
+The system employs an encoder, a decoder, and an attention mechanism
+between them. Decoding is performed with (CTC/Att joint) beamsearch.
+
+To run this recipe, do the following:
+> python train.py hparams/branchformer.yaml
+
+With the default hyperparameters, the system employs a convolutional frontend and a Branchformer.
+The decoder is based on a Transformer decoder.
+
+The neural network is trained on both CTC and negative-log likelihood
+targets and sub-word units estimated with Byte Pairwise Encoding (BPE)
+are used as basic recognition tokens. Training is performed on the Tedlium2
+training dataset.
+
+The best model is the average of the checkpoints from last 10 epochs.
+
+The experiment file is flexible enough to support a large variety of
+different systems. By properly changing the parameter files, you can try
+different encoders, decoders, tokens (e.g, characters instead of BPE),
+and many other possible variations.
+
+
+Authors
+ * Jianyuan Zhong 2020
+ * Mirco Ravanelli 2020
+ * Peter Plantinga 2020
+ * Samuele Cornell 2020, 2021, 2022
+ * Titouan Parcollet 2021, 2022
+ * Shucong Zhang 2023
+"""
+
+import os
+import sys
+import torch
+import logging
+from pathlib import Path
+import speechbrain as sb
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.distributed import run_on_main, if_main_process
+
+logger = logging.getLogger(__name__)
+
+
+# Define training procedure
+class ASR(sb.core.Brain):
+    def compute_forward(self, batch, stage):
+        """Forward computations from the waveform batches to the output probabilities."""
+        batch = batch.to(self.device)
+        wavs, wav_lens = batch.sig
+        tokens_bos, _ = batch.tokens_bos
+
+        # compute features
+        feats = self.hparams.compute_features(wavs)
+        current_epoch = self.hparams.epoch_counter.current
+        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
+
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)
+
+        # forward modules
+        src = self.modules.CNN(feats)
+
+        enc_out, pred = self.modules.Transformer(
+            src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index,
+        )
+
+        # output layer for ctc log-probabilities
+        logits = self.modules.ctc_lin(enc_out)
+        p_ctc = self.hparams.log_softmax(logits)
+
+        # output layer for seq2seq log-probabilities
+        pred = self.modules.seq_lin(pred)
+        p_seq = self.hparams.log_softmax(pred)
+
+        # Compute outputs
+        hyps = None
+        if stage == sb.Stage.TRAIN:
+            hyps = None
+        elif stage == sb.Stage.VALID:
+            hyps = None
+            current_epoch = self.hparams.epoch_counter.current
+            if current_epoch % self.hparams.valid_search_interval == 0:
+                # for the sake of efficiency, we only perform beamsearch with limited capacity
+                # and no LM to give user some idea of how the AM is doing
+                hyps, _, _, _ = self.hparams.valid_search(
+                    enc_out.detach(), wav_lens
+                )
+        elif stage == sb.Stage.TEST:
+            hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
+
+        return p_ctc, p_seq, wav_lens, hyps
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss (CTC+NLL) given predictions and targets."""
+
+        (p_ctc, p_seq, wav_lens, hyps,) = predictions
+
+        ids = batch.id
+        tokens_eos, tokens_eos_lens = batch.tokens_eos
+        tokens, tokens_lens = batch.tokens
+
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "fea_augment"):
+                tokens = self.hparams.fea_augment.replicate_labels(tokens)
+                tokens_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_lens
+                )
+                tokens_eos = self.hparams.fea_augment.replicate_labels(
+                    tokens_eos
+                )
+                tokens_eos_lens = self.hparams.fea_augment.replicate_labels(
+                    tokens_eos_lens
+                )
+
+        loss_seq = self.hparams.seq_cost(
+            p_seq, tokens_eos, length=tokens_eos_lens
+        ).sum()
+
+        loss_ctc = self.hparams.ctc_cost(
+            p_ctc, tokens, wav_lens, tokens_lens
+        ).sum()
+
+        loss = (
+            self.hparams.ctc_weight * loss_ctc
+            + (1 - self.hparams.ctc_weight) * loss_seq
+        )
+
+        if stage != sb.Stage.TRAIN:
+            current_epoch = self.hparams.epoch_counter.current
+            valid_search_interval = self.hparams.valid_search_interval
+            if current_epoch % valid_search_interval == 0 or (
+                stage == sb.Stage.TEST
+            ):
+                # Decode token terms to words
+                predicted_words = [
+                    tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps
+                ]
+                target_words = [wrd.split(" ") for wrd in batch.wrd]
+                self.wer_metric.append(ids, predicted_words, target_words)
+
+            # compute the accuracy of the one-step-forward prediction
+            self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
+        return loss
+
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
+            self.hparams.noam_annealing(self.optimizer)
+
+    def on_evaluate_start(self, max_key=None, min_key=None):
+        """perform checkpoint averge if needed"""
+        super().on_evaluate_start()
+
+        ckpts = self.checkpointer.find_checkpoints(
+            max_key=max_key, min_key=min_key
+        )
+        ckpt = sb.utils.checkpoints.average_checkpoints(
+            ckpts, recoverable_name="model"
+        )
+
+        self.hparams.model.load_state_dict(ckpt, strict=True)
+        self.hparams.model.eval()
+        logger.info("Loaded the average")
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called at the beginning of each epoch"""
+        if stage != sb.Stage.TRAIN:
+            self.acc_metric = self.hparams.acc_computer()
+            self.wer_metric = self.hparams.error_rate_computer()
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of a epoch."""
+        # Compute/store important stats
+        stage_stats = {"loss": stage_loss}
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_stats
+        else:
+            stage_stats["ACC"] = self.acc_metric.summarize()
+            current_epoch = self.hparams.epoch_counter.current
+            valid_search_interval = self.hparams.valid_search_interval
+            if (
+                current_epoch % valid_search_interval == 0
+                or stage == sb.Stage.TEST
+            ):
+                stage_stats["WER"] = self.wer_metric.summarize("error_rate")
+
+        # log stats and save checkpoint at end-of-epoch
+        if stage == sb.Stage.VALID:
+
+            lr = self.hparams.noam_annealing.current_lr
+            steps = self.optimizer_step
+            optimizer = self.optimizer.__class__.__name__
+
+            epoch_stats = {
+                "epoch": epoch,
+                "lr": lr,
+                "steps": steps,
+                "optimizer": optimizer,
+            }
+            self.hparams.train_logger.log_stats(
+                stats_meta=epoch_stats,
+                train_stats=self.train_stats,
+                valid_stats=stage_stats,
+            )
+            self.checkpointer.save_and_keep_only(
+                meta={"ACC": stage_stats["ACC"], "epoch": epoch},
+                max_keys=["ACC"],
+                num_to_keep=10,
+            )
+
+        elif stage == sb.Stage.TEST:
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+            if if_main_process():
+                with open(self.hparams.test_wer_file, "w") as w:
+                    self.wer_metric.write_stats(w)
+
+            # save the averaged checkpoint at the end of the evaluation stage
+            # delete the rest of the intermediate checkpoints
+            # ACC is set to 1.1 so checkpointer only keeps the averaged checkpoint
+            self.checkpointer.save_and_keep_only(
+                meta={"ACC": 1.1, "epoch": epoch},
+                max_keys=["ACC"],
+                num_to_keep=1,
+            )
+
+
+def dataio_prepare(hparams):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined functions."""
+
+    train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
+        csv_path=hparams["train_csv"]
+    )
+
+    if hparams["sorting"] == "ascending":
+        # we sort training data to speed up training and get better results.
+        train_data = train_data.filtered_sorted(sort_key="duration")
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["train_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "descending":
+        train_data = train_data.filtered_sorted(
+            sort_key="duration", reverse=True
+        )
+        # when sorting do not shuffle in dataloader ! otherwise is pointless
+        hparams["train_dataloader_opts"]["shuffle"] = False
+
+    elif hparams["sorting"] == "random":
+        pass
+
+    else:
+        raise NotImplementedError(
+            "sorting must be random, ascending or descending"
+        )
+    valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
+        csv_path=hparams["valid_csv"]
+    )
+    valid_data = valid_data.filtered_sorted(sort_key="duration")
+
+    # test is separate
+    test_datasets = {}
+    for csv_file in hparams["test_csv"]:
+        name = Path(csv_file).stem
+        test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
+            csv_path=csv_file
+        )
+        test_datasets[name] = test_datasets[name].filtered_sorted(
+            sort_key="duration"
+        )
+
+    datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
+    valtest_datasets = [valid_data] + [i for k, i in test_datasets.items()]
+
+    # We get the tokenizer as we need it to encode the labels when creating
+    # mini-batches.
+    tokenizer = hparams["tokenizer"]
+
+    # 2. Define audio pipeline:
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    sb.dataio.dataset.add_dynamic_item(valtest_datasets, audio_pipeline)
+
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline_train(wav):
+        # Speed Perturb is done here so it is multi-threaded with the
+        # workers of the dataloader (faster).
+        if "speed_perturb" in hparams:
+            sig = sb.dataio.dataio.read_audio(wav)
+
+            sig = hparams["speed_perturb"](sig.unsqueeze(0)).squeeze(0)
+        else:
+            sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    sb.dataio.dataset.add_dynamic_item([train_data], audio_pipeline_train)
+
+    # 3. Define text pipeline:
+    @sb.utils.data_pipeline.takes("wrd")
+    @sb.utils.data_pipeline.provides(
+        "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
+    )
+    def text_pipeline(wrd):
+        yield wrd
+        tokens_list = tokenizer.encode_as_ids(wrd)
+        yield tokens_list
+        tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
+        yield tokens_bos
+        tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
+        yield tokens_eos
+        tokens = torch.LongTensor(tokens_list)
+        yield tokens
+
+    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
+
+    # 4. Set output:
+    sb.dataio.dataset.set_output_keys(
+        datasets, ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"],
+    )
+
+    # 5. If Dynamic Batching is used, we instantiate the needed samplers.
+    train_batch_sampler = None
+    valid_batch_sampler = None
+    if hparams["dynamic_batching"]:
+        from speechbrain.dataio.sampler import DynamicBatchSampler  # noqa
+
+        dynamic_hparams_train = hparams["dynamic_batch_sampler_train"]
+        dynamic_hparams_valid = hparams["dynamic_batch_sampler_valid"]
+
+        train_batch_sampler = DynamicBatchSampler(
+            train_data,
+            length_func=lambda x: x["duration"],
+            **dynamic_hparams_train,
+        )
+        valid_batch_sampler = DynamicBatchSampler(
+            valid_data,
+            length_func=lambda x: x["duration"],
+            **dynamic_hparams_valid,
+        )
+
+    return (
+        train_data,
+        valid_data,
+        test_datasets,
+        tokenizer,
+        train_batch_sampler,
+        valid_batch_sampler,
+    )
+
+
+if __name__ == "__main__":
+    # CLI:
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # If --distributed_launch then
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # 1.  # Dataset prep (parsing Tedlium2)
+    from tedlium2_prepare import prepare_tedlium2  # noqa
+
+    # multi-gpu (ddp) save data preparation
+    run_on_main(
+        prepare_tedlium2,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "utt_save_folder": hparams["clipped_utt_folder"],
+            "csv_save_folder": hparams["output_folder"],
+            "skip_prep": hparams["skip_prep"],
+            "avoid_if_shorter_than": hparams["avoid_if_shorter_than"],
+        },
+    )
+
+    # here we create the datasets objects as well as tokenization and encoding
+    (
+        train_data,
+        valid_data,
+        test_datasets,
+        tokenizer,
+        train_bsampler,
+        valid_bsampler,
+    ) = dataio_prepare(hparams)
+
+    # We download the pretrained LM from HuggingFace (or elsewhere depending on
+    # the path given in the YAML file). The tokenizer is loaded at the same time.
+    run_on_main(hparams["pretrainer"].collect_files)
+    hparams["pretrainer"].load_collected()
+
+    # Trainer initialization
+    asr_brain = ASR(
+        modules=hparams["modules"],
+        opt_class=hparams["Adam"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    # adding objects to trainer:
+    asr_brain.tokenizer = hparams["tokenizer"]
+    train_dataloader_opts = hparams["train_dataloader_opts"]
+    valid_dataloader_opts = hparams["valid_dataloader_opts"]
+
+    if train_bsampler is not None:
+        collate_fn = None
+        if "collate_fn" in train_dataloader_opts:
+            collate_fn = train_dataloader_opts["collate_fn"]
+
+        train_dataloader_opts = {
+            "batch_sampler": train_bsampler,
+            "num_workers": hparams["num_workers"],
+        }
+
+        if collate_fn is not None:
+            train_dataloader_opts["collate_fn"] = collate_fn
+
+    if valid_bsampler is not None:
+        collate_fn = None
+        if "collate_fn" in valid_dataloader_opts:
+            collate_fn = valid_dataloader_opts["collate_fn"]
+
+        valid_dataloader_opts = {"batch_sampler": valid_bsampler}
+
+        if collate_fn is not None:
+            valid_dataloader_opts["collate_fn"] = collate_fn
+
+    # Training
+    asr_brain.fit(
+        asr_brain.hparams.epoch_counter,
+        train_data,
+        valid_data,
+        train_loader_kwargs=train_dataloader_opts,
+        valid_loader_kwargs=valid_dataloader_opts,
+    )
+
+    # Testing
+    if not os.path.exists(hparams["output_wer_folder"]):
+        os.makedirs(hparams["output_wer_folder"])
+
+    for k in test_datasets.keys():  # keys are test_clean, test_other etc
+        asr_brain.hparams.test_wer_file = os.path.join(
+            hparams["output_wer_folder"], f"wer_{k}.txt"
+        )
+        asr_brain.evaluate(
+            test_datasets[k],
+            max_key="ACC",
+            test_loader_kwargs=hparams["test_dataloader_opts"],
+        )
diff --git a/recipes/Tedlium2/Tokenizer/README.md b/recipes/Tedlium2/Tokenizer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..202cb99521000d43b1b13a66dd582de262043e63
--- /dev/null
+++ b/recipes/Tedlium2/Tokenizer/README.md
@@ -0,0 +1,44 @@
+# Tokenizer.
+This folder contains the scripts to train a tokenizer using SentencePiece (https://github.com/google/sentencepiece).
+The tokenizer is trained on the top of the Tedlium2 training transcriptions.
+
+You can download Tedlium2 at https://lium.univ-lemans.fr/ted-lium2/
+
+
+# How to Run
+
+To run the training script, follow these steps:
+
+1. Run the following command, replacing `--data_folder` with the path to your downloaded and unpacked Tedlium2 dataset:
+
+```python
+python train.py hparams/tedlium2_500_bpe.yaml --data_folder=/path/to/TEDLIUM --clipped_utt_folder=/path/where/to/store/clipped/TEDLIUM
+```
+
+**IMPORTANT**: Please utilize **absolute paths** for both the `data_folder` and the `clipped_utt_folder` because the generated CSV files will be employed in training the ASR model.
+
+
+2. The script will automatically process the dataset and store a modified version of it in the directory specified by `--clipped_utt_folder`. This modified dataset contains recordings split into individual utterances, making it suitable for Automatic Speech Recognition (ASR) training. You can now use this processed dataset for ASR training as described in the `../ASR/README.md` file.
+
+Make sure to adjust the paths and filenames as needed to match your specific setup and dataset location.
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+
+# **Citing SpeechBrain**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/Tedlium2/Tokenizer/hparams/tedlium2_500_bpe.yaml b/recipes/Tedlium2/Tokenizer/hparams/tedlium2_500_bpe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..03c91b12692c4415e1c3cdb71b0802467b2505f6
--- /dev/null
+++ b/recipes/Tedlium2/Tokenizer/hparams/tedlium2_500_bpe.yaml
@@ -0,0 +1,31 @@
+# ############################################################################
+# Tokenizer: subword BPE with unigram 500
+# Training: Tedlium2
+# Authors:  Abdel Heba 2021
+#           Shucong Zhang 2023
+# ############################################################################
+
+output_folder: results/tokenizer # folder where to store the BPE ckpt and csv files
+clipped_utt_folder: !PLACEHOLDER # folder where to store the clipped utterence-level recordings
+
+# Data files
+data_folder: !PLACEHOLDER # e.g, /path/to/TEDLIUM_release2
+skip_prep: False
+train_csv: !ref <output_folder>/train/train.csv
+valid_csv: !ref <output_folder>/dev/dev.csv
+
+####################### Training Parameters ####################################
+token_type: bpe  # ["unigram", "bpe", "char"]
+token_output: 500  # index(blank/eos/bos/unk) = 0
+character_coverage: 1.0
+csv_read: wrd
+avoid_if_shorter_than: 1.0
+
+tokenizer: !name:speechbrain.tokenizers.SentencePiece.SentencePiece
+   model_dir: !ref <output_folder>
+   vocab_size: !ref <token_output>
+   annotation_train: !ref <train_csv>
+   annotation_read: !ref <csv_read>
+   model_type: !ref <token_type> # ["unigram", "bpe", "char"]
+   character_coverage: !ref <character_coverage>
+   annotation_list_to_check: [!ref <train_csv>, !ref <valid_csv>]
diff --git a/recipes/Tedlium2/Tokenizer/tedlium2_prepare.py b/recipes/Tedlium2/Tokenizer/tedlium2_prepare.py
new file mode 120000
index 0000000000000000000000000000000000000000..53047f4d8303e1012ba1026f2507b668436f99e1
--- /dev/null
+++ b/recipes/Tedlium2/Tokenizer/tedlium2_prepare.py
@@ -0,0 +1 @@
+../tedlium2_prepare.py
\ No newline at end of file
diff --git a/recipes/Tedlium2/Tokenizer/train.py b/recipes/Tedlium2/Tokenizer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..af2b21ae93e23936d737997a10b7bfd097e945d0
--- /dev/null
+++ b/recipes/Tedlium2/Tokenizer/train.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env/python3
+"""Recipe for training a BPE tokenizer with Tedlium2.
+The tokenizer converts words into sub-word units that can
+be used to train a language (LM) or an acoustic model (AM).
+When doing a speech recognition experiment you have to make
+sure that the acoustic and language models are trained with
+the same tokenizer. Otherwise, a token mismatch is introduced
+and beamsearch will produce bad results when combining AM and LM.
+
+To run this recipe, do the following:
+> python train.py hyperparams/tedlium2_500_bpe.yaml
+
+Authors
+ * Shucong Zhang 2023
+"""
+
+import sys
+import speechbrain as sb
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.distributed import run_on_main
+import shutil
+
+if __name__ == "__main__":
+
+    # CLI:
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # create ddp_group with the right communication protocol
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # 1.  # Dataset prep (parsing Tedlium2)
+    from tedlium2_prepare import prepare_tedlium2  # noqa
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # multi-gpu (ddp) save data preparation
+    run_on_main(
+        prepare_tedlium2,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "utt_save_folder": hparams["clipped_utt_folder"],
+            "csv_save_folder": hparams["output_folder"],
+            "skip_prep": hparams["skip_prep"],
+            "avoid_if_shorter_than": hparams["avoid_if_shorter_than"],
+        },
+    )
+
+    # Train tokenizer
+    hparams["tokenizer"]()
+
+    output_path = hparams["output_folder"]
+
+    token_output = hparams["token_output"]
+    token_type = hparams["token_type"]
+    bpe_model = f"{output_path}/{token_output}_{token_type}.model"
+
+    tokenizer_ckpt = f"{output_path}/tokenizer.ckpt"
+    shutil.copyfile(bpe_model, tokenizer_ckpt)
diff --git a/recipes/Tedlium2/tedlium2_prepare.py b/recipes/Tedlium2/tedlium2_prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..c56bee6cba8c2a4099d9e0b38c65995f66c83dea
--- /dev/null
+++ b/recipes/Tedlium2/tedlium2_prepare.py
@@ -0,0 +1,226 @@
+"""
+Download link: https://lium.univ-lemans.fr/ted-lium2/
+
+Authors
+ * Shucong Zhang 2023
+ * Adel Moumen 2023
+"""
+
+import os
+import csv
+import logging
+import torchaudio
+import functools
+from speechbrain.utils.parallel import parallel_map
+
+logger = logging.getLogger(__name__)
+
+
+def make_splits(
+    sph_file, stm_file, utt_save_folder, avoid_if_shorter_than,
+):
+    """
+    This function splits the .sph Ted-talk recording into utterences based on the .stm annotation.
+
+    Arguments
+    ---------
+    sph_file : str
+        Path to the sph file containing Ted-talk recording.
+    stm_file : str
+        Path to the stm file containing Ted-talk annotation.
+    utt_save_folder: str
+        The folder stores the clipped individual utterences.
+    avoid_if_shorter_than: int
+        Any utterance shorter than this will be discarded.
+    """
+    # the annotation for JillSobuleMANHATTANINJANUARY_2006.sph is not useful
+    if "JillSobuleMANHATTANINJANUARY_2006" in sph_file:
+        logger.info("JillSobuleMANHATTANINJANUARY_2006.sph is skipped")
+        return
+
+    # load the annotation of the entire speech recording
+    annotation_file = open(stm_file, "r")
+    annotations = annotation_file.readlines()
+
+    # load the original speech recording
+    original_speech, sample_rate = torchaudio.load(sph_file)
+
+    entry = []
+
+    # process the annotation utterence by utterance
+    for i, line in enumerate(annotations):
+        line = line.strip("\n")
+        line = line.split(" ")
+        # parse the annotation
+        talk_id = line[0]
+        spk_id = line[2]
+
+        # start and end point of the utterences in the recording
+        start = float(line[3])
+        end = float(line[4])
+        duration = -start + end
+        # we skip short utterences in case of CNN padding issues
+        if duration < avoid_if_shorter_than:
+            continue
+
+        # transcriptions
+        wrd_list = line[6:]
+        if wrd_list[-1] == "":
+            wrd_list = wrd_list[:-1]
+        transcript = " ".join(wrd_list)
+        if not transcript[-1].isalpha():
+            transcript = transcript[:-1]
+        transcript = transcript.replace(" 've", "'ve")
+        transcript = transcript.replace(" 't", "'t")
+        transcript = transcript.replace(" 'll", "'ll")
+        transcript = transcript.replace(" 'd", "'d")
+        transcript = transcript.replace(" 'm", "'m")
+        transcript = transcript.replace(" 're", "'re")
+        transcript = transcript.replace(" 's", "'s")
+        # skip invalid transcriptions
+        if len(wrd_list) <= 1 or transcript == "ignore_time_segment_in_scoring":
+            continue
+
+        # clip and save the current utterance
+        clipped_save_path = os.path.join(
+            utt_save_folder, talk_id + "-" + str(i) + ".wav"
+        )
+
+        # we avoid duplicated clip and save
+        if not os.path.exists(clipped_save_path):
+            start = float(line[3]) * sample_rate
+            end = float(line[4]) * sample_rate
+            curr_utt = original_speech[:, int(start) : int(end)]
+            torchaudio.save(clipped_save_path, curr_utt, sample_rate)
+        # append to the csv entry list
+        csv_line = [
+            f"{talk_id}-{str(i)}",
+            str(duration),
+            clipped_save_path,
+            spk_id,
+            transcript,
+        ]
+        entry.append(csv_line)
+
+    return entry
+
+
+def process_line(
+    talk_sph, avoid_if_shorter_than, utt_save_folder_split, data_folder, split
+):
+    """ This function processes a single Ted-talk recording.
+
+    Arguments
+    ---------
+    talk_sph : str
+        The name of the Ted-talk recording.
+    avoid_if_shorter_than: int
+        Any utterance shorter than this will be discarded.
+    utt_save_folder_split: str
+        The folder stores the clipped individual utterences.
+    data_folder: str
+        The folder stores the original Ted-talk recordings.
+    split: str
+        The split of the dataset, e.g., train, dev, test.
+    """
+    talk_name = talk_sph[:-4]
+    talk_sph_path = os.path.join(data_folder, split, "sph", talk_sph)
+    talk_stm_path = os.path.join(data_folder, split, "stm", talk_name + ".stm")
+
+    return make_splits(
+        talk_sph_path,
+        talk_stm_path,
+        utt_save_folder_split,
+        avoid_if_shorter_than,
+    )
+
+
+def prepare_tedlium2(
+    data_folder,
+    utt_save_folder,
+    csv_save_folder,
+    skip_prep=False,
+    avoid_if_shorter_than=1,
+):
+    """ This function prepares the Tedlium2 dataset.
+    Download link: https://lium.univ-lemans.fr/ted-lium2/
+
+    Arguments
+    ---------
+    data_folder : str
+        Path to the folder where the original Tedlium2 dataset is stored.
+    utt_save_folder : list
+        Path where to save the clipped utterence-leve recordings.
+    csv_save_folder: str
+        Path where to save the generated .csv files.
+    skip_prep: bool
+        If True, data preparation is skipped.
+    avoid_if_shorter_than: int
+        Any utterance shorter than this will be discarded.
+
+    Example
+    -------
+    >>> data_folder = 'datasets/TEDLIUM_release2'
+    >>> utt_save_folder = 'datasets/TEDLIUM_release2_processed'
+    >>> csv_save_folder = 'TEDLIUM2'
+    >>> prepare_tedlium2(data_folder, utt_save_folder, csv_save_folder)
+    """
+    if skip_prep:
+        return
+
+    splits = [
+        "train",
+        "dev",
+        "test",
+    ]
+
+    for split in splits:
+        utt_save_folder_split = os.path.join(utt_save_folder, split)
+        csv_save_folder_split = os.path.join(csv_save_folder, split)
+        os.makedirs(utt_save_folder_split, exist_ok=True)
+        os.makedirs(csv_save_folder_split, exist_ok=True)
+        new_filename = os.path.join(csv_save_folder_split, split + ".csv")
+        if os.path.exists(new_filename):
+            continue
+        logger.info("Preparing %s..." % new_filename)
+        data_folder_split = os.path.join(data_folder, split)
+        talk_sphs = os.listdir(os.path.join(data_folder_split, "sph"))
+
+        line_processor = functools.partial(
+            process_line,
+            avoid_if_shorter_than=avoid_if_shorter_than,
+            utt_save_folder_split=utt_save_folder_split,
+            data_folder=data_folder,
+            split=split,
+        )
+
+        tmp_csv = os.path.join(csv_save_folder_split, split + ".tmp")
+        final_csv = os.path.join(csv_save_folder_split, split + ".csv")
+        total_line = 0
+        total_duration = 0
+        with open(tmp_csv, mode="w", encoding="utf-8") as csv_f:
+            csv_writer = csv.writer(
+                csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
+            )
+
+            csv_writer.writerow(["ID", "duration", "wav", "spk_id", "wrd"])
+            for row in parallel_map(line_processor, talk_sphs):
+                if row is None:
+                    continue
+
+                for line in row:
+                    csv_writer.writerow(line)
+                    total_duration += float(line[1])
+                total_line += len(row)
+
+        os.replace(tmp_csv, final_csv)
+
+        msg = "\t%s successfully created!" % (new_filename)
+        logger.info(msg)
+
+        msg = f"Number of samples: {total_line} "
+        logger.info(msg)
+        msg = "Total duration: %s Hours" % (
+            str(round(total_duration / 3600, 2))
+        )
+        logger.info(msg)
diff --git a/recipes/UrbanSound8k/SoundClassification/hparams/train_ecapa_tdnn.yaml b/recipes/UrbanSound8k/SoundClassification/hparams/train_ecapa_tdnn.yaml
index b30f795128a512ab6665911ee6d801f700d17dab..3ecb1119b460be4c313dd61fc7d6a2a0321bc61f 100644
--- a/recipes/UrbanSound8k/SoundClassification/hparams/train_ecapa_tdnn.yaml
+++ b/recipes/UrbanSound8k/SoundClassification/hparams/train_ecapa_tdnn.yaml
@@ -14,7 +14,8 @@ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
 # Set up folders for reading from and writing to
 # Dataset must already exist at `audio_data_folder`
 data_folder: !PLACEHOLDER # e.g., /localscratch/UrbanSound8K
-open_rir_folder: <data_folder>/RIRS # Change if needed
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
 audio_data_folder: !ref <data_folder>/audio
 # TODO the follwing folder will contain the resampled audio files (mono channel and config SR) to train on
 #reasmpled_audio_data_folder: !ref <data_folder>/audio_mono16kHz
@@ -22,14 +23,21 @@ output_folder: !ref ./results/urban_sound/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
+
 # Tensorboard logs
 use_tensorboard: True
 tensorboard_logs_folder: !ref <output_folder>/tb_logs/
 
 # Path where data manifest files will be stored
-train_annotation: !ref <data_folder>/manifest/train.json
-valid_annotation: !ref <data_folder>/manifest/valid.json
-test_annotation: !ref <data_folder>/manifest/test.json
+train_annotation: !ref <save_folder>/manifest/train.json
+valid_annotation: !ref <save_folder>/manifest/valid.json
+test_annotation: !ref <save_folder>/manifest/test.json
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
 
 # To standardize results, UrbanSound8k has pre-separated samples into
 # 10 folds for multi-fold validation
@@ -40,7 +48,7 @@ skip_manifest_creation: False
 
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 25
 batch_size: 32
 lr: 0.001
@@ -66,10 +74,11 @@ out_n_neurons: 10
 # because this does not mix samples from folds in train to valid/test, only
 # within train or valid, or test
 shuffle: True
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
     shuffle: !ref <shuffle>
-    num_workers: 0
+    num_workers: !ref <num_workers>
 
 # Functions
 compute_features: !new:speechbrain.lobes.features.Fbank
@@ -94,55 +103,88 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
 
-augment_wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [100]
-
-augment_speed: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-add_rev: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <open_rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 0.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <open_rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_rev_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <open_rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-
-# Definition of the augmentation pipeline.
-# If concat_augment = False, the augmentation techniques are applied
-# in sequence. If concat_augment = True, all the augmented signals
-# # are concatenated in a single big batch.
-
-augment_pipeline: [
-    #!ref <augment_wavedrop>,
-    #!ref <augment_speed>,
-    #!ref <add_rev>,
-    #!ref <add_noise>,
-    #!ref <add_rev_noise>
-]
-concat_augment: True
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: True
+    concat_original: True
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 mean_var_norm: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
@@ -150,11 +192,6 @@ mean_var_norm: !new:speechbrain.processing.features.InputNormalization
 
 modules:
     compute_features: !ref <compute_features>
-    augment_wavedrop: !ref <augment_wavedrop>
-    augment_speed: !ref <augment_speed>
-    add_rev: !ref <add_rev>
-    add_noise: !ref <add_noise>
-    add_rev_noise: !ref <add_rev_noise>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     mean_var_norm: !ref <mean_var_norm>
diff --git a/recipes/UrbanSound8k/SoundClassification/train.py b/recipes/UrbanSound8k/SoundClassification/train.py
index 806e667d8c1b1e1f6c7b3420d2753bafb351b7d7..262f27577ef47028803ce053628ad0efac306991 100755
--- a/recipes/UrbanSound8k/SoundClassification/train.py
+++ b/recipes/UrbanSound8k/SoundClassification/train.py
@@ -43,33 +43,9 @@ class UrbanSound8kBrain(sb.core.Brain):
         batch = batch.to(self.device)
         wavs, lens = batch.sig
 
-        if stage == sb.Stage.TRAIN:
-
-            # Applying the augmentation pipeline
-            wavs_aug_tot = []
-            wavs_aug_tot.append(wavs)
-            for count, augment in enumerate(self.hparams.augment_pipeline):
-
-                # Apply augment
-                wavs_aug = augment(wavs, lens)
-
-                # Managing speed change
-                if wavs_aug.shape[1] > wavs.shape[1]:
-                    wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
-                else:
-                    zero_sig = torch.zeros_like(wavs)
-                    zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
-                    wavs_aug = zero_sig
-
-                if self.hparams.concat_augment:
-                    wavs_aug_tot.append(wavs_aug)
-                else:
-                    wavs = wavs_aug
-                    wavs_aug_tot[0] = wavs
-
-            wavs = torch.cat(wavs_aug_tot, dim=0)
-            self.n_augment = len(wavs_aug_tot)
-            lens = torch.cat([lens] * self.n_augment)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, lens = self.hparams.wav_augment(wavs, lens)
 
         # Feature extraction and normalization
         feats = self.modules.compute_features(wavs)
@@ -98,8 +74,9 @@ class UrbanSound8kBrain(sb.core.Brain):
         classid, _ = batch.class_string_encoded
 
         # Concatenate labels (due to data augmentation)
-        if stage == sb.Stage.TRAIN:
-            classid = torch.cat([classid] * self.n_augment, dim=0)
+        # Concatenate labels (due to data augmentation)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            classid = self.hparams.wav_augment.replicate_labels(classid)
 
         loss = self.hparams.compute_cost(predictions, classid, lens)
 
@@ -417,6 +394,8 @@ if __name__ == "__main__":
             "skip_manifest_creation": hparams["skip_manifest_creation"],
         },
     )
+    sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
+    sb.utils.distributed.run_on_main(hparams["prepare_rir_data"])
 
     # Dataset IO prep: creating Dataset objects and proper encodings for phones
     datasets, label_encoder = dataio_prep(hparams)
diff --git a/recipes/Voicebank/ASR/CTC/README.md b/recipes/Voicebank/ASR/CTC/README.md
index b01bf25e09c3782881aa2041b35e0054401d0c6f..9baa57d1841e99f1dc039d59d203bc7cbb8dc2ff 100644
--- a/recipes/Voicebank/ASR/CTC/README.md
+++ b/recipes/Voicebank/ASR/CTC/README.md
@@ -8,9 +8,13 @@ download and resample the dataset.
 ## How to run
 
 ```bash
-python train.py hparams/train.yaml
+python train.py hparams/train.yaml --data_folder=your/data/folder --jit
 ```
 
+**Note on Compilation**:
+Enabling the just-in-time (JIT) compiler significantly improves code performance, resulting in a 50-60% speed boost. We highly recommend utilizing the JIT compiler for optimal results.
+This speed improvement is observed specifically when using the CRDNN model.
+
 ## Results
 
 | Release  | hyperparams file | input type  | Test PER | Model link    | GPUs        |
diff --git a/recipes/Voicebank/ASR/CTC/hparams/train.yaml b/recipes/Voicebank/ASR/CTC/hparams/train.yaml
index c76c1d745344b870dfb3d88abcf544cd68115f86..a49bae5fa4caebfd13f87b4fc4a8f24db47af522 100644
--- a/recipes/Voicebank/ASR/CTC/hparams/train.yaml
+++ b/recipes/Voicebank/ASR/CTC/hparams/train.yaml
@@ -20,13 +20,14 @@ valid_annotation: !ref <output_folder>/valid.json
 test_annotation: !ref <output_folder>/test.json
 skip_prep: False # Skip data preparation
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 50
 batch_size: 8
 sorting: ascending
 dataloader_options:
     batch_size: !ref <batch_size>
-lr: 1.0
+lr: 0.5
+max_grad_norm: 5.0
 
 # Set this to the path of a pretrained model to load before training
 # pretrained: model_clean_ep3.ckpt
@@ -36,7 +37,7 @@ sample_rate: 16000
 n_fft: 400
 n_mels: 40
 
-# Model parameters
+####################### Model Parameters #######################################
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
 cnn_blocks: 2
@@ -60,10 +61,41 @@ compute_features: !new:speechbrain.lobes.features.Fbank
     n_fft: !ref <n_fft>
     n_mels: !ref <n_mels>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 3
+    max_augmentations: 3
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 model: !new:speechbrain.lobes.models.CRDNN.CRDNN
     input_shape: [null, null, !ref <n_mels>]
     activation: !ref <activation>
@@ -104,7 +136,6 @@ modules:
     model: !ref <model>
     output: !ref <output>
     normalize: !ref <normalize>
-    augmentation: !ref <augmentation>
 
 jit_module_keys: [model]
 
diff --git a/recipes/Voicebank/ASR/CTC/train.py b/recipes/Voicebank/ASR/CTC/train.py
index 1d361ea1264fe3129b5e3455584ed8c0a9a4d08d..454f828c0a3a96f501f3ef1bf7b8c0679d88828a 100644
--- a/recipes/Voicebank/ASR/CTC/train.py
+++ b/recipes/Voicebank/ASR/CTC/train.py
@@ -26,7 +26,9 @@ class ASR_Brain(sb.Brain):
         "Given an input batch it computes the phoneme probabilities."
         batch = batch.to(self.device)
         wavs, wav_lens = batch.sig
-        wavs = self.modules.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
         feats = self.hparams.compute_features(wavs)
         feats = self.modules.normalize(feats, wav_lens)
         out = self.modules.model(feats)
@@ -39,6 +41,10 @@ class ASR_Brain(sb.Brain):
         "Given the network predictions and targets computed the CTC loss."
         pout, pout_lens = predictions
         phns, phn_lens = batch.phn_encoded
+        # Label Augmentation
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            phns = self.hparams.wav_augment.replicate_labels(phns)
+            phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens)
         loss = self.hparams.compute_cost(pout, phns, pout_lens, phn_lens)
         self.ctc_metrics.append(batch.id, pout, phns, pout_lens, phn_lens)
 
diff --git a/recipes/Voicebank/MTL/ASR_enhance/hparams/enhance_mimic.yaml b/recipes/Voicebank/MTL/ASR_enhance/hparams/enhance_mimic.yaml
index 2a96e25db7460d473460889dc326f6589fd2eceb..c3391498a7ad7792331fce172dd820cd15ce9cb1 100644
--- a/recipes/Voicebank/MTL/ASR_enhance/hparams/enhance_mimic.yaml
+++ b/recipes/Voicebank/MTL/ASR_enhance/hparams/enhance_mimic.yaml
@@ -18,7 +18,7 @@ valid_annotation: !ref <data_folder>/valid.json
 test_annotation: !ref <data_folder>/test.json
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 50
 batch_size: 8
 lr: 0.0001
diff --git a/recipes/Voicebank/MTL/ASR_enhance/hparams/pretrain_perceptual.yaml b/recipes/Voicebank/MTL/ASR_enhance/hparams/pretrain_perceptual.yaml
index a86fbc4ccf1cec27c6c4979fe386b146b7418e65..d384d026ae5bb20de42f48f9beb572e93bd65a00 100644
--- a/recipes/Voicebank/MTL/ASR_enhance/hparams/pretrain_perceptual.yaml
+++ b/recipes/Voicebank/MTL/ASR_enhance/hparams/pretrain_perceptual.yaml
@@ -18,7 +18,7 @@ valid_annotation: !ref <data_folder>/valid.json
 test_annotation: !ref <data_folder>/test.json
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 20
 ctc_epochs: 4
 batch_size: 8
diff --git a/recipes/Voicebank/MTL/ASR_enhance/hparams/robust_asr.yaml b/recipes/Voicebank/MTL/ASR_enhance/hparams/robust_asr.yaml
index 8eec433834a5a92270eca660dd0ccab0762ff461..1835342c30d97c7335927fd2f6c17f7e8fe2a18d 100644
--- a/recipes/Voicebank/MTL/ASR_enhance/hparams/robust_asr.yaml
+++ b/recipes/Voicebank/MTL/ASR_enhance/hparams/robust_asr.yaml
@@ -12,15 +12,19 @@ stats_file: !ref <output_folder>/stats.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/Voicebank
-data_folder_rirs: !ref <data_folder>
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
 train_annotation: !ref <data_folder>/train.json
 valid_annotation: !ref <data_folder>/valid.json
 test_annotation: !ref <data_folder>/test.json
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 30
 ctc_epochs: 0
 batch_size: 8
@@ -31,13 +35,18 @@ checkpoint_avg: 5  # average this many checkpoints for eval
 sorting: ascending
 eval_max_key: null
 eval_min_key: null
+
+num_workers: 4
 train_loader_options:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 valid_loader_options:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
     shuffle: False
 test_loader_options:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
     shuffle: False
 epochs_before_lr_drop: 3
 
@@ -132,24 +141,78 @@ compute_stft: !new:speechbrain.processing.features.STFT
 spectral_magnitude: !name:speechbrain.processing.features.spectral_magnitude
     power: 0.5
 
-env_corr: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
-augment: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
 fbank: !new:speechbrain.lobes.features.Fbank
     n_mels: !ref <n_mels>
     sample_rate: !ref <sample_rate>
 
-beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearchLM
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+    vocab_size: !ref <output_neurons>
+
+rnnlm_scorer: !new:speechbrain.decoders.scorer.RNNLMScorer
+    language_model: !ref <asr_model[lm_model]>
+    temperature: !ref <temperature_lm>
+
+scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <rnnlm_scorer>,
+                   !ref <coverage_scorer>]
+    weights:
+        rnnlm: !ref <lm_weight>
+        coverage: !ref <coverage_penalty>
+
+beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     embedding: !ref <asr_model[tgt_embedding]>
     decoder: !ref <asr_model[recognizer]>
     linear: !ref <asr_model[seq_output]>
-    language_model: !ref <asr_model[lm_model]>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
     min_decode_ratio: !ref <min_decode_ratio>
@@ -158,10 +221,8 @@ beam_searcher: !new:speechbrain.decoders.seq2seq.S2SRNNBeamSearchLM
     eos_threshold: !ref <eos_threshold>
     using_max_attn_shift: !ref <using_max_attn_shift>
     max_attn_shift: !ref <max_attn_shift>
-    coverage_penalty: !ref <coverage_penalty>
-    lm_weight: !ref <lm_weight>
     temperature: !ref <temperature>
-    temperature_lm: !ref <temperature_lm>
+    scorer: !ref <scorer>
 
 opt_class: !name:torch.optim.AdamW
     lr: !ref <lr>
diff --git a/recipes/Voicebank/MTL/ASR_enhance/train.py b/recipes/Voicebank/MTL/ASR_enhance/train.py
index a07fbfae5123f9c81117c9548f2bcb5a325900c2..e9b83309d2296f7701732f05da0778e1ce22e621 100644
--- a/recipes/Voicebank/MTL/ASR_enhance/train.py
+++ b/recipes/Voicebank/MTL/ASR_enhance/train.py
@@ -20,6 +20,7 @@ import os
 import sys
 import torch
 import torchaudio
+import logging
 import speechbrain as sb
 from pesq import pesq
 from pystoi import stoi
@@ -28,6 +29,8 @@ from hyperpyyaml import load_hyperpyyaml
 from speechbrain.utils.data_utils import undo_padding
 from speechbrain.utils.distributed import run_on_main, if_main_process
 
+logger = logging.getLogger(__name__)
+
 
 def pesq_eval(pred_wav, target_wav):
     return pesq(
@@ -74,7 +77,7 @@ class MTLbrain(sb.Brain):
 
         predictions = {}
         if self.hparams.enhance_type is not None:
-            noisy_wavs, lens = self.prepare_wavs(batch.noisy_sig)
+            noisy_wavs, lens = self.prepare_wavs(batch.noisy_sig, stage)
 
             # Mask with "signal approximation (SA)"
             if self.hparams.enhance_type == "masking":
@@ -89,14 +92,14 @@ class MTLbrain(sb.Brain):
 
         # Generate clean features for ASR pre-training
         if self.hparams.ctc_type == "clean" or self.hparams.seq_type == "clean":
-            clean_wavs, lens = self.prepare_wavs(batch.clean_sig)
+            clean_wavs, lens = self.prepare_wavs(batch.clean_sig, stage)
             clean_feats = self.prepare_feats(clean_wavs)
 
         # Compute seq outputs
         if self.hparams.seq_type is not None:
 
             # Prepare target inputs
-            tokens, token_lens = self.prepare_targets(batch.tokens_bos)
+            tokens, token_lens = self.prepare_targets(batch.tokens_bos, stage)
             tokens = self.modules.tgt_embedding(tokens)
 
             if self.hparams.seq_type == "clean":
@@ -105,8 +108,6 @@ class MTLbrain(sb.Brain):
                 embed = self.modules.src_embedding(clean_feats)
             if self.hparams.seq_type == "joint":
                 asr_feats = predictions["wavs"]
-                if stage == sb.Stage.TRAIN:
-                    asr_feats = self.hparams.augment(asr_feats, lens)
                 asr_feats = self.hparams.fbank(asr_feats)
                 asr_feats = self.hparams.normalizer(asr_feats, lens)
                 embed = self.modules.src_embedding(asr_feats)
@@ -119,9 +120,10 @@ class MTLbrain(sb.Brain):
                 predictions["ctc_pout"] = torch.log_softmax(out, dim=-1)
 
             if stage != sb.Stage.TRAIN:
-                predictions["hyps"], _ = self.hparams.beam_searcher(
-                    embed.detach(), lens
-                )
+                hyps, _, _, _ = self.hparams.beam_searcher(embed.detach(), lens)
+
+                # Convert best hypothesis to list
+                predictions["hyps"] = hyps
 
         elif self.hparams.ctc_type is not None:
             if self.hparams.ctc_type == "clean":
@@ -137,18 +139,12 @@ class MTLbrain(sb.Brain):
 
         return predictions
 
-    def prepare_wavs(self, signal, augment=True):
+    def prepare_wavs(self, signal, stage):
         """Prepare possibly enhanced waveforms"""
         wavs, wav_lens = signal
-
-        if self.stage == sb.Stage.TRAIN and hasattr(self.hparams, "env_corr"):
-            if augment:
-                wavs_noise = self.hparams.env_corr(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-            else:
-                wavs = torch.cat([wavs, wavs], dim=0)
-            wav_lens = torch.cat([wav_lens, wav_lens])
-
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
         return wavs, wav_lens
 
     def prepare_feats(self, wavs):
@@ -158,13 +154,13 @@ class MTLbrain(sb.Brain):
         feats = torch.log1p(feats)
         return feats
 
-    def prepare_targets(self, tokens):
+    def prepare_targets(self, tokens, stage):
         """Prepare target by concatenating self if "env_corr" is used"""
         tokens, token_lens = tokens
 
-        if self.stage == sb.Stage.TRAIN and hasattr(self.hparams, "env_corr"):
-            tokens = torch.cat([tokens, tokens], dim=0)
-            token_lens = torch.cat([token_lens, token_lens])
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens = self.hparams.wav_augment.replicate_labels(tokens)
+            token_lens = self.hparams.wav_augment.replicate_labels(token_lens)
 
         return tokens, token_lens
 
@@ -172,7 +168,7 @@ class MTLbrain(sb.Brain):
         """Compute possibly several loss terms: enhance, mimic, ctc, seq"""
 
         # Do not augment targets
-        clean_wavs, lens = self.prepare_wavs(batch.clean_sig, augment=False)
+        clean_wavs, lens = self.prepare_wavs(batch.clean_sig, stage)
         loss = 0
 
         # Compute enhancement loss
@@ -237,7 +233,7 @@ class MTLbrain(sb.Brain):
             not hasattr(self.hparams, "ctc_epochs")
             or self.hparams.epoch_counter.current < self.hparams.ctc_epochs
         ):
-            tokens, token_lens = self.prepare_targets(batch.tokens)
+            tokens, token_lens = self.prepare_targets(batch.tokens, stage)
             ctc_loss = sb.nnet.losses.ctc_loss(
                 predictions["ctc_pout"],
                 tokens,
@@ -264,7 +260,7 @@ class MTLbrain(sb.Brain):
         # Compute nll loss for seq2seq model
         if self.hparams.seq_weight > 0:
 
-            tokens, token_lens = self.prepare_targets(batch.tokens_eos)
+            tokens, token_lens = self.prepare_targets(batch.tokens_eos, stage)
             seq_loss = self.hparams.seq_loss(
                 predictions["seq_pout"], tokens, token_lens
             )
@@ -411,6 +407,7 @@ class MTLbrain(sb.Brain):
             min_key=min_key,
             max_num_checkpoints=self.hparams.checkpoint_avg,
         )
+        logger.info(f"Averaging {len(checkpoints)} Checkpoints...")
         for model in self.modules:
             if (
                 model not in self.hparams.frozen_models
@@ -516,6 +513,8 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
+    if "prepare_noise_data" in hparams:
+        run_on_main(hparams["prepare_noise_data"])
 
     # Load pretrained models
     for model in ["asr", "enhance", "perceptual"]:
diff --git a/recipes/Voicebank/dereverb/MetricGAN-U/train.py b/recipes/Voicebank/dereverb/MetricGAN-U/train.py
index 6fcddee47cdae75121f4afb0cd837172729986a7..293ebd226973310b14fb43d3f4396dad97597020 100644
--- a/recipes/Voicebank/dereverb/MetricGAN-U/train.py
+++ b/recipes/Voicebank/dereverb/MetricGAN-U/train.py
@@ -34,7 +34,7 @@ from speechbrain.dataio.sampler import ReproducibleWeightedRandomSampler
 
 ### For DNSMSOS
 # URL for the web service
-SCORING_URI = "https://dnsmos-4.azurewebsites.net/score"
+SCORING_URI = "https://github.com/microsoft/DNS-Challenge"
 # If the service is authenticated, set the key or token
 AUTH_KEY = ""
 if AUTH_KEY == "":
@@ -418,8 +418,10 @@ class MetricGanBrain(sb.Brain):
                 )
                 self.d_optimizer.zero_grad()
                 loss.backward()
-                if self.check_gradients(loss):
-                    self.d_optimizer.step()
+                torch.nn.utils.clip_grad_norm_(
+                    self.modules.parameters(), self.max_grad_norm
+                )
+                self.d_optimizer.step()
                 loss_tracker += loss.detach() / 3
         elif self.sub_stage == SubStage.HISTORICAL:
             loss = self.compute_objectives(
@@ -427,8 +429,10 @@ class MetricGanBrain(sb.Brain):
             )
             self.d_optimizer.zero_grad()
             loss.backward()
-            if self.check_gradients(loss):
-                self.d_optimizer.step()
+            torch.nn.utils.clip_grad_norm_(
+                self.modules.parameters(), self.max_grad_norm
+            )
+            self.d_optimizer.step()
             loss_tracker += loss.detach()
         elif self.sub_stage == SubStage.GENERATOR:
             for name, param in self.modules.generator.named_parameters():
@@ -442,8 +446,10 @@ class MetricGanBrain(sb.Brain):
             )
             self.g_optimizer.zero_grad()
             loss.backward()
-            if self.check_gradients(loss):
-                self.g_optimizer.step()
+            torch.nn.utils.clip_grad_norm_(
+                self.modules.parameters(), self.max_grad_norm
+            )
+            self.g_optimizer.step()
             loss_tracker += loss.detach()
 
         return loss_tracker
diff --git a/recipes/Voicebank/dereverb/spectral_mask/train.py b/recipes/Voicebank/dereverb/spectral_mask/train.py
index 0aa16e6d2d2f3b15c1d7c58d0a5ba76f9b6dd406..cce0e85677ed756814aa3c45e2f3fd4043d51229 100644
--- a/recipes/Voicebank/dereverb/spectral_mask/train.py
+++ b/recipes/Voicebank/dereverb/spectral_mask/train.py
@@ -148,6 +148,9 @@ class SEBrain(sb.Brain):
         self.optimizer = self.hparams.g_opt_class(
             self.modules.generator.parameters()
         )
+        self.optimizers_dict = {
+            "optimizer": self.optimizer,
+        }
 
 
 def dataio_prep(hparams):
diff --git a/recipes/Voicebank/enhance/MetricGAN-U/train.py b/recipes/Voicebank/enhance/MetricGAN-U/train.py
index 01704d89305f54cf7be1974ebe9e9a94d6eeb22a..ea41e1e3533b045edbcf6b05d61c29767e9fa703 100644
--- a/recipes/Voicebank/enhance/MetricGAN-U/train.py
+++ b/recipes/Voicebank/enhance/MetricGAN-U/train.py
@@ -34,7 +34,7 @@ from speechbrain.dataio.sampler import ReproducibleWeightedRandomSampler
 
 ### For DNSMSOS
 # URL for the web service
-SCORING_URI = "https://dnsmos-4.azurewebsites.net/score"
+SCORING_URI = "https://github.com/microsoft/DNS-Challenge"
 # If the service is authenticated, set the key or token
 AUTH_KEY = ""
 if AUTH_KEY == "":
@@ -412,8 +412,10 @@ class MetricGanBrain(sb.Brain):
                 )
                 self.d_optimizer.zero_grad()
                 loss.backward()
-                if self.check_gradients(loss):
-                    self.d_optimizer.step()
+                torch.nn.utils.clip_grad_norm_(
+                    self.modules.parameters(), self.max_grad_norm
+                )
+                self.d_optimizer.step()
                 loss_tracker += loss.detach() / 3
         elif self.sub_stage == SubStage.HISTORICAL:
             loss = self.compute_objectives(
@@ -421,8 +423,10 @@ class MetricGanBrain(sb.Brain):
             )
             self.d_optimizer.zero_grad()
             loss.backward()
-            if self.check_gradients(loss):
-                self.d_optimizer.step()
+            torch.nn.utils.clip_grad_norm_(
+                self.modules.parameters(), self.max_grad_norm
+            )
+            self.d_optimizer.step()
             loss_tracker += loss.detach()
         elif self.sub_stage == SubStage.GENERATOR:
             for name, param in self.modules.generator.named_parameters():
@@ -436,8 +440,10 @@ class MetricGanBrain(sb.Brain):
             )
             self.g_optimizer.zero_grad()
             loss.backward()
-            if self.check_gradients(loss):
-                self.g_optimizer.step()
+            torch.nn.utils.clip_grad_norm_(
+                self.modules.parameters(), self.max_grad_norm
+            )
+            self.g_optimizer.step()
             loss_tracker += loss.detach()
 
         return loss_tracker
diff --git a/recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml b/recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml
index b9dd9a223709988e7f46f56cd38b2f4226e6b993..ce0b58a8c296bf8de0c52c407375ef46d3db6396 100644
--- a/recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml
+++ b/recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml
@@ -12,7 +12,6 @@ seed: 4234
 __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
 
 data_folder: !PLACEHOLDER # e.g, /data/member1/user_jasonfu/noisy-vctk-16k
-train_clean_folder: !ref <data_folder>/clean_trainset_28spk_wav_16k/
 
 MetricGAN_folder: !ref <output_folder>/enhanced_wavs
 output_folder: !ref ./results/MetricGAN/<seed>
diff --git a/recipes/Voicebank/enhance/MetricGAN/train.py b/recipes/Voicebank/enhance/MetricGAN/train.py
index 12bdac6a22354903ad13494992ed34f19bddd05b..2da4c849529237467c5b993c0145f51282ed2f4b 100644
--- a/recipes/Voicebank/enhance/MetricGAN/train.py
+++ b/recipes/Voicebank/enhance/MetricGAN/train.py
@@ -296,8 +296,10 @@ class MetricGanBrain(sb.Brain):
                 )
                 self.d_optimizer.zero_grad()
                 loss.backward()
-                if self.check_gradients(loss):
-                    self.d_optimizer.step()
+                torch.nn.utils.clip_grad_norm_(
+                    self.modules.parameters(), self.max_grad_norm
+                )
+                self.d_optimizer.step()
                 loss_tracker += loss.detach() / 3
         elif self.sub_stage == SubStage.HISTORICAL:
             loss = self.compute_objectives(
@@ -305,8 +307,10 @@ class MetricGanBrain(sb.Brain):
             )
             self.d_optimizer.zero_grad()
             loss.backward()
-            if self.check_gradients(loss):
-                self.d_optimizer.step()
+            torch.nn.utils.clip_grad_norm_(
+                self.modules.parameters(), self.max_grad_norm
+            )
+            self.d_optimizer.step()
             loss_tracker += loss.detach()
         elif self.sub_stage == SubStage.GENERATOR:
             for name, param in self.modules.generator.named_parameters():
@@ -322,8 +326,10 @@ class MetricGanBrain(sb.Brain):
 
             self.g_optimizer.zero_grad()
             loss.backward()
-            if self.check_gradients(loss):
-                self.g_optimizer.step()
+            torch.nn.utils.clip_grad_norm_(
+                self.modules.parameters(), self.max_grad_norm
+            )
+            self.g_optimizer.step()
             loss_tracker += loss.detach()
 
         return loss_tracker
diff --git a/recipes/Voicebank/enhance/SEGAN/train.py b/recipes/Voicebank/enhance/SEGAN/train.py
index f9cbcadb8f1ac07e3a259839c5b02360866231f9..9f91cfb3384776e870611a993fca6c6cdf02fd2e 100644
--- a/recipes/Voicebank/enhance/SEGAN/train.py
+++ b/recipes/Voicebank/enhance/SEGAN/train.py
@@ -167,8 +167,10 @@ class SEBrain(sb.Brain):
         out_d1 = self.compute_forward_d(noisy_wavs, clean_wavs)
         loss_d1 = self.compute_objectives_d1(out_d1, batch)
         loss_d1.backward()
-        if self.check_gradients(loss_d1):
-            self.optimizer_d.step()
+        torch.nn.utils.clip_grad_norm_(
+            self.modules.parameters(), self.max_grad_norm
+        )
+        self.optimizer_d.step()
         self.optimizer_d.zero_grad()
 
         # second training step
@@ -181,8 +183,10 @@ class SEBrain(sb.Brain):
         out_d2 = self.compute_forward_d(out_g2, clean_wavs)
         loss_d2 = self.compute_objectives_d2(out_d2, batch)
         loss_d2.backward(retain_graph=True)
-        if self.check_gradients(loss_d2):
-            self.optimizer_d.step()
+        torch.nn.utils.clip_grad_norm_(
+            self.modules.parameters(), self.max_grad_norm
+        )
+        self.optimizer_d.step()
         self.optimizer_d.zero_grad()
 
         # third (last) training step
@@ -198,8 +202,10 @@ class SEBrain(sb.Brain):
             z_logvar=z_logvar,
         )
         loss_g3.backward()
-        if self.check_gradients(loss_g3):
-            self.optimizer_g.step()
+        torch.nn.utils.clip_grad_norm_(
+            self.modules.parameters(), self.max_grad_norm
+        )
+        self.optimizer_g.step()
         self.optimizer_g.zero_grad()
         self.optimizer_d.zero_grad()
 
diff --git a/recipes/Voicebank/voicebank_prepare.py b/recipes/Voicebank/voicebank_prepare.py
index 89ccadc47adf896c13a261c1dae0a75635bfdaf0..333f0427d5897ec7a13eaa3a2dab302251cf862b 100644
--- a/recipes/Voicebank/voicebank_prepare.py
+++ b/recipes/Voicebank/voicebank_prepare.py
@@ -432,7 +432,7 @@ def download_vctk(destination, tmp_dir=None, device="cpu"):
         "clean_trainset_28spk_wav",
     ]
 
-    downsampler = Resample(orig_freq=48000, new_freq=16000)
+    downsampler = Resample(orig_freq=48000, new_freq=16000).to(device)
 
     for directory in dirs:
         logger.info("Resampling " + directory)
diff --git a/recipes/VoxCeleb/SpeakerRec/README.md b/recipes/VoxCeleb/SpeakerRec/README.md
index e23b38ba83010c37a1143613c83629435e22293c..7150220b06f669a64405fb07512158fd3457c631 100644
--- a/recipes/VoxCeleb/SpeakerRec/README.md
+++ b/recipes/VoxCeleb/SpeakerRec/README.md
@@ -83,26 +83,30 @@ Below results are all obtained with the official verification split of voxceleb1
 [Speaker verification results (in EER) on VoxCeleb1-O, with score normalization]
 | System          | Dataset    | EER  | Model/Log Link |
 |-----------------|------------|------| -----|
-| Xvector + PLDA  | VoxCeleb 1,2 | 3.23% | https://www.dropbox.com/sh/mau2nrt6i81ctfc/AAAUkAECzVaVWUMjD3mytjgea?dl=0 |
+| Xvector + PLDA  | VoxCeleb 1,2 | 3.23% | https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0 |
 | ECAPA-TDNN      | VoxCeleb 1,2 | 0.80% | https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0 |
-| ResNet TDNN     | VoxCeleb 1,2 | 0.95% | https://www.dropbox.com/sh/yvqn7tn6iqztx9k/AAAhhhbOCUJ47C0LbcpUlzYUa?dl=0 |
+| ResNet TDNN     | VoxCeleb 1,2 | 0.95% | https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0 |
 
 [Speaker verification results (in EER), no score normalization]
 | System          | Dataset    | VoxCeleb1-O  | VoxCeleb1-E  | VoxCeleb1-H  | Model/Log Link |
 |-----------------|------------|------|------|------| -----|
 | ECAPA-TDNN      | VoxCeleb 1,2 | 0.90% | - | - | https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0 |
 | ECAPA-TDNN      | VoxCeleb 2 | 1.30% | 1.98% | 3.62% | (to be updated) |
-| ResNet TDNN     | VoxCeleb 1,2 | 1.05% | - | - | https://www.dropbox.com/sh/yvqn7tn6iqztx9k/AAAhhhbOCUJ47C0LbcpUlzYUa?dl=0  |
+| ResNet TDNN     | VoxCeleb 1,2 | 1.05% | - | - | https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0  |
 
 
 ## PreTrained Model + Easy-Inference
 You can perform the easy-inference of various models provided on [HuggingFace](https://huggingface.co) via the links below. They are specified in the hyperparameter yaml files as well.
+
+**NOTE: If you would like to store the embeddings for future use, please check `extract_speaker_embeddings.py` for the gist.**
+
 | System          | Hugging Face model link |
 |-----------------|-------------------------|
 | Xvector         | https://huggingface.co/speechbrain/spkrec-xvect-voxceleb |
 | ECAPA-TDNN      | https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb |
 | ResNet TDNN     | https://huggingface.co/speechbrain/spkrec-resnet-voxceleb |
 
+
 # **About SpeechBrain**
 - Website: https://speechbrain.github.io/
 - Code: https://github.com/speechbrain/speechbrain/
diff --git a/recipes/VoxCeleb/SpeakerRec/extract_speaker_embeddings.py b/recipes/VoxCeleb/SpeakerRec/extract_speaker_embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..658e5e08d4210ce769cf64866ee6830873a3fa9c
--- /dev/null
+++ b/recipes/VoxCeleb/SpeakerRec/extract_speaker_embeddings.py
@@ -0,0 +1,132 @@
+#!/usr/bin/python3
+"""Recipe for extracting speaker embeddings for other purpose. This
+is more like a script that copes with modern usage of speaker embed-
+ding vectors.
+
+The input of this script is a training list like below
+(we recommend having full absolute path for wav paths)
+----------
+utt1 $wav1_path
+...
+uttN $wavN_path
+
+The extracted embeddings are stored as numpy files in the output
+folder. The name of each numpy file is its utterance name.
+NOTE: This may result in a large number of files in a single folder.
+
+To run this recipe, use the following command:
+> python extract_speaker_embeddings.py {input_training_list} {output_folder} {hyperparameter_file}
+
+Using your own hyperparameter file or one of the following:
+    hparams/verification_ecapa.yaml (for the ecapa+tdnn system)
+    hparams/verification_resnet.yaml (for the resnet tdnn system)
+    hparams/verification_plda_xvector.yaml (for the xvector system)
+
+Author
+    * Mirco Ravanelli 2020
+    * Hwidong Na 2020
+    * Nauman Dawalatabad 2020
+    * Xuechen Liu 2023
+"""
+import os
+import sys
+
+import numpy as np
+import torch
+import logging
+import torchaudio
+import speechbrain as sb
+from hyperpyyaml import load_hyperpyyaml
+
+from speechbrain.utils.distributed import run_on_main
+from speechbrain.utils.data_utils import download_file
+
+
+def compute_embeddings_single(wavs, wav_lens, params):
+    """Compute speaker embeddings.
+
+    Arguments
+    ---------
+    wavs : Torch.Tensor
+        Tensor containing the speech waveform (batch, time).
+        Make sure the sample rate is fs=16000 Hz.
+    wav_lens: Torch.Tensor
+        Tensor containing the relative length for each sentence
+        in the length (e.g., [0.8 0.6 1.0])
+    """
+    with torch.no_grad():
+        feats = params["compute_features"](wavs)
+        feats = params["mean_var_norm"](feats, wav_lens)
+        embeddings = params["embedding_model"](feats, wav_lens)
+    return embeddings.squeeze(1)
+
+
+def compute_embeddings(params, wav_scp, outdir):
+    """Compute speaker embeddings.
+
+    Arguments
+    ---------
+    params: dict
+        The parameter files storing info about model, data, etc
+    wav_scp : str
+        The wav.scp file in Kaldi, in the form of "$utt $wav_path"
+    outdir: str
+        The output directory where we store the embeddings in per-
+        numpy manner.
+    """
+    with torch.no_grad():
+        with open(wav_scp, "r") as wavscp:
+            for line in wavscp:
+                utt, wav_path = line.split()
+                out_file = "{}/{}.npy".format(outdir, utt)
+                wav, _ = torchaudio.load(wav_path)
+                data = wav.transpose(0, 1).squeeze(1).unsqueeze(0)
+                lens = torch.Tensor([data.shape[1]])
+                data, lens = (
+                    data.to(run_opts["device"]),
+                    lens.to(run_opts["device"]),
+                )
+                embedding = compute_embeddings_single(
+                    data, lens, params
+                ).squeeze()
+
+                out_embedding = embedding.detach().cpu().numpy()
+                np.save(out_file, out_embedding)
+                del out_embedding, wav, data
+
+
+if __name__ == "__main__":
+    in_list = sys.argv[1]
+    out_dir = sys.argv[2]
+    os.makedirs(out_dir, exist_ok=True)
+
+    # Logger setup
+    logger = logging.getLogger(__name__)
+    current_dir = os.path.dirname(os.path.abspath(__file__))
+    sys.path.append(os.path.dirname(current_dir))
+
+    # Load hyperparameters file with command-line overrides
+    params_file, run_opts, overrides = sb.core.parse_arguments(sys.argv[3:])
+    if "data_folder:" not in overrides:
+        # By default it is a PLACEHOLDER (we need to replace it with a dummy path)
+        overrides += "\ndata_folder: ."
+    if "output_folder:" not in overrides:
+        # Ensure to put the saved model in the output folder
+        overrides += f"\noutput_folder: {out_dir}"
+
+    with open(params_file) as fin:
+        params = load_hyperpyyaml(fin, overrides)
+    run_on_main(params["pretrainer"].collect_files)
+    params["pretrainer"].load_collected(run_opts["device"])
+    params["embedding_model"].eval()
+    params["embedding_model"].to(run_opts["device"])
+
+    # Download verification list (to exlude verification sentences from train)
+    veri_file_path = os.path.join(
+        params["save_folder"], os.path.basename(params["verification_file"])
+    )
+    download_file(params["verification_file"], veri_file_path)
+
+    print("Begin embedding extraction......")
+    compute_embeddings(params, in_list, out_dir)
+    print("The embeddings have been extracted and stored at {}".format(out_dir))
diff --git a/recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn.yaml b/recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn.yaml
index 69af4bf114eb383212f103e2e1679588fe62bec3..c93ba21ecb1109e9cf67ce5c02a4255fff16192a 100644
--- a/recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn.yaml
+++ b/recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn.yaml
@@ -10,13 +10,18 @@ output_folder: !ref results/ecapa_augment/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/Voxceleb
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
 train_annotation: !ref <save_folder>/train.csv
 valid_annotation: !ref <save_folder>/dev.csv
-
-# Folder to extract data augmentation files
-rir_folder: !ref <data_folder> # Change it if needed
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
 
 # Use the following links for the official voxceleb splits:
 # VoxCeleb1 (cleaned): https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt
@@ -51,10 +56,11 @@ deltas: False
 # Number of speakers
 out_n_neurons: 7205 #1211 for vox1  # 5994 for vox2, 7205 for vox1+vox2
 
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
     shuffle: !ref <shuffle>
-    num_workers: 2
+    num_workers: !ref <num_workers>
 
 # Functions
 compute_features: !new:speechbrain.lobes.features.Fbank
@@ -79,56 +85,66 @@ classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-
-augment_wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [100]
-
-augment_speed: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-add_rev: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 0.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_rev_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-
-# Definition of the augmentation pipeline.
-# If concat_augment = False, the augmentation techniques are applied
-# in sequence. If concat_augment = True, all the augmented signals
-# # are concatenated in a single big batch.
-
-augment_pipeline: [
-    !ref <augment_wavedrop>,
-    !ref <augment_speed>,
-    !ref <add_rev>,
-    !ref <add_noise>,
-    !ref <add_rev_noise>
-]
-concat_augment: True
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: True
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 mean_var_norm: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
@@ -136,11 +152,6 @@ mean_var_norm: !new:speechbrain.processing.features.InputNormalization
 
 modules:
     compute_features: !ref <compute_features>
-    augment_wavedrop: !ref <augment_wavedrop>
-    augment_speed: !ref <augment_speed>
-    add_rev: !ref <add_rev>
-    add_noise: !ref <add_noise>
-    add_rev_noise: !ref <add_rev_noise>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     mean_var_norm: !ref <mean_var_norm>
diff --git a/recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn_mel_spec.yaml b/recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn_mel_spec.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a2c1999b7ffdb09f6b181624d878a08bdb9848b2
--- /dev/null
+++ b/recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn_mel_spec.yaml
@@ -0,0 +1,208 @@
+# ################################
+# Model: Speaker identification with ECAPA
+# Authors: Hwidong Na & Mirco Ravanelli
+# ################################
+
+# Basic parameters
+seed: 1986
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+output_folder: !ref results/ecapa_augment/<seed>
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+
+# Data files
+data_folder: !PLACEHOLDER  # e.g. /path/to/Voxceleb
+train_annotation: !ref <save_folder>/train.csv
+valid_annotation: !ref <save_folder>/dev.csv
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
+data_folder_noise: !ref <data_folder>/noise
+data_folder_rir: !ref <data_folder>/rir
+# Folder to extract data augmentation files
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
+
+# Use the following links for the official voxceleb splits:
+# VoxCeleb1 (cleaned): https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt
+# VoxCeleb1-H (cleaned): https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/list_test_hard2.txt
+# VoxCeleb1-E (cleaned): https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/list_test_all2.txt.
+# VoxCeleb1-E and VoxCeleb1-H lists are drawn from the VoxCeleb1 training set.
+# Therefore you cannot use any files in VoxCeleb1 for training if you are using these lists for testing.
+verification_file: https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt
+
+split_ratio: [90, 10]
+skip_prep: False
+ckpt_interval_minutes: 15 # save checkpoint every N min
+
+# Training parameters
+number_of_epochs: 10
+batch_size: 32
+lr: 0.001
+base_lr: 0.00000001
+max_lr: !ref <lr>
+step_size: 65000
+sample_rate: 16000
+sentence_len: 3.0 # seconds
+shuffle: True
+random_chunk: True
+
+# Feature parameters
+hop_length: 256
+win_length: 1024
+n_mel_channels: 80
+n_fft: 1024
+mel_fmin: 0.0
+mel_fmax: 8000.0
+mel_normalized: False
+power: 1
+norm: "slaney"
+mel_scale: "slaney"
+dynamic_range_compression: True
+
+# Number of speakers
+out_n_neurons: 7205 #1211 for vox1  # 5994 for vox2, 7205 for vox1+vox2
+
+num_workers: 4
+dataloader_options:
+    batch_size: !ref <batch_size>
+    shuffle: !ref <shuffle>
+    num_workers: !ref <num_workers>
+
+# Functions
+use_tacotron2_mel_spec: True
+
+compute_features: !name:speechbrain.lobes.models.Tacotron2.mel_spectogram
+    sample_rate: !ref <sample_rate>
+    hop_length: !ref <hop_length>
+    win_length: !ref <win_length>
+    n_fft: !ref <n_fft>
+    n_mels: !ref <n_mel_channels>
+    f_min: !ref <mel_fmin>
+    f_max: !ref <mel_fmax>
+    power: !ref <power>
+    normalized: !ref <mel_normalized>
+    norm: !ref <norm>
+    mel_scale: !ref <mel_scale>
+    compression: !ref <dynamic_range_compression>
+
+# Modules
+embedding_model: !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN
+    input_size: !ref <n_mel_channels>
+    channels: [1024, 1024, 1024, 1024, 3072]
+    kernel_sizes: [5, 3, 3, 3, 1]
+    dilations: [1, 2, 3, 4, 1]
+    groups: [1, 1, 1, 1, 1]
+    attention_channels: 128
+    lin_neurons: 192
+
+classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
+    input_size: 192
+    out_neurons: !ref <out_n_neurons>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: True
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+mean_var_norm: !new:speechbrain.processing.features.InputNormalization
+    norm_type: sentence
+    std_norm: False
+
+modules:
+    embedding_model: !ref <embedding_model>
+    classifier: !ref <classifier>
+    mean_var_norm: !ref <mean_var_norm>
+
+compute_cost: !new:speechbrain.nnet.losses.LogSoftmaxWrapper
+    loss_fn: !new:speechbrain.nnet.losses.AdditiveAngularMargin
+        margin: 0.2
+        scale: 30
+
+# compute_error: !name:speechbrain.nnet.losses.classification_error
+
+opt_class: !name:torch.optim.Adam
+    lr: !ref <lr>
+    weight_decay: 0.000002
+
+lr_annealing: !new:speechbrain.nnet.schedulers.CyclicLRScheduler
+    base_lr: !ref <base_lr>
+    max_lr: !ref <max_lr>
+    step_size: !ref <step_size>
+
+# Logging + checkpoints
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+error_stats: !name:speechbrain.utils.metric_stats.MetricStats
+    metric: !name:speechbrain.nnet.losses.classification_error
+        reduction: batch
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        embedding_model: !ref <embedding_model>
+        classifier: !ref <classifier>
+        normalizer: !ref <mean_var_norm>
+        counter: !ref <epoch_counter>
+        lr_annealing: !ref <lr_annealing>
diff --git a/recipes/VoxCeleb/SpeakerRec/hparams/train_resnet.yaml b/recipes/VoxCeleb/SpeakerRec/hparams/train_resnet.yaml
index 40120be180bdc35f76874d7112469b2239e70f6c..a20786574601e7d61b37848db29cf66493d2255e 100644
--- a/recipes/VoxCeleb/SpeakerRec/hparams/train_resnet.yaml
+++ b/recipes/VoxCeleb/SpeakerRec/hparams/train_resnet.yaml
@@ -11,14 +11,18 @@ output_folder: !ref results/resnet/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
 # Data files
-# data_folder: !PLACEHOLDER  # e.g. /path/to/Voxceleb
-data_folder: /home/smg/xuecliu/speechbrain/recipes/VoxCeleb/SpeakerRec/data/voxceleb
+data_folder: !PLACEHOLDER  # e.g. /path/to/Voxceleb
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
 train_annotation: !ref <save_folder>/train.csv
 valid_annotation: !ref <save_folder>/dev.csv
-
-# Folder to extract data augmentation files
-rir_folder: !ref data # Change it if needed
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
 
 # Use the following links for the official voxceleb splits:
 # VoxCeleb1 (cleaned): https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt
@@ -51,13 +55,14 @@ right_frames: 0
 deltas: False
 
 # Number of speakers
-# 1211 for vox1, 5994 for vox2, 7205 for vox1+vox2
+# 1211 for vox1, 5994 for vox2, 7205 for vox1+vox2
 out_n_neurons: 7205
 
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
     shuffle: !ref <shuffle>
-    num_workers: 2
+    num_workers: !ref <num_workers>
 
 # Functions
 compute_features: !new:speechbrain.lobes.features.Fbank
@@ -80,56 +85,66 @@ classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-
-augment_wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [100]
-
-augment_speed: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-add_rev: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 0.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_rev_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-
-# Definition of the augmentation pipeline.
-# If concat_augment = False, the augmentation techniques are applied
-# in sequence. If concat_augment = True, all the augmented signals
-# # are concatenated in a single big batch.
-
-augment_pipeline: [
-    !ref <augment_wavedrop>,
-    !ref <augment_speed>,
-    !ref <add_rev>,
-    !ref <add_noise>,
-    !ref <add_rev_noise>
-]
-concat_augment: True
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: True
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 mean_var_norm: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
@@ -137,11 +152,6 @@ mean_var_norm: !new:speechbrain.processing.features.InputNormalization
 
 modules:
     compute_features: !ref <compute_features>
-    augment_wavedrop: !ref <augment_wavedrop>
-    augment_speed: !ref <augment_speed>
-    add_rev: !ref <add_rev>
-    add_noise: !ref <add_noise>
-    add_rev_noise: !ref <add_rev_noise>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     mean_var_norm: !ref <mean_var_norm>
diff --git a/recipes/VoxCeleb/SpeakerRec/hparams/train_x_vectors.yaml b/recipes/VoxCeleb/SpeakerRec/hparams/train_x_vectors.yaml
index 48afdfa2b0a55071defae95778f2c59386e2da59..ab628c681b9729ea2cf431593f524c151c599570 100644
--- a/recipes/VoxCeleb/SpeakerRec/hparams/train_x_vectors.yaml
+++ b/recipes/VoxCeleb/SpeakerRec/hparams/train_x_vectors.yaml
@@ -10,13 +10,18 @@ output_folder: !ref results/xvect_augment/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
 # Data files
 data_folder: !PLACEHOLDER  # e.g. /path/to/Voxceleb
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
 train_annotation: !ref <save_folder>/train.csv
 valid_annotation: !ref <save_folder>/dev.csv
-
-# Folder to extract data augmentation files
-rir_folder: !ref <data_folder> # Change it if needed
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
 
 # Use the following links for the official voxceleb splits:
 # VoxCeleb1 (cleaned): https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt
@@ -51,10 +56,11 @@ deltas: False
 out_n_neurons: 7205 #1211 for vox1  # 5994 for vox2, 7205 for vox1+vox2
 emb_dim: 512
 
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
     shuffle: !ref <shuffle>
-    num_workers: 0
+    num_workers: !ref <num_workers>
 
 # Functions
 compute_features: !new:speechbrain.lobes.features.Fbank
@@ -82,55 +88,66 @@ classifier: !new:speechbrain.lobes.models.Xvector.Classifier
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-
-augment_wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [100]
-
-augment_speed: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-add_rev: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 0.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_rev_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-
-# Definition of the augmentation pipeline.
-# If concat_augment = False, the augmentation techniques are applied
-# in sequence. If concat_augment = True, all the augmented signals
-# are concatenated in a single big batch.
-augment_pipeline: [
-    !ref <augment_wavedrop>,
-    !ref <augment_speed>,
-    !ref <add_rev>,
-    !ref <add_noise>,
-    !ref <add_rev_noise>
-]
-concat_augment: True
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: True
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 mean_var_norm: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
@@ -138,11 +155,6 @@ mean_var_norm: !new:speechbrain.processing.features.InputNormalization
 
 modules:
     compute_features: !ref <compute_features>
-    augment_wavedrop: !ref <augment_wavedrop>
-    augment_speed: !ref <augment_speed>
-    add_rev: !ref <add_rev>
-    add_noise: !ref <add_noise>
-    add_rev_noise: !ref <add_rev_noise>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     mean_var_norm: !ref <mean_var_norm>
diff --git a/recipes/VoxCeleb/SpeakerRec/speaker_verification_cosine.py b/recipes/VoxCeleb/SpeakerRec/speaker_verification_cosine.py
index e954932e1a9b08b1b70ad0992d11994587e2e68c..71b7bf16cf8772cbe0e66a53ea88e9cebf67b944 100755
--- a/recipes/VoxCeleb/SpeakerRec/speaker_verification_cosine.py
+++ b/recipes/VoxCeleb/SpeakerRec/speaker_verification_cosine.py
@@ -254,7 +254,7 @@ if __name__ == "__main__":
     # We download the pretrained LM from HuggingFace (or elsewhere depending on
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     run_on_main(params["pretrainer"].collect_files)
-    params["pretrainer"].load_collected(run_opts["device"])
+    params["pretrainer"].load_collected()
     params["embedding_model"].eval()
     params["embedding_model"].to(run_opts["device"])
 
diff --git a/recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py b/recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py
index 0c88253cc54e54f401984d2b5d8d95407a54c05b..d5a2a0cffd35cd3f8e1587bc8340c4383d5e2c54 100755
--- a/recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py
+++ b/recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py
@@ -36,34 +36,19 @@ class SpeakerBrain(sb.core.Brain):
         batch = batch.to(self.device)
         wavs, lens = batch.sig
 
-        if stage == sb.Stage.TRAIN:
-            # Applying the augmentation pipeline
-            wavs_aug_tot = []
-            wavs_aug_tot.append(wavs)
-            for count, augment in enumerate(self.hparams.augment_pipeline):
-                # Apply augment
-                wavs_aug = augment(wavs, lens)
-
-                # Managing speed change
-                if wavs_aug.shape[1] > wavs.shape[1]:
-                    wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
-                else:
-                    zero_sig = torch.zeros_like(wavs)
-                    zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
-                    wavs_aug = zero_sig
-
-                if self.hparams.concat_augment:
-                    wavs_aug_tot.append(wavs_aug)
-                else:
-                    wavs = wavs_aug
-                    wavs_aug_tot[0] = wavs
-
-            wavs = torch.cat(wavs_aug_tot, dim=0)
-            self.n_augment = len(wavs_aug_tot)
-            lens = torch.cat([lens] * self.n_augment)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, lens = self.hparams.wav_augment(wavs, lens)
 
         # Feature extraction and normalization
-        feats = self.modules.compute_features(wavs)
+        if (
+            hasattr(self.hparams, "use_tacotron2_mel_spec")
+            and self.hparams.use_tacotron2_mel_spec
+        ):
+            feats = self.hparams.compute_features(audio=wavs)
+            feats = torch.transpose(feats, 1, 2)
+        else:
+            feats = self.modules.compute_features(wavs)
         feats = self.modules.mean_var_norm(feats, lens)
 
         # Embeddings + speaker classifier
@@ -79,8 +64,8 @@ class SpeakerBrain(sb.core.Brain):
         spkid, _ = batch.spk_id_encoded
 
         # Concatenate labels (due to data augmentation)
-        if stage == sb.Stage.TRAIN:
-            spkid = torch.cat([spkid] * self.n_augment, dim=0)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            spkid = self.hparams.wav_augment.replicate_labels(spkid)
 
         loss = self.hparams.compute_cost(predictions, spkid, lens)
 
@@ -223,6 +208,8 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
+    sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
+    sb.utils.distributed.run_on_main(hparams["prepare_rir_data"])
 
     # Dataset IO prep: creating Dataset objects and proper encodings for phones
     train_data, valid_data, label_encoder = dataio_prep(hparams)
diff --git a/recipes/VoxLingua107/lang_id/hparams/train_ecapa.yaml b/recipes/VoxLingua107/lang_id/hparams/train_ecapa.yaml
index 42258884ebb2ab7c1bdc91345d8525725209ab0d..db29a301b80588ce8b2f0d82b0f4d4da16afa24b 100644
--- a/recipes/VoxLingua107/lang_id/hparams/train_ecapa.yaml
+++ b/recipes/VoxLingua107/lang_id/hparams/train_ecapa.yaml
@@ -10,8 +10,6 @@ output_folder: !ref results/epaca/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 data_folder: !PLACEHOLDER
-rir_folder: !ref <data_folder>
-# skip_prep: False
 
 shards_url: /data/voxlingua107_shards
 train_meta: !ref <shards_url>/train/meta.json
@@ -19,6 +17,15 @@ val_meta: !ref <shards_url>/dev/meta.json
 train_shards: !ref <shards_url>/train/shard-{000000..000507}.tar
 val_shards: !ref <shards_url>/dev/shard-000000.tar
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
+
+
 # Set to directory on a large disk if you are training on Webdataset shards hosted on the web
 shard_cache_dir:
 
@@ -40,15 +47,65 @@ deltas: False
 # Number of languages
 out_n_neurons: 107
 
+num_workers: 4
+batch_size: 128
+batch_size_val: 32
 train_dataloader_options:
-    num_workers: 4
-    batch_size: 128
+    num_workers: !ref <num_workers>
+    batch_size: !ref <batch_size>
 
 val_dataloader_options:
     num_workers: 1
-    batch_size: 32
-
-# Functions
+    batch_size: !ref <batch_size_val>
+
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    shuffle_augmentations: True
+    min_augmentations: 1
+    max_augmentations: 3
+    augmentations: [
+        !ref <add_reverb>,
+        !ref <add_noise>,
+        !ref <speed_perturb>]
+
+    # Functions
 compute_features: !new:speechbrain.lobes.features.Fbank
     n_mels: !ref <n_mels>
     left_frames: !ref <left_frames>
@@ -74,39 +131,12 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
 
-augment_speed: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [90, 100, 110]
-
-
-add_rev_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 0.5
-    noise_prob: 0.8
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-# Definition of the augmentation pipeline.
-# If concat_augment = False, the augmentation techniques are applied
-# in sequence. If concat_augment = True, all the augmented signals
-# # are concatenated in a single big batch.
-augment_pipeline: [
-    !ref <augment_speed>,
-    !ref <add_rev_noise>
-]
-
-concat_augment: False
-
 mean_var_norm: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
     std_norm: False
 
 modules:
     compute_features: !ref <compute_features>
-    augment_speed: !ref <augment_speed>
-    add_rev_noise: !ref <add_rev_noise>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     mean_var_norm: !ref <mean_var_norm>
diff --git a/recipes/VoxLingua107/lang_id/train.py b/recipes/VoxLingua107/lang_id/train.py
index b261c0f9c32939081900d0c71deb13b837d92341..8c880002df47668aa5a4782cc44b0e32d0f57aa4 100644
--- a/recipes/VoxLingua107/lang_id/train.py
+++ b/recipes/VoxLingua107/lang_id/train.py
@@ -46,33 +46,9 @@ class LanguageBrain(sb.core.Brain):
         batch = batch.to(self.device)
         wavs, lens = batch.sig
 
-        if stage == sb.Stage.TRAIN:
-
-            # Applying the augmentation pipeline
-            wavs_aug_tot = []
-            wavs_aug_tot.append(wavs)
-            for count, augment in enumerate(self.hparams.augment_pipeline):
-
-                # Apply augment
-                wavs_aug = augment(wavs, lens)
-
-                # Managing speed change
-                if wavs_aug.shape[1] > wavs.shape[1]:
-                    wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
-                else:
-                    zero_sig = torch.zeros_like(wavs)
-                    zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
-                    wavs_aug = zero_sig
-
-                if self.hparams.concat_augment:
-                    wavs_aug_tot.append(wavs_aug)
-                else:
-                    wavs = wavs_aug
-                    wavs_aug_tot[0] = wavs
-
-            wavs = torch.cat(wavs_aug_tot, dim=0)
-            self.n_augment = len(wavs_aug_tot)
-            lens = torch.cat([lens] * self.n_augment)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, lens = self.hparams.wav_augment(wavs, lens)
 
         # Feature extraction and normalization
         feats = self.modules.compute_features(wavs)
@@ -92,8 +68,8 @@ class LanguageBrain(sb.core.Brain):
         langid = batch.lang_id_encoded
 
         # Concatenate labels (due to data augmentation)
-        if stage == sb.Stage.TRAIN:
-            langid = torch.cat([langid] * self.n_augment, dim=0)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            langid = self.hparams.wav_augment.replicate_labels(langid)
 
         # breakpoint()
         loss = self.hparams.compute_cost(predictions, langid.unsqueeze(1), lens)
@@ -237,6 +213,10 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
+    # Data preparation for augmentation
+    sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
+    sb.utils.distributed.run_on_main(hparams["prepare_rir_data"])
+
     (
         train_data,
         valid_data,
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-wham-DM.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-wham-DM.yaml
index 28013c7635b31350fd72277af9f7e7d0d1ebc90d..843c9fb0917fcb5a1aed07668a0426be2d59d558 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-wham-DM.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-wham-DM.yaml
@@ -38,14 +38,14 @@ test_data: !ref <save_folder>/whamorg_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 8000
 n_audio_to_save: 20
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 8
 lr: 0.0001
@@ -72,18 +72,39 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-whamr-DM.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-whamr-DM.yaml
index 5a49e2155ff46919fafec99896412231b51e6ee8..b55f05a1169c344207f7fee4a885a425be200b5a 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-whamr-DM.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-whamr-DM.yaml
@@ -38,14 +38,14 @@ test_data: !ref <save_folder>/whamr_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 8000
 n_audio_to_save: 20
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 8
 lr: 0.0001
@@ -73,18 +73,39 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/convtasnet-whamr-DM.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/convtasnet-whamr-DM.yaml
index 8a25b3a6ad84c79fd376e321453212f3ebf12407..545428d5ac31f32ee81dd2af7a616040cfed231e 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/convtasnet-whamr-DM.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/convtasnet-whamr-DM.yaml
@@ -37,14 +37,14 @@ test_data: !ref <save_folder>/whamr_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 8000
 n_audio_to_save: 20
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 10
 lr: 0.00015
@@ -72,18 +72,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/dprnn-whamr-DM.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/dprnn-whamr-DM.yaml
index 7525c1860482db8562e8004d2e447b5941652ca9..d974b03c11edf0c0db0698f3b7f30d81a4c9a5a0 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/dprnn-whamr-DM.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/dprnn-whamr-DM.yaml
@@ -37,14 +37,14 @@ test_data: !ref <save_folder>/whamr_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 8000
 n_audio_to_save: 20
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -72,18 +72,39 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-wham.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-wham.yaml
index faaa2bb5d8970b6d49aa0f418158123e00f71e9a..df1935306b47512947cd04962744c0217192811f 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-wham.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-wham.yaml
@@ -38,13 +38,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -66,18 +66,39 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k-DM.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k-DM.yaml
index 17b39b7814d0e2737a857c1a47498febe9b5ab9c..dc27834919fca3b9cff38b3a12e65fb56b18c653 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k-DM.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k-DM.yaml
@@ -39,14 +39,14 @@ test_data: !ref <save_folder>/whamr_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 16000
 n_audio_to_save: 20
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -74,18 +74,39 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k.yaml
index 6ac93e65ac90bf5ce6abac8e6aa66663af17a08a..d11332c7e0702ea5cae1c778a4f8b7f90c13ed58 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k.yaml
@@ -37,14 +37,14 @@ test_data: !ref <save_folder>/whamr_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 16000
 n_audio_to_save: 20
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -72,18 +72,39 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-DM.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-DM.yaml
index 01794cfda89e4c0f140410fef21c82732578b3f5..3721698bced2416d6f97521978eb18934a0f3a6c 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-DM.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-DM.yaml
@@ -37,14 +37,14 @@ test_data: !ref <save_folder>/whamr_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 8000
 n_audio_to_save: 20
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -72,18 +72,39 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr.yaml b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr.yaml
index de2cfaaed2526fed5d1179e2859e50a41d6362be..14a38a06b6b442e3d1b9086e6889a0125a884ef1 100644
--- a/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr.yaml
+++ b/recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr.yaml
@@ -37,14 +37,14 @@ test_data: !ref <save_folder>/whamr_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 1 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 8000
 n_audio_to_save: 20
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -72,18 +72,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/enhancement/train.py b/recipes/WHAMandWHAMR/enhancement/train.py
index 5a3e4e3af084a9b19d38b4643e5c1c038ac2eed9..887754565ca1f95716b5a45715082e9350eab377 100755
--- a/recipes/WHAMandWHAMR/enhancement/train.py
+++ b/recipes/WHAMandWHAMR/enhancement/train.py
@@ -26,7 +26,6 @@ import torchaudio
 import speechbrain as sb
 import speechbrain.nnet.schedulers as schedulers
 from speechbrain.utils.distributed import run_on_main
-from torch.cuda.amp import autocast
 from hyperpyyaml import load_hyperpyyaml
 import numpy as np
 from tqdm import tqdm
@@ -35,6 +34,7 @@ import logging
 from speechbrain.processing.features import spectral_magnitude
 from speechbrain.utils.metric_stats import MetricStats
 from pesq import pesq
+from speechbrain.core import AMPConfig
 
 
 # Define training procedure
@@ -92,7 +92,8 @@ class Separation(sb.Brain):
                     targets = targets[:, :min_len, :]
 
                 if self.hparams.use_wavedrop:
-                    mix = self.hparams.wavedrop(mix, mix_lens)
+                    mix = self.hparams.drop_chunk(mix, mix_lens)
+                    mix = self.hparams.drop_freq(mix)
 
                 if self.hparams.limit_training_signal_len:
                     mix, targets = self.cut_signals(mix, targets)
@@ -150,19 +151,59 @@ class Separation(sb.Brain):
 
     def fit_batch(self, batch):
         """Trains one batch"""
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
+
         # Unpacking batch list
         mixture = batch.mix_sig
         targets = [batch.s1_sig, batch.s2_sig]
         noise = batch.noise_sig[0]
 
-        if self.auto_mix_prec:
-            with autocast():
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
+                    predictions, targets = self.compute_forward(
+                        mixture, targets, sb.Stage.TRAIN, noise
+                    )
+                    loss = self.compute_objectives(predictions, targets)
+
+                    # hard threshold the easy dataitems
+                    if self.hparams.threshold_byloss:
+                        th = self.hparams.threshold
+                        loss = loss[loss > th]
+                        if loss.nelement() > 0:
+                            loss = loss.mean()
+                    else:
+                        loss = loss.mean()
+
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    self.scaler.scale(loss).backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        self.scaler.unscale_(self.optimizer)
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.scaler.step(self.optimizer)
+                    self.scaler.update()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
+                    )
+                    loss.data = torch.tensor(0).to(self.device)
+            else:
                 predictions, targets = self.compute_forward(
                     mixture, targets, sb.Stage.TRAIN, noise
                 )
                 loss = self.compute_objectives(predictions, targets)
 
-                # hard threshold the easy dataitems
                 if self.hparams.threshold_byloss:
                     th = self.hparams.threshold
                     loss = loss[loss > th]
@@ -171,56 +212,24 @@ class Separation(sb.Brain):
                 else:
                     loss = loss.mean()
 
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                self.scaler.scale(loss).backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    self.scaler.unscale_(self.optimizer)
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm,
-                    )
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
-                    )
-                )
-                loss.data = torch.tensor(0).to(self.device)
-        else:
-            predictions, targets = self.compute_forward(
-                mixture, targets, sb.Stage.TRAIN, noise
-            )
-            loss = self.compute_objectives(predictions, targets)
-
-            if self.hparams.threshold_byloss:
-                th = self.hparams.threshold
-                loss = loss[loss > th]
-                if loss.nelement() > 0:
-                    loss = loss.mean()
-            else:
-                loss = loss.mean()
-
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                loss.backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm
-                    )
-                self.optimizer.step()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    loss.backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.optimizer.step()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
                     )
-                )
-                loss.data = torch.tensor(0).to(self.device)
+                    loss.data = torch.tensor(0).to(self.device)
         self.optimizer.zero_grad()
 
         return loss.detach().cpu()
@@ -332,9 +341,7 @@ class Separation(sb.Brain):
             recombine = True
 
             for i in range(targets.shape[-1]):
-                new_target = self.hparams.speedperturb(
-                    targets[:, :, i], targ_lens
-                )
+                new_target = self.hparams.speed_perturb(targets[:, :, i])
                 new_targets.append(new_target)
                 if i == 0:
                     min_len = new_target.shape[-1]
@@ -624,6 +631,10 @@ if __name__ == "__main__":
         overrides=overrides,
     )
 
+    # Update precision to bf16 if the device is CPU and precision is fp16
+    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
+        hparams["precision"] = "bf16"
+
     # Check if wsj0_tr is set with dynamic mixing
     if hparams["dynamic_mixing"] and not os.path.exists(
         hparams["base_folder_dm"]
diff --git a/recipes/WHAMandWHAMR/meta/create_whamr_rirs.py b/recipes/WHAMandWHAMR/meta/create_whamr_rirs.py
index bd777e9aab22ec20bbe5c304222580660fa4146a..d461b443c75303ae143e471b784fca125a070dfd 100644
--- a/recipes/WHAMandWHAMR/meta/create_whamr_rirs.py
+++ b/recipes/WHAMandWHAMR/meta/create_whamr_rirs.py
@@ -12,7 +12,7 @@ import torchaudio
 from wham_room import WhamRoom
 from scipy.signal import resample_poly
 import torch
-from speechbrain.pretrained.fetching import fetch
+from speechbrain.utils.fetching import fetch
 from tqdm import tqdm
 
 
diff --git a/recipes/WHAMandWHAMR/separation/README.md b/recipes/WHAMandWHAMR/separation/README.md
index 2b7809469bb478a9f189343b75401a6c24ee1795..45ceaf17bcbff5bd4817f2c06a62d708107c0339 100644
--- a/recipes/WHAMandWHAMR/separation/README.md
+++ b/recipes/WHAMandWHAMR/separation/README.md
@@ -91,8 +91,8 @@ The 16kHz version of the sepformer can be found [here](https://huggingface.co/sp
 
 You can run the following command to train the model using Distributed Data Parallel (DDP) with 2 GPUs:
 
-```
- python -m torch.distributed.launch --nproc_per_node=2 train.py hparams/sepformer-whamr.yaml --data_folder /yourdatapath --distributed_launch --distributed_backend='nccl'
+```bash
+torchrun --nproc_per_node=2 train.py hparams/sepformer-whamr.yaml --data_folder /yourdatapath
 ```
 You can add the other runtime options as appropriate. For more complete information on multi-GPU usage, take a look at this [tutorial](https://colab.research.google.com/drive/13pBUacPiotw1IvyffvGZ-HrtBr9T6l15).
 
diff --git a/recipes/WHAMandWHAMR/separation/hparams/sepformer-wham.yaml b/recipes/WHAMandWHAMR/separation/hparams/sepformer-wham.yaml
index d66fe225935f2a207035ece7b391f2f43045fbd8..db920e7fab8d969cbebe61a15b4884ce6f843da1 100644
--- a/recipes/WHAMandWHAMR/separation/hparams/sepformer-wham.yaml
+++ b/recipes/WHAMandWHAMR/separation/hparams/sepformer-wham.yaml
@@ -36,13 +36,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 2 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -64,18 +64,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/separation/hparams/sepformer-whamr.yaml b/recipes/WHAMandWHAMR/separation/hparams/sepformer-whamr.yaml
index d670ea23e0729c1f1f0c15a0fe7c7d070dc9d3bc..8538529a629e4ed99176930e46d778e9c8a12438 100644
--- a/recipes/WHAMandWHAMR/separation/hparams/sepformer-whamr.yaml
+++ b/recipes/WHAMandWHAMR/separation/hparams/sepformer-whamr.yaml
@@ -35,12 +35,12 @@ test_data: !ref <save_folder>/whamr_tt.csv
 skip_prep: False
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -68,18 +68,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WHAMandWHAMR/separation/train.py b/recipes/WHAMandWHAMR/separation/train.py
index 277f053fb81cdf4d535512284e559facfa11ef85..a9f12c0691d42e8c0e5cdcd660aace76772677da 100755
--- a/recipes/WHAMandWHAMR/separation/train.py
+++ b/recipes/WHAMandWHAMR/separation/train.py
@@ -26,12 +26,12 @@ import torchaudio
 import speechbrain as sb
 import speechbrain.nnet.schedulers as schedulers
 from speechbrain.utils.distributed import run_on_main
-from torch.cuda.amp import autocast
 from hyperpyyaml import load_hyperpyyaml
 import numpy as np
 from tqdm import tqdm
 import csv
 import logging
+from speechbrain.core import AMPConfig
 
 
 # Define training procedure
@@ -77,7 +77,8 @@ class Separation(sb.Brain):
                     targets = targets[:, :min_len, :]
 
                 if self.hparams.use_wavedrop:
-                    mix = self.hparams.wavedrop(mix, mix_lens)
+                    mix = self.hparams.drop_chunk(mix, mix_lens)
+                    mix = self.hparams.drop_freq(mix)
 
                 if self.hparams.limit_training_signal_len:
                     mix, targets = self.cut_signals(mix, targets)
@@ -113,19 +114,59 @@ class Separation(sb.Brain):
 
     def fit_batch(self, batch):
         """Trains one batch"""
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
+
         # Unpacking batch list
         mixture = batch.mix_sig
         targets = [batch.s1_sig, batch.s2_sig]
         noise = batch.noise_sig[0]
 
-        if self.auto_mix_prec:
-            with autocast():
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
+                    predictions, targets = self.compute_forward(
+                        mixture, targets, sb.Stage.TRAIN, noise
+                    )
+                    loss = self.compute_objectives(predictions, targets)
+
+                    # hard threshold the easy dataitems
+                    if self.hparams.threshold_byloss:
+                        th = self.hparams.threshold
+                        loss = loss[loss > th]
+                        if loss.nelement() > 0:
+                            loss = loss.mean()
+                    else:
+                        loss = loss.mean()
+
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    self.scaler.scale(loss).backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        self.scaler.unscale_(self.optimizer)
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.scaler.step(self.optimizer)
+                    self.scaler.update()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
+                    )
+                    loss.data = torch.tensor(0).to(self.device)
+            else:
                 predictions, targets = self.compute_forward(
                     mixture, targets, sb.Stage.TRAIN, noise
                 )
                 loss = self.compute_objectives(predictions, targets)
 
-                # hard threshold the easy dataitems
                 if self.hparams.threshold_byloss:
                     th = self.hparams.threshold
                     loss = loss[loss > th]
@@ -134,56 +175,24 @@ class Separation(sb.Brain):
                 else:
                     loss = loss.mean()
 
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                self.scaler.scale(loss).backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    self.scaler.unscale_(self.optimizer)
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm,
-                    )
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
-                    )
-                )
-                loss.data = torch.tensor(0).to(self.device)
-        else:
-            predictions, targets = self.compute_forward(
-                mixture, targets, sb.Stage.TRAIN, noise
-            )
-            loss = self.compute_objectives(predictions, targets)
-
-            if self.hparams.threshold_byloss:
-                th = self.hparams.threshold
-                loss = loss[loss > th]
-                if loss.nelement() > 0:
-                    loss = loss.mean()
-            else:
-                loss = loss.mean()
-
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                loss.backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm
-                    )
-                self.optimizer.step()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    loss.backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.optimizer.step()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
                     )
-                )
-                loss.data = torch.tensor(0).to(self.device)
+                    loss.data = torch.tensor(0).to(self.device)
         self.optimizer.zero_grad()
 
         return loss.detach().cpu()
@@ -265,9 +274,7 @@ class Separation(sb.Brain):
             recombine = True
 
             for i in range(targets.shape[-1]):
-                new_target = self.hparams.speedperturb(
-                    targets[:, :, i], targ_lens
-                )
+                new_target = self.hparams.speed_perturb(targets[:, :, i])
                 new_targets.append(new_target)
                 if i == 0:
                     min_len = new_target.shape[-1]
@@ -543,6 +550,11 @@ if __name__ == "__main__":
         hyperparams_to_save=hparams_file,
         overrides=overrides,
     )
+    print("aaaaaaaaaaaaaaaaaa")
+    # Update precision to bf16 if the device is CPU and precision is fp16
+    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
+        hparams["precision"] = "bf16"
+        print("bbbbbbbbbbbbbbbbb")
 
     # Check if wsj0_tr is set with dynamic mixing
     if hparams["dynamic_mixing"] and not os.path.exists(
diff --git a/recipes/WSJ0Mix/separation/README.md b/recipes/WSJ0Mix/separation/README.md
index 7dd068335c607573d1c40f3e67c775388e999370..56c5accf062d4c74683d0918997dab58db32b13d 100644
--- a/recipes/WSJ0Mix/separation/README.md
+++ b/recipes/WSJ0Mix/separation/README.md
@@ -97,8 +97,8 @@ Pretrained models for SepFormer on WSJ0-2Mix, WSJ0-3Mix, and WHAM! datasets can
 
 You can run the following command to train the model using Distributed Data Parallel (DDP) with 2 GPUs:
 
-```
- python -m torch.distributed.launch --nproc_per_node=2 train.py hparams/sepformer.yaml --data_folder /yourdatapath --distributed_launch --distributed_backend='nccl'
+```bash
+torchrun --nproc_per_node=2 train.py hparams/sepformer.yaml --data_folder /yourdatapath
 ```
 You can add the other runtime options as appropriate. For more complete information on multi-GPU usage, take a look at this [tutorial](https://colab.research.google.com/drive/13pBUacPiotw1IvyffvGZ-HrtBr9T6l15).
 
diff --git a/recipes/WSJ0Mix/separation/hparams/convtasnet.yaml b/recipes/WSJ0Mix/separation/hparams/convtasnet.yaml
index fc2b533cf04f1fb5e1dcea44ac9c8edbc820186d..0305c6236d5cb7faa07739321428e4ad97e1a5a4 100644
--- a/recipes/WSJ0Mix/separation/hparams/convtasnet.yaml
+++ b/recipes/WSJ0Mix/separation/hparams/convtasnet.yaml
@@ -30,13 +30,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -58,18 +58,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WSJ0Mix/separation/hparams/dprnn.yaml b/recipes/WSJ0Mix/separation/hparams/dprnn.yaml
index c69234e29d757a42f3664c8d6e7ba31c2ab8a98b..df1952d8c93d034db2f97749865a02a40c45ee22 100644
--- a/recipes/WSJ0Mix/separation/hparams/dprnn.yaml
+++ b/recipes/WSJ0Mix/separation/hparams/dprnn.yaml
@@ -30,13 +30,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -58,18 +58,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WSJ0Mix/separation/hparams/resepformer.yaml b/recipes/WSJ0Mix/separation/hparams/resepformer.yaml
index e0a052d1889204f4d6a8e734333456fd0b903990..406f2aa76510ddce75fed5fc8e40e2cee3aea08b 100644
--- a/recipes/WSJ0Mix/separation/hparams/resepformer.yaml
+++ b/recipes/WSJ0Mix/separation/hparams/resepformer.yaml
@@ -32,12 +32,12 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -59,18 +59,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WSJ0Mix/separation/hparams/sepformer-conformerintra.yaml b/recipes/WSJ0Mix/separation/hparams/sepformer-conformerintra.yaml
index e12cacca2c4227a4e9fa1d169a2a2ddf0f704a71..2cf2b7ac551547a88df59a1f2f658cc895f0b4be 100644
--- a/recipes/WSJ0Mix/separation/hparams/sepformer-conformerintra.yaml
+++ b/recipes/WSJ0Mix/separation/hparams/sepformer-conformerintra.yaml
@@ -32,12 +32,12 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -59,18 +59,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WSJ0Mix/separation/hparams/sepformer-customdataset.yaml b/recipes/WSJ0Mix/separation/hparams/sepformer-customdataset.yaml
index ecafc8e002d43f2a3f042ea8082da11a42754367..c896f2dfd844ee1dbac3625f3eea38358fcc8d0d 100644
--- a/recipes/WSJ0Mix/separation/hparams/sepformer-customdataset.yaml
+++ b/recipes/WSJ0Mix/separation/hparams/sepformer-customdataset.yaml
@@ -33,13 +33,13 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32
 num_spks: 2 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 sample_rate: 16000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -61,18 +61,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WSJ0Mix/separation/hparams/sepformer.yaml b/recipes/WSJ0Mix/separation/hparams/sepformer.yaml
index f495360a718b6e7bdae6bf2d4a54ebd382d6ebc2..77319604d02669a6823e79acdb6aac75f70f88aa 100644
--- a/recipes/WSJ0Mix/separation/hparams/sepformer.yaml
+++ b/recipes/WSJ0Mix/separation/hparams/sepformer.yaml
@@ -33,14 +33,14 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: True # Set it to True for mixed precision
+precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 2 # set to 3 for wsj0-3mix
 noprogressbar: False
 save_audio: True # Save estimated sources on disk
 n_audio_to_save: 20
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -62,18 +62,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WSJ0Mix/separation/hparams/skim.yaml b/recipes/WSJ0Mix/separation/hparams/skim.yaml
index ea8ffc3c76d586d7bfd7a46192fde0fcf25fb531..606c7060a5115b863805c835575bd6a6db182cf7 100644
--- a/recipes/WSJ0Mix/separation/hparams/skim.yaml
+++ b/recipes/WSJ0Mix/separation/hparams/skim.yaml
@@ -32,12 +32,12 @@ skip_prep: False
 
 
 # Experiment params
-auto_mix_prec: False # Set it to True for mixed precision
+precision: fp32 # bf16, fp16 or fp32 # Set it to True for mixed precision
 num_spks: 2 # set to 3 for wsj0-3mix
 save_audio: False # Save estimated sources on disk
 sample_rate: 8000
 
-# Training parameters
+####################### Training Parameters ####################################
 N_epochs: 200
 batch_size: 1
 lr: 0.00015
@@ -59,18 +59,38 @@ use_rand_shift: False
 min_shift: -8000
 max_shift: 8000
 
-speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 1.0
-    drop_freq_prob: 0.0
-    drop_chunk_prob: 0.0
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    perturb_prob: 0.0
-    drop_freq_prob: 1.0
-    drop_chunk_prob: 1.0
-    sample_rate: !ref <sample_rate>
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
 
 # loss thresholding -- this thresholds the training loss
 threshold_byloss: True
diff --git a/recipes/WSJ0Mix/separation/train.py b/recipes/WSJ0Mix/separation/train.py
index 9f53787be9591b0b7e267b1f5ce02da408dc3e86..f8ccf93fbde3a67f339a85124363c25aa35730a4 100755
--- a/recipes/WSJ0Mix/separation/train.py
+++ b/recipes/WSJ0Mix/separation/train.py
@@ -29,12 +29,12 @@ import torchaudio
 import speechbrain as sb
 import speechbrain.nnet.schedulers as schedulers
 from speechbrain.utils.distributed import run_on_main
-from torch.cuda.amp import autocast
 from hyperpyyaml import load_hyperpyyaml
 import numpy as np
 from tqdm import tqdm
 import csv
 import logging
+from speechbrain.core import AMPConfig
 
 
 # Define training procedure
@@ -55,13 +55,14 @@ class Separation(sb.Brain):
         # Add speech distortions
         if stage == sb.Stage.TRAIN:
             with torch.no_grad():
-                if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
+                if self.hparams.use_speedperturb:
                     mix, targets = self.add_speed_perturb(targets, mix_lens)
 
                     mix = targets.sum(-1)
 
                 if self.hparams.use_wavedrop:
-                    mix = self.hparams.wavedrop(mix, mix_lens)
+                    mix = self.hparams.drop_chunk(mix, mix_lens)
+                    mix = self.hparams.drop_freq(mix)
 
                 if self.hparams.limit_training_signal_len:
                     mix, targets = self.cut_signals(mix, targets)
@@ -97,6 +98,9 @@ class Separation(sb.Brain):
 
     def fit_batch(self, batch):
         """Trains one batch"""
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
+
         # Unpacking batch list
         mixture = batch.mix_sig
         targets = [batch.s1_sig, batch.s2_sig]
@@ -104,14 +108,51 @@ class Separation(sb.Brain):
         if self.hparams.num_spks == 3:
             targets.append(batch.s3_sig)
 
-        if self.auto_mix_prec:
-            with autocast():
+        with self.no_sync(not should_step):
+            if self.use_amp:
+                with torch.autocast(
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
+                ):
+                    predictions, targets = self.compute_forward(
+                        mixture, targets, sb.Stage.TRAIN
+                    )
+                    loss = self.compute_objectives(predictions, targets)
+
+                    # hard threshold the easy dataitems
+                    if self.hparams.threshold_byloss:
+                        th = self.hparams.threshold
+                        loss = loss[loss > th]
+                        if loss.nelement() > 0:
+                            loss = loss.mean()
+                    else:
+                        loss = loss.mean()
+
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    self.scaler.scale(loss).backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        self.scaler.unscale_(self.optimizer)
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.scaler.step(self.optimizer)
+                    self.scaler.update()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
+                    )
+                    loss.data = torch.tensor(0).to(self.device)
+            else:
                 predictions, targets = self.compute_forward(
                     mixture, targets, sb.Stage.TRAIN
                 )
                 loss = self.compute_objectives(predictions, targets)
 
-                # hard threshold the easy dataitems
                 if self.hparams.threshold_byloss:
                     th = self.hparams.threshold
                     loss = loss[loss > th]
@@ -120,56 +161,24 @@ class Separation(sb.Brain):
                 else:
                     loss = loss.mean()
 
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                self.scaler.scale(loss).backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    self.scaler.unscale_(self.optimizer)
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm,
-                    )
-                self.scaler.step(self.optimizer)
-                self.scaler.update()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
-                    )
-                )
-                loss.data = torch.tensor(0).to(self.device)
-        else:
-            predictions, targets = self.compute_forward(
-                mixture, targets, sb.Stage.TRAIN
-            )
-            loss = self.compute_objectives(predictions, targets)
-
-            if self.hparams.threshold_byloss:
-                th = self.hparams.threshold
-                loss = loss[loss > th]
-                if loss.nelement() > 0:
-                    loss = loss.mean()
-            else:
-                loss = loss.mean()
-
-            if (
-                loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
-            ):  # the fix for computational problems
-                loss.backward()
-                if self.hparams.clip_grad_norm >= 0:
-                    torch.nn.utils.clip_grad_norm_(
-                        self.modules.parameters(), self.hparams.clip_grad_norm
-                    )
-                self.optimizer.step()
-            else:
-                self.nonfinite_count += 1
-                logger.info(
-                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
-                        self.nonfinite_count
+                if (
+                    loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
+                ):  # the fix for computational problems
+                    loss.backward()
+                    if self.hparams.clip_grad_norm >= 0:
+                        torch.nn.utils.clip_grad_norm_(
+                            self.modules.parameters(),
+                            self.hparams.clip_grad_norm,
+                        )
+                    self.optimizer.step()
+                else:
+                    self.nonfinite_count += 1
+                    logger.info(
+                        "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
+                            self.nonfinite_count
+                        )
                     )
-                )
-                loss.data = torch.tensor(0).to(self.device)
+                    loss.data = torch.tensor(0).to(self.device)
         self.optimizer.zero_grad()
 
         return loss.detach().cpu()
@@ -239,15 +248,13 @@ class Separation(sb.Brain):
         min_len = -1
         recombine = False
 
-        if self.hparams.use_speedperturb:
+        if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
             # Performing speed change (independently on each source)
             new_targets = []
             recombine = True
 
             for i in range(targets.shape[-1]):
-                new_target = self.hparams.speedperturb(
-                    targets[:, :, i], targ_lens
-                )
+                new_target = self.hparams.speed_perturb(targets[:, :, i])
                 new_targets.append(new_target)
                 if i == 0:
                     min_len = new_target.shape[-1]
@@ -528,6 +535,10 @@ if __name__ == "__main__":
         overrides=overrides,
     )
 
+    # Update precision to bf16 if the device is CPU and precision is fp16
+    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
+        hparams["precision"] = "bf16"
+
     # Check if wsj0_tr is set with dynamic mixing
     if hparams["dynamic_mixing"] and not os.path.exists(
         hparams["base_folder_dm"]
diff --git a/recipes/ZaionEmotionDataset/emotion_diarization/hparams/train.yaml b/recipes/ZaionEmotionDataset/emotion_diarization/hparams/train.yaml
index 3ec38136c0e11b4e299922a8806c31ebc3e2b150..0d9601d2068e91db1819a2647a8628657bc0a47c 100644
--- a/recipes/ZaionEmotionDataset/emotion_diarization/hparams/train.yaml
+++ b/recipes/ZaionEmotionDataset/emotion_diarization/hparams/train.yaml
@@ -30,14 +30,13 @@ train_annotation: !ref <output_folder>/train.json
 valid_annotation: !ref <output_folder>/valid.json
 test_annotation: !ref <output_folder>/test.json
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 15
 lr: 0.0001
 lr_wav2vec: 0.00001
-sorting: ascending
-# auto_mix_prec: False
+# precision: fp32 # bf16, fp16 or fp32
 # do_resample: False
-sample_rate: 16000
+# sample_rate: 16000
 
 # With data_parallel batch_size is split into N jobs
 # With DDP batch_size is multiplied by N jobs
@@ -69,10 +68,6 @@ dataloader_options:
 test_dataloader_opts:
     batch_size: !ref <test_batch_size>
 
-# Model parameters
-activation: !name:torch.nn.LeakyReLU
-# dnn_layers: 2
-
 # # DER evaluation parameters
 # ignore_overlap: True
 # forgiveness_collar: 0.25
@@ -80,15 +75,12 @@ activation: !name:torch.nn.LeakyReLU
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
 
 input_norm: !new:speechbrain.processing.features.InputNormalization
     norm_type: sentence
     std_norm: False
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec2>
@@ -120,11 +112,6 @@ modules:
 model: !new:torch.nn.ModuleList
     - [!ref <output_mlp>]
 
-model_opt_class: !name:torch.optim.Adadelta
-    lr: !ref <lr>
-    rho: 0.95
-    eps: 1.e-8
-
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
 
@@ -157,8 +144,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
 
-error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
-
 error_stats: !name:speechbrain.utils.metric_stats.MetricStats
     metric: !name:speechbrain.nnet.losses.classification_error
         reduction: batch
diff --git a/recipes/ZaionEmotionDataset/emotion_diarization/train.py b/recipes/ZaionEmotionDataset/emotion_diarization/train.py
index ffd64dc425bdfe3ea9c0be534843ae268336b14a..6ddf45d4ffa1ae6b183a9d6cffbf5be7f4129f34 100644
--- a/recipes/ZaionEmotionDataset/emotion_diarization/train.py
+++ b/recipes/ZaionEmotionDataset/emotion_diarization/train.py
@@ -85,20 +85,6 @@ class EmoDiaBrain(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Trains the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.wav2vec2_optimizer.step()
-            self.optimizer.step()
-
-        self.wav2vec2_optimizer.zero_grad()
-        self.optimizer.zero_grad()
-
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch=None):
         """Gets called at the beginning of each epoch.
         Arguments
@@ -200,6 +186,11 @@ class EmoDiaBrain(sb.Brain):
             )
             self.checkpointer.add_recoverable("optimizer", self.optimizer)
 
+        self.optimizers_dict = {
+            "wav2vec2": self.wav2vec2_optimizer,
+            "model": self.optimizer,
+        }
+
 
 def dataio_prep(hparams):
     """This function prepares the datasets to be used in the brain class.
diff --git a/recipes/fluent-speech-commands/Tokenizer/hparams/tokenizer_bpe51.yaml b/recipes/fluent-speech-commands/Tokenizer/hparams/tokenizer_bpe51.yaml
index db7c1ddb7287627e1f59ac1374bc774e94f4e302..eff38c7bf42af12f0bb5945cc54209c12ba36551 100644
--- a/recipes/fluent-speech-commands/Tokenizer/hparams/tokenizer_bpe51.yaml
+++ b/recipes/fluent-speech-commands/Tokenizer/hparams/tokenizer_bpe51.yaml
@@ -13,7 +13,7 @@ train_csv: !ref <output_folder>/train.csv
 valid_csv: !ref <output_folder>/valid.csv
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 51  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
diff --git a/recipes/fluent-speech-commands/Tokenizer/train.py b/recipes/fluent-speech-commands/Tokenizer/train.py
index 4b8b72e69f9cede631fb4adee68dbca56fd37396..d1cf9dfaae8e44e5663208d729bd9a9fc7147af3 100644
--- a/recipes/fluent-speech-commands/Tokenizer/train.py
+++ b/recipes/fluent-speech-commands/Tokenizer/train.py
@@ -25,7 +25,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/fluent-speech-commands/direct/hparams/train.yaml b/recipes/fluent-speech-commands/direct/hparams/train.yaml
index 61e57782a109acf2b68a4b0ed080bf2894b3a999..428faf144a9fb1e371bde9c3a6816bfb1e95138d 100644
--- a/recipes/fluent-speech-commands/direct/hparams/train.yaml
+++ b/recipes/fluent-speech-commands/direct/hparams/train.yaml
@@ -16,22 +16,30 @@ save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
 test_wer_file: !ref <output_folder>/wer_test.txt
 
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
 # Data files
 data_folder: !PLACEHOLDER # e.g, /localscratch/fluent_speech_commands_dataset
-rir_folder: !ref <data_folder> # Change it if needed
-csv_train: !ref <output_folder>/train.csv
-csv_valid: !ref <output_folder>/valid.csv
-csv_test: !ref <output_folder>/test.csv
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
+csv_train: !ref <save_folder>/train.csv
+csv_valid: !ref <save_folder>/valid.csv
+csv_test: !ref <save_folder>/test.csv
+noise_annotation: !ref <save_folder>/noise.csv
+rir_annotation: !ref <save_folder>/rir.csv
+
 tokenizer_file: https://www.dropbox.com/s/hvf2huofnq0sjbn/51_unigram.model?dl=1
 skip_prep: False
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 6
 batch_size: 16
 lr: 0.0003
 # token_type: unigram # ["unigram", "bpe", "char"]
 sorting: random
 
-# Model parameters
+####################### Model Parameters #######################################
 sample_rate: 16000
 emb_size: 128
 dec_neurons: 512
@@ -48,17 +56,78 @@ slu_beam_size: 80
 eos_threshold: 1.5
 temperature: 1.25
 
+num_workers: 4
 dataloader_opts:
+    num_workers: !ref <num_workers>
     batch_size: !ref <batch_size>
     shuffle: True
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Models
-asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
-    source: speechbrain/asr-crdnn-rnnlm-librispeech
-    run_opts: {"device":"cuda:0"}
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 9
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 2
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    shuffle_augmentations: True
+    min_augmentations: 1
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <add_reverb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+asr_model_source: speechbrain/asr-crdnn-rnnlm-librispeech
 
 slu_enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <ASR_encoder_dim>]
@@ -90,57 +159,8 @@ seq_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dec_neurons>
     n_neurons: !ref <output_neurons>
 
-augment_wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [100]
-
-augment_speed: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
-add_rev: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 0.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-add_rev_noise: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 3.0  # seconds
-    reverb_prob: 1.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-    rir_scale_factor: 1.0
-
-
-augment_pipeline: [
-    !ref <augment_wavedrop>,
-    !ref <augment_speed>,
-    !ref <add_rev>,
-    !ref <add_noise>,
-    !ref <add_rev_noise>
-]
-
 
 modules:
-    augment_wavedrop: !ref <augment_wavedrop>
-    augment_speed: !ref <augment_speed>
-    add_rev: !ref <add_rev>
-    add_noise: !ref <add_noise>
-    add_rev_noise: !ref <add_rev_noise>
     slu_enc: !ref <slu_enc>
     output_emb: !ref <output_emb>
     dec: !ref <dec>
@@ -172,7 +192,6 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     temperature: !ref <temperature>
     using_max_attn_shift: False
     max_attn_shift: 30
-    coverage_penalty: 0.
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
diff --git a/recipes/fluent-speech-commands/direct/train.py b/recipes/fluent-speech-commands/direct/train.py
index 61271635fda02059476f512014efbf9bc4c8f432..8175efe64650d4f2ff9127860a7fb41aac83d0c5 100644
--- a/recipes/fluent-speech-commands/direct/train.py
+++ b/recipes/fluent-speech-commands/direct/train.py
@@ -34,30 +34,13 @@ class SLU(sb.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, tokens_bos_lens = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            # Applying the augmentation pipeline
-            wavs_aug_tot = []
-            wavs_aug_tot.append(wavs)
-            for count, augment in enumerate(self.hparams.augment_pipeline):
-
-                # Apply augment
-                wavs_aug = augment(wavs, wav_lens)
-
-                # Managing speed change
-                if wavs_aug.shape[1] > wavs.shape[1]:
-                    wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
-                else:
-                    zero_sig = torch.zeros_like(wavs)
-                    zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
-                    wavs_aug = zero_sig
-
-                wavs_aug_tot.append(wavs_aug)
-
-            wavs = torch.cat(wavs_aug_tot, dim=0)
-            self.n_augment = len(wavs_aug_tot)
-            wav_lens = torch.cat([wav_lens] * self.n_augment)
-            tokens_bos = torch.cat([tokens_bos] * self.n_augment)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
+            tokens_bos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_bos_lens
+            )
 
         # ASR encoder forward pass
         with torch.no_grad():
@@ -75,40 +58,27 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            return p_seq, wav_lens
-        else:
-            p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens)
-            return p_seq, wav_lens, p_tokens
+        p_tokens = None
+        if stage != sb.Stage.TRAIN:
+            p_tokens, _, _, _ = self.hparams.beam_searcher(
+                encoder_out, wav_lens
+            )
+
+        return p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            p_seq, wav_lens = predictions
-        else:
-            p_seq, wav_lens, predicted_tokens = predictions
+        p_seq, wav_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.hparams, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
-            )
-
-        if stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos] * self.n_augment, dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens] * self.n_augment, dim=0
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
             )
 
         loss_seq = self.hparams.seq_cost(
@@ -118,9 +88,7 @@ class SLU(sb.Brain):
         # (No ctc loss)
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 tokenizer.decode_ids(utt_seq).split(" ")
@@ -144,26 +112,8 @@ class SLU(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -291,7 +241,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -314,13 +263,22 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
-
+    run_on_main(hparams["prepare_noise_data"])
+    run_on_main(hparams["prepare_rir_data"])
     # here we create the datasets objects as well as tokenization and encoding
     (train_set, valid_set, test_set, tokenizer,) = dataio_prepare(hparams)
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
+
+    # Download pretrained ASR model
+    from speechbrain.inference.ASR import EncoderDecoderASR
+
+    hparams["asr_model"] = EncoderDecoderASR.from_hparams(
+        source=hparams["asr_model_source"],
+        run_opts={"device": run_opts["device"]},
+    )
 
     # Brain class initialization
     slu_brain = SLU(
diff --git a/recipes/timers-and-such/LM/hparams/train.yaml b/recipes/timers-and-such/LM/hparams/train.yaml
index 485dd54265a57612f862b180d223c79149d3aa3b..f3ba652edbc6effb882d697e13ee3747d7badf9f 100644
--- a/recipes/timers-and-such/LM/hparams/train.yaml
+++ b/recipes/timers-and-such/LM/hparams/train.yaml
@@ -23,7 +23,7 @@ csv_test_synth: !ref <output_folder>/test-synth-type=decoupled.csv
 csv_test_real: !ref <output_folder>/test-real-type=decoupled.csv
 skip_prep: False
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 10
 batch_size: 128
 lr: 0.0003
diff --git a/recipes/timers-and-such/LM/train.py b/recipes/timers-and-such/LM/train.py
index d88c995300b2e12b7caaefaabff6889a0ff98b5f..e0b926152c697935bd0bde94d167ca276c1f6a18 100644
--- a/recipes/timers-and-such/LM/train.py
+++ b/recipes/timers-and-such/LM/train.py
@@ -38,27 +38,6 @@ class LM(sb.Brain):
         loss = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens)
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
-    def on_stage_start(self, stage, epoch):
-        """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
-
     def on_stage_end(self, stage, stage_loss, epoch):
         """Gets called at the end of a epoch."""
         # Compute/store important stats
@@ -169,7 +148,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -190,7 +168,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Create experiment directory
     sb.create_experiment_directory(
diff --git a/recipes/timers-and-such/Tokenizer/hparams/tokenizer_bpe51.yaml b/recipes/timers-and-such/Tokenizer/hparams/tokenizer_bpe51.yaml
index 7554a03421cfa18627b4809af2296f6c54d4ea13..2a9f39161d2237e228ee0a9d82def76728a52324 100644
--- a/recipes/timers-and-such/Tokenizer/hparams/tokenizer_bpe51.yaml
+++ b/recipes/timers-and-such/Tokenizer/hparams/tokenizer_bpe51.yaml
@@ -15,7 +15,7 @@ train_csv: !ref <output_folder>/train-type=direct.csv
 valid_csv: !ref <output_folder>/dev-real-type=direct.csv
 
 
-# Training parameters
+####################### Training Parameters ####################################
 token_type: unigram  # ["unigram", "bpe", "char"]
 token_output: 51  # index(blank/eos/bos/unk) = 0
 character_coverage: 1.0
diff --git a/recipes/timers-and-such/Tokenizer/train.py b/recipes/timers-and-such/Tokenizer/train.py
index bd32f8fc3c6b0a6c24bba13f7c7fe02556f12d7d..925978f4171a465088dc920bfa7c089ed2175bdd 100644
--- a/recipes/timers-and-such/Tokenizer/train.py
+++ b/recipes/timers-and-such/Tokenizer/train.py
@@ -25,7 +25,6 @@ if __name__ == "__main__":
     with open(hparams_file) as fin:
         hparams = load_hyperpyyaml(fin, overrides)
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/recipes/timers-and-such/decoupled/hparams/train_LS_LM.yaml b/recipes/timers-and-such/decoupled/hparams/train_LS_LM.yaml
index 626ff07bda74d056849573060e080b1695a71903..15b60202ee72c367e587aca71af4aa96083bf7d3 100644
--- a/recipes/timers-and-such/decoupled/hparams/train_LS_LM.yaml
+++ b/recipes/timers-and-such/decoupled/hparams/train_LS_LM.yaml
@@ -34,7 +34,7 @@ skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 test_on_all_real: False
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 1
 batch_size: 16
 lr: 0.0003
@@ -67,7 +67,7 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
 # Models
-asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
+asr_model: !apply:speechbrain.inference.ASR.EncoderDecoderASR.from_hparams
     source: speechbrain/asr-crdnn-rnnlm-librispeech
     run_opts: {"device":"cuda:0"}
     overrides: {"beam_size": !ref <asr_beam_size>}
@@ -138,8 +138,6 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     eos_threshold: !ref <eos_threshold>
     temperature: !ref <temperature>
     using_max_attn_shift: False
-    max_attn_shift: 30
-    coverage_penalty: 0.
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
diff --git a/recipes/timers-and-such/decoupled/hparams/train_TAS_LM.yaml b/recipes/timers-and-such/decoupled/hparams/train_TAS_LM.yaml
index 94bca81fc90bbf8a7140a293e83d9695b964c4ba..a0ab73e237afee40a02268de608a02fb653e63e6 100644
--- a/recipes/timers-and-such/decoupled/hparams/train_TAS_LM.yaml
+++ b/recipes/timers-and-such/decoupled/hparams/train_TAS_LM.yaml
@@ -34,7 +34,7 @@ skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 test_on_all_real: False
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 1
 batch_size: 16
 lr: 0.0003
@@ -68,7 +68,7 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
 # Models
-asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
+asr_model: !apply:speechbrain.inference.ASR.EncoderDecoderASR.from_hparams
     source: speechbrain/asr-crdnn-rnnlm-librispeech
     run_opts: {"device":"cuda:0"}
     overrides:
@@ -153,8 +153,6 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     eos_threshold: !ref <eos_threshold>
     temperature: !ref <temperature>
     using_max_attn_shift: False
-    max_attn_shift: 30
-    coverage_penalty: 0.
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
diff --git a/recipes/timers-and-such/decoupled/train.py b/recipes/timers-and-such/decoupled/train.py
index 9835df460e0a7228e00279ef4004bafe7193f211..47d24ba8cdaf86028aca986aadf5028e5518b461 100644
--- a/recipes/timers-and-such/decoupled/train.py
+++ b/recipes/timers-and-such/decoupled/train.py
@@ -68,27 +68,18 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            return p_seq, asr_tokens_lens
-        else:
-            p_tokens, scores = self.hparams.beam_searcher(
+        p_tokens = None
+        if stage != sb.Stage.TRAIN:
+            p_tokens, _, _, _ = self.hparams.beam_searcher(
                 encoder_out, asr_tokens_lens
             )
-            return p_seq, asr_tokens_lens, p_tokens
+
+        return p_seq, asr_tokens_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            p_seq, asr_tokens_lens = predictions
-        else:
-            p_seq, asr_tokens_lens, predicted_tokens = predictions
+        p_seq, asr_tokens_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
@@ -101,9 +92,7 @@ class SLU(sb.Brain):
         # (No ctc loss)
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 tokenizer.decode_ids(utt_seq).split(" ")
@@ -127,26 +116,8 @@ class SLU(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -315,7 +286,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -353,7 +323,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Brain class initialization
     slu_brain = SLU(
diff --git a/recipes/timers-and-such/direct/hparams/train.yaml b/recipes/timers-and-such/direct/hparams/train.yaml
index f9edc3d789edebf0add05af3fa95561c6c4ff33b..01909eb5b02891a6bc91b66ed5a360f6e24df5bd 100644
--- a/recipes/timers-and-such/direct/hparams/train.yaml
+++ b/recipes/timers-and-such/direct/hparams/train.yaml
@@ -21,27 +21,31 @@ test_synth_wer_file: !ref <output_folder>/test_synth_wer.txt
 
 # Data files
 data_folder: !PLACEHOLDER # e.g, /localscratch/timers-and-such
-data_folder_rirs: !ref <data_folder>
 train_splits: ["train-synth", "train-real"]
-csv_train: !ref <output_folder>/train-type=direct.csv
-csv_dev_real: !ref <output_folder>/dev-real-type=direct.csv
-csv_dev_synth: !ref <output_folder>/dev-synth-type=direct.csv
-csv_test_real: !ref <output_folder>/test-real-type=direct.csv
-csv_test_synth: !ref <output_folder>/test-synth-type=direct.csv
-csv_all_real: !ref <output_folder>/all-real-type=direct.csv
+csv_train: !ref <save_folder>/train-type=direct.csv
+csv_dev_real: !ref <save_folder>/dev-real-type=direct.csv
+csv_dev_synth: !ref <save_folder>/dev-synth-type=direct.csv
+csv_test_real: !ref <save_folder>/test-real-type=direct.csv
+csv_test_synth: !ref <save_folder>/test-synth-type=direct.csv
+csv_all_real: !ref <save_folder>/all-real-type=direct.csv
 tokenizer_file: https://huggingface.co/speechbrain/slu-timers-and-such-direct-librispeech-asr/resolve/main/tokenizer.ckpt
 skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 test_on_all_real: False
 
-# Training parameters
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
+####################### Training Parameters ####################################
 number_of_epochs: 1
 batch_size: 16
 lr: 0.0003
 # token_type: unigram # ["unigram", "bpe", "char"]
 sorting: random
 
-# Model parameters
+####################### Model Parameters #######################################
 sample_rate: 16000
 emb_size: 128
 dec_neurons: 512
@@ -58,17 +62,68 @@ slu_beam_size: 80
 eos_threshold: 1.5
 temperature: 1.25
 
+num_workers: 4
 dataloader_opts:
+    num_workers: !ref <num_workers>
     batch_size: !ref <batch_size>
     shuffle: True
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Models
-asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
-    source: speechbrain/asr-crdnn-rnnlm-librispeech
-    run_opts: {"device":"cuda:0"}
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+asr_model_source: speechbrain/asr-crdnn-rnnlm-librispeech
 
 slu_enc: !new:speechbrain.nnet.containers.Sequential
     input_shape: [null, null, !ref <ASR_encoder_dim>]
@@ -100,20 +155,12 @@ seq_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dec_neurons>
     n_neurons: !ref <output_neurons>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
 
 modules:
     slu_enc: !ref <slu_enc>
     output_emb: !ref <output_emb>
     dec: !ref <dec>
     seq_lin: !ref <seq_lin>
-    env_corrupt: !ref <env_corrupt>
 
 model: !new:torch.nn.ModuleList
     - [!ref <slu_enc>, !ref <output_emb>,
@@ -140,8 +187,6 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     eos_threshold: !ref <eos_threshold>
     temperature: !ref <temperature>
     using_max_attn_shift: False
-    max_attn_shift: 30
-    coverage_penalty: 0.
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
@@ -159,9 +204,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         scheduler: !ref <lr_annealing>
         counter: !ref <epoch_counter>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
 
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
diff --git a/recipes/timers-and-such/direct/hparams/train_with_wav2vec2.yaml b/recipes/timers-and-such/direct/hparams/train_with_wav2vec2.yaml
index e44e553773552f5ee98e1226df07a6eaf2e4a869..4ac02716648e78dddb8de604e1fa2ade2c8d0d62 100644
--- a/recipes/timers-and-such/direct/hparams/train_with_wav2vec2.yaml
+++ b/recipes/timers-and-such/direct/hparams/train_with_wav2vec2.yaml
@@ -37,7 +37,7 @@ ckpt_interval_minutes: 15 # save checkpoint every N min
 test_on_all_real: False
 
 
-# Training parameters
+####################### Training Parameters ####################################
 number_of_epochs: 50
 batch_size: 8
 lr: 0.0004
@@ -49,7 +49,7 @@ freeze_wav2vec: False
 # token_type: unigram # ["unigram", "bpe", "char"]
 sorting: ascending
 
-# Model parameters
+####################### Model Parameters #######################################
 sample_rate: 16000
 emb_size: 128
 dec_neurons: 512
@@ -73,7 +73,7 @@ dataloader_opts:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
+wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
     source: !ref <wav2vec2_hub>
     output_norm: True
     freeze: !ref <freeze_wav2vec>
@@ -143,7 +143,6 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     temperature: !ref <temperature>
     using_max_attn_shift: False
     max_attn_shift: 30
-    coverage_penalty: 0.
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
@@ -171,10 +170,39 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
         counter: !ref <epoch_counter>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
+############################## Augmentations ###################################
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
     speeds: [95, 100, 105]
 
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
 
diff --git a/recipes/timers-and-such/direct/train.py b/recipes/timers-and-such/direct/train.py
index d7dd5edd2e275277a6b582bcc12266e98624346e..3d396151dd63916f6520f7ab2f2558bf5cfa5cbd 100644
--- a/recipes/timers-and-such/direct/train.py
+++ b/recipes/timers-and-such/direct/train.py
@@ -34,16 +34,13 @@ class SLU(sb.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, tokens_bos_lens = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "env_corrupt"):
-                wavs_noise = self.hparams.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-                tokens_bos_lens = torch.cat([tokens_bos_lens, tokens_bos_lens])
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
+            tokens_bos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_bos_lens
+            )
 
         # ASR encoder forward pass
         with torch.no_grad():
@@ -61,34 +58,26 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            return p_seq, wav_lens
-        else:
-            p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens)
-            return p_seq, wav_lens, p_tokens
+        p_tokens = None
+        if stage != sb.Stage.TRAIN:
+            p_tokens, _, _, _ = self.hparams.beam_searcher(
+                encoder_out, wav_lens
+            )
+
+        return p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            p_seq, wav_lens = predictions
-        else:
-            p_seq, wav_lens, predicted_tokens = predictions
+        p_seq, wav_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
-        tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.hparams, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
             )
 
         loss_seq = self.hparams.seq_cost(
@@ -98,9 +87,7 @@ class SLU(sb.Brain):
         # (No ctc loss)
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 tokenizer.decode_ids(utt_seq).split(" ")
@@ -124,26 +111,8 @@ class SLU(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -304,7 +273,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -329,6 +297,7 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
 
     # here we create the datasets objects as well as tokenization and encoding
     (
@@ -342,7 +311,15 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
+
+    # Download pretrained ASR model
+    from speechbrain.inference.ASR import EncoderDecoderASR
+
+    hparams["asr_model"] = EncoderDecoderASR.from_hparams(
+        source=hparams["asr_model_source"],
+        run_opts={"device": run_opts["device"]},
+    )
 
     # Brain class initialization
     slu_brain = SLU(
diff --git a/recipes/timers-and-such/direct/train_with_wav2vec2.py b/recipes/timers-and-such/direct/train_with_wav2vec2.py
index 5f2511ef4d3e1d8429b71843701c47356bc7f127..2d0f7e9b5257acba04542a71f07e9774421d8754 100644
--- a/recipes/timers-and-such/direct/train_with_wav2vec2.py
+++ b/recipes/timers-and-such/direct/train_with_wav2vec2.py
@@ -36,10 +36,11 @@ class SLU(sb.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, tokens_bos_lens = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
+
         # wav2vec forward pass
         wav2vec2_out = self.modules.wav2vec2(wavs, wav_lens)
         # SLU forward pass
@@ -51,22 +52,18 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
             return p_seq, wav_lens
         else:
-            p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens)
+            p_tokens, _, _, _ = self.hparams.beam_searcher(
+                encoder_out, wav_lens
+            )
             return p_seq, wav_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
             p_seq, wav_lens = predictions
         else:
             p_seq, wav_lens, predicted_tokens = predictions
@@ -75,15 +72,20 @@ class SLU(sb.Brain):
         tokens_eos, tokens_eos_lens = batch.tokens_eos
         tokens, tokens_lens = batch.tokens
 
+        # Label Augmentation
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
+            )
+
         loss_seq = self.hparams.seq_cost(
             p_seq, tokens_eos, length=tokens_eos_lens
         )
 
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 tokenizer.decode_ids(utt_seq).split(" ")
@@ -107,31 +109,8 @@ class SLU(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        loss.backward()
-        if self.check_gradients(loss):
-            self.wav2vec2_optimizer.step()
-            self.optimizer.step()
-
-        self.wav2vec2_optimizer.zero_grad()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -196,9 +175,10 @@ class SLU(sb.Brain):
             )
             self.checkpointer.add_recoverable("optimizer", self.optimizer)
 
-    def zero_grad(self, set_to_none=False):
-        self.wav2vec2_optimizer.zero_grad(set_to_none)
-        self.optimizer.zero_grad(set_to_none)
+        self.optimizers_dict = {
+            "wav2vec2_optimizer": self.wav2vec2_optimizer,
+            "model_optimizer": self.optimizer,
+        }
 
 
 def dataio_prepare(hparams):
@@ -322,7 +302,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -360,7 +339,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     hparams["wav2vec2"] = hparams["wav2vec2"].to(run_opts["device"])
 
diff --git a/recipes/timers-and-such/multistage/hparams/train_LS_LM.yaml b/recipes/timers-and-such/multistage/hparams/train_LS_LM.yaml
index 7e7bd0d522ec1fefb723013f21e752c6c2806ae3..47892f73d7ab9a1ab7dbadb293461fb94a327642 100644
--- a/recipes/timers-and-such/multistage/hparams/train_LS_LM.yaml
+++ b/recipes/timers-and-such/multistage/hparams/train_LS_LM.yaml
@@ -21,7 +21,6 @@ test_synth_wer_file: !ref <output_folder>/test_synth_wer.txt
 
 # Data files
 data_folder: !PLACEHOLDER # e.g, /localscratch/timers-and-such
-data_folder_rirs: !ref <data_folder>
 train_splits: ["train-synth", "train-real"]
 csv_train: !ref <output_folder>/train-type=multistage.csv
 csv_dev_real: !ref <output_folder>/dev-real-type=multistage.csv
@@ -34,7 +33,13 @@ skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 test_on_all_real: False
 
-# Training parameters
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
+
+####################### Training Parameters ####################################
 number_of_epochs: 1
 batch_size: 16
 lr: 0.0003
@@ -59,15 +64,68 @@ slu_beam_size: 80
 eos_threshold: 1.5
 temperature: 1.25
 
+num_workers: 4
 dataloader_opts:
+    num_workers: !ref <num_workers>
     batch_size: !ref <batch_size>
     shuffle: True
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Models
-asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
+asr_model: !apply:speechbrain.inference.ASR.EncoderDecoderASR.from_hparams
     source: speechbrain/asr-crdnn-rnnlm-librispeech
     run_opts: {"device":"cuda:0"}
     overrides: {"beam_size": !ref <asr_beam_size>}
@@ -106,13 +164,6 @@ seq_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dec_neurons>
     n_neurons: !ref <output_neurons>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
 
 modules:
     slu_enc: !ref <slu_enc>
@@ -120,7 +171,6 @@ modules:
     output_emb: !ref <output_emb>
     dec: !ref <dec>
     seq_lin: !ref <seq_lin>
-    env_corrupt: !ref <env_corrupt>
 
 
 model: !new:torch.nn.ModuleList
@@ -149,8 +199,6 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     eos_threshold: !ref <eos_threshold>
     temperature: !ref <temperature>
     using_max_attn_shift: False
-    max_attn_shift: 30
-    coverage_penalty: 0.
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
@@ -168,10 +216,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         scheduler: !ref <lr_annealing>
         counter: !ref <epoch_counter>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
 
diff --git a/recipes/timers-and-such/multistage/hparams/train_TAS_LM.yaml b/recipes/timers-and-such/multistage/hparams/train_TAS_LM.yaml
index 1af9281fadc7e8cbaf01a961a1018edf3c056163..009f226282f155622b4202883a49edecf4cd406d 100644
--- a/recipes/timers-and-such/multistage/hparams/train_TAS_LM.yaml
+++ b/recipes/timers-and-such/multistage/hparams/train_TAS_LM.yaml
@@ -21,7 +21,6 @@ test_synth_wer_file: !ref <output_folder>/test_synth_wer.txt
 
 # Data files
 data_folder: !PLACEHOLDER # e.g, /localscratch/timers-and-such
-data_folder_rirs: !ref <data_folder>
 train_splits: ["train-synth", "train-real"]
 csv_train: !ref <output_folder>/train-type=multistage.csv
 csv_dev_real: !ref <output_folder>/dev-real-type=multistage.csv
@@ -34,7 +33,13 @@ skip_prep: False
 ckpt_interval_minutes: 15 # save checkpoint every N min
 test_on_all_real: False
 
-# Training parameters
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
+
+####################### Training Parameters ####################################
 number_of_epochs: 1
 batch_size: 16
 lr: 0.0003
@@ -59,15 +64,69 @@ slu_beam_size: 80
 eos_threshold: 1.5
 temperature: 1.25
 
+num_workers: 4
 dataloader_opts:
+    num_workers: !ref <num_workers>
     batch_size: !ref <batch_size>
     shuffle: True
 
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
+############################## Augmentations ###################################
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: 0
+    snr_high: 15
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: [95, 100, 105]
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: 0
+    drop_freq_high: 1
+    drop_freq_count_low: 1
+    drop_freq_count_high: 3
+    drop_freq_width: 0.05
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: 1000
+    drop_length_high: 2000
+    drop_count_low: 1
+    drop_count_high: 5
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    concat_original: True
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
+
+############################## Models ##########################################
+
 # Models
-asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
+asr_model: !apply:speechbrain.inference.ASR.EncoderDecoderASR.from_hparams
     source: speechbrain/asr-crdnn-rnnlm-librispeech
     run_opts: {"device":"cuda:0"}
     overrides:
@@ -119,21 +178,12 @@ seq_lin: !new:speechbrain.nnet.linear.Linear
     input_size: !ref <dec_neurons>
     n_neurons: !ref <output_neurons>
 
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-
 modules:
     slu_enc: !ref <slu_enc>
     input_emb: !ref <input_emb>
     output_emb: !ref <output_emb>
     dec: !ref <dec>
     seq_lin: !ref <seq_lin>
-    env_corrupt: !ref <env_corrupt>
 
 
 model: !new:torch.nn.ModuleList
@@ -161,8 +211,6 @@ beam_searcher: !new:speechbrain.decoders.S2SRNNBeamSearcher
     eos_threshold: !ref <eos_threshold>
     temperature: !ref <temperature>
     using_max_attn_shift: False
-    max_attn_shift: 30
-    coverage_penalty: 0.
 
 opt_class: !name:torch.optim.Adam
     lr: !ref <lr>
@@ -180,10 +228,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         scheduler: !ref <lr_annealing>
         counter: !ref <epoch_counter>
 
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
-
 log_softmax: !new:speechbrain.nnet.activations.Softmax
     apply_log: True
 
diff --git a/recipes/timers-and-such/multistage/train.py b/recipes/timers-and-such/multistage/train.py
index 3384f468a460d2bc1b8f6ad8d964f8eef7273e98..ba0c926b35fb4ec9b14814ea93d451d1d62ed727 100644
--- a/recipes/timers-and-such/multistage/train.py
+++ b/recipes/timers-and-such/multistage/train.py
@@ -33,16 +33,13 @@ class SLU(sb.Brain):
         wavs, wav_lens = batch.sig
         tokens_bos, tokens_bos_lens = batch.tokens_bos
 
-        # Add augmentation if specified
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.hparams, "env_corrupt"):
-                wavs_noise = self.hparams.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)
-                tokens_bos_lens = torch.cat([tokens_bos_lens, tokens_bos_lens])
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+            tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
+            tokens_bos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_bos_lens
+            )
 
         # ASR forward pass
         words, asr_tokens = self.hparams.asr_model.transcribe_batch(
@@ -75,36 +72,26 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            return p_seq, asr_tokens_lens
-        else:
-            p_tokens, scores = self.hparams.beam_searcher(
+        p_tokens = None
+        if stage != sb.Stage.TRAIN:
+            p_tokens, _, _, _ = self.hparams.beam_searcher(
                 encoder_out, asr_tokens_lens
             )
-            return p_seq, asr_tokens_lens, p_tokens
+
+        return p_seq, asr_tokens_lens, p_tokens
 
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
-            p_seq, asr_tokens_lens = predictions
-        else:
-            p_seq, asr_tokens_lens, predicted_tokens = predictions
+        p_seq, asr_tokens_lens, predicted_tokens = predictions
 
         ids = batch.id
         tokens_eos, tokens_eos_lens = batch.tokens_eos
-        tokens, tokens_lens = batch.tokens
 
-        if hasattr(self.hparams, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
-            tokens_eos_lens = torch.cat(
-                [tokens_eos_lens, tokens_eos_lens], dim=0
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
+            tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
+                tokens_eos_lens
             )
 
         loss_seq = self.hparams.seq_cost(
@@ -114,9 +101,7 @@ class SLU(sb.Brain):
         # (No ctc loss)
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 tokenizer.decode_ids(utt_seq).split(" ")
@@ -140,26 +125,8 @@ class SLU(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -320,7 +287,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -345,6 +311,7 @@ if __name__ == "__main__":
             "skip_prep": hparams["skip_prep"],
         },
     )
+    run_on_main(hparams["prepare_noise_data"])
 
     # here we create the datasets objects as well as tokenization and encoding
     (
@@ -358,7 +325,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Brain class initialization
     slu_brain = SLU(
diff --git a/requirements.txt b/requirements.txt
index f458d4d11eeab738eac8bd9f46d486ce0162b32e..bc6dec53d07543b429e1837632f9fac3f3fde526 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,16 @@
 -r lint-requirements.txt
-huggingface_hub>=0.7.0
+huggingface_hub>=0.8.0
 hyperpyyaml>=0.0.1
 joblib>=0.14.1
 numpy>=1.17.0
 packaging
 pandas>=1.0.1
 pre-commit>=2.3.0
+pygtrie>=2.1,<3.0
 scipy>=1.4.1
 sentencepiece>=0.1.91
 SoundFile; sys_platform == 'win32'
 torch>=1.9.0
-torchaudio>=0.9.0
+torchaudio>=1.9.0
 tqdm>=4.42.0
 transformers>=4.30.0
diff --git a/speechbrain/alignment/aligner.py b/speechbrain/alignment/aligner.py
index 669fca8e404f1affaf0b1c1e8263765904894078..764f90fa582d31c5d323db1e77bf86ee1214b683 100644
--- a/speechbrain/alignment/aligner.py
+++ b/speechbrain/alignment/aligner.py
@@ -1315,9 +1315,8 @@ class HMMAligner(torch.nn.Module):
         torch.save(self.align_dict, path)
 
     @mark_as_loader
-    def _load(self, path, end_of_epoch=False, device=None):
+    def _load(self, path, end_of_epoch=False):
         del end_of_epoch  # Not used here.
-        del device
         self.align_dict = torch.load(path)
 
 
diff --git a/speechbrain/alignment/ctc_segmentation.py b/speechbrain/alignment/ctc_segmentation.py
index e1a7efc412b6a4dfbfaedfa12d003f616369cc91..c8964a838d343cfe6da8ce77d2e6b467ea3c5928 100644
--- a/speechbrain/alignment/ctc_segmentation.py
+++ b/speechbrain/alignment/ctc_segmentation.py
@@ -1,10 +1,12 @@
 #!/usr/bin/env python3
-# 2021, Technische Universität München, Ludwig Kürzinger
 """Perform CTC segmentation to align utterances within audio files.
 
 This uses the ctc-segmentation Python package.
 Install it with pip or see the installing instructions in
 https://github.com/lumaku/ctc-segmentation
+
+Authors
+ * Ludwig Kürzinger 2021
 """
 
 import logging
@@ -18,7 +20,7 @@ import torch
 from typing import List
 
 # speechbrain interface
-from speechbrain.pretrained.interfaces import EncoderASR, EncoderDecoderASR
+from speechbrain.inference.ASR import EncoderASR, EncoderDecoderASR
 
 # imports for CTC segmentation
 try:
@@ -49,8 +51,8 @@ class CTCSegmentationTask(SimpleNamespace):
     The human-readable output can be configured with
     the printing options.
 
-    Properties
-    ---------
+    Attributes
+    ----------
     text : list
         Utterance texts, separated by line. But without the utterance
             name at the beginning of the line (as in kaldi-style text).
@@ -75,9 +77,6 @@ class CTCSegmentationTask(SimpleNamespace):
         have the same length as the number of utterances.
     lpz : np.ndarray
         CTC posterior log probabilities (Optional).
-
-    Properties for printing
-    ----------------------
     print_confidence_score : bool
         Include the confidence score.
         Default: True.
@@ -142,7 +141,7 @@ class CTCSegmentation:
     If needed, parameters for CTC segmentation can be set with ``set_config(·)``.
     Then call the instance as function to align text within an audio file.
 
-    Arguments
+    Attributes
     ---------
     asr_model : EncoderDecoderASR
         Speechbrain ASR interface. This requires a model that has a
@@ -178,16 +177,16 @@ class CTCSegmentation:
     Example
     -------
         >>> # using example file included in the SpeechBrain repository
-        >>> from speechbrain.pretrained import EncoderDecoderASR
+        >>> from speechbrain.inference.ASR import EncoderDecoderASR
         >>> from speechbrain.alignment.ctc_segmentation import CTCSegmentation
         >>> # load an ASR model
         >>> pre_trained = "speechbrain/asr-transformer-transformerlm-librispeech"
-        >>> asr_model = EncoderDecoderASR.from_hparams(source=pre_trained)
-        >>> aligner = CTCSegmentation(asr_model, kaldi_style_text=False)
+        >>> asr_model = EncoderDecoderASR.from_hparams(source=pre_trained)  # doctest: +SKIP
+        >>> aligner = CTCSegmentation(asr_model, kaldi_style_text=False)  # doctest: +SKIP
         >>> # load data
         >>> audio_path = "tests/samples/single-mic/example1.wav"
         >>> text = ["THE BIRCH CANOE", "SLID ON THE", "SMOOTH PLANKS"]
-        >>> segments = aligner(audio_path, text, name="example1")
+        >>> segments = aligner(audio_path, text, name="example1")  # doctest: +SKIP
 
     On multiprocessing
     ------------------
@@ -430,8 +429,8 @@ class CTCSegmentation:
         of samples per encoded CTC frame are needed. This function estimates them by
         doing one inference, which is only needed once.
 
-        Args
-        ----
+        Arguments
+        ---------
         speech_len : int
             Length of randomly generated speech vector for single
             inference. Default: 215040.
@@ -452,8 +451,8 @@ class CTCSegmentation:
     def get_lpz(self, speech: Union[torch.Tensor, np.ndarray]):
         """Obtain CTC posterior log probabilities for given speech data.
 
-        Args
-        ----
+        Arguments
+        ---------
         speech : Union[torch.Tensor, np.ndarray]
             Speech audio input.
 
@@ -521,8 +520,8 @@ class CTCSegmentation:
         ``['▁', '▁r', '▁re', '▁real', '▁really']``. The alignment will be
         based on the most probable activation sequence given by the network.
 
-        Args
-        ----
+        Arguments
+        ---------
         text : list
             List or multiline-string with utterance ground truths.
         lpz : np.ndarray
@@ -590,8 +589,8 @@ class CTCSegmentation:
     def get_segments(task: CTCSegmentationTask):
         """Obtain segments for given utterance texts and CTC log posteriors.
 
-        Args
-        ----
+        Arguments
+        ---------
         task : CTCSegmentationTask
             Task object that contains ground truth and
             CTC posterior probabilities.
@@ -636,8 +635,8 @@ class CTCSegmentation:
     ) -> CTCSegmentationTask:
         """Align utterances.
 
-        Args
-        ----
+        Arguments
+        ---------
         speech : Union[torch.Tensor, np.ndarray, str, Path]
             Audio file that can be given as path or as array.
         text : Union[List[str], str]
diff --git a/speechbrain/augment/__init__.py b/speechbrain/augment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a16815f8462cada3d236c2ee3f7df4a9cca9ec5
--- /dev/null
+++ b/speechbrain/augment/__init__.py
@@ -0,0 +1,2 @@
+""" Package containing various techniques of data augmentation
+"""
diff --git a/speechbrain/augment/augmenter.py b/speechbrain/augment/augmenter.py
new file mode 100644
index 0000000000000000000000000000000000000000..55aee785cbc94f7ff996d5c350f2bbd1090029f7
--- /dev/null
+++ b/speechbrain/augment/augmenter.py
@@ -0,0 +1,525 @@
+"""Classes for implementing data augmentation pipelines.
+
+Authors
+ * Mirco Ravanelli 2022
+"""
+import random
+import logging
+import torch
+import torch.nn.functional as F
+from speechbrain.utils.callchains import lengths_arg_exists
+
+logger = logging.getLogger(__name__)
+
+
+class Augmenter(torch.nn.Module):
+    """Applies pipelines of data augmentation.
+
+    Arguments
+    ---------
+    parallel_augment: bool
+        If False, the augmentations are applied sequentially with
+        the order specified in the pipeline argument.
+        When True, all the N augmentations are concatenated in the output
+        on the batch axis.
+    parallel_augment_fixed_bs: bool
+        If False, each augmenter (performed in parallel) generates a number of
+        augmented examples equal to the batch size. Thus, overall, with this
+        option N*batch size artificial data are
+        generated, where N is the number of augmenters.
+        When True, the number of total augmented examples is kept fixed at
+        the batch size, thus, for each augmenter, fixed at batch size // N examples.
+        This option is useful to keep controlled the number of synthetic examples
+        with respect to the original data distribution, as it keep always
+        50% of original data, and 50% of augmented data.
+    concat_original: bool
+        if True, the original input is concatenated with the
+        augmented outputs (on the batch axis).
+    min_augmentations: int
+        The number of augmentations applied to the input signal is randomly
+        sampled between min_augmentations and max_augmentations. For instance,
+        if the augmentation dict contains N=6 augmentations and we set
+        select min_augmentations=1 and max_augmentations=4 we apply up to
+        M=4 augmentations. The selected augmentations are applied in the order
+        specified in the augmentations dict. If shuffle_augmentations = True,
+        a random set of M augmentations is selected.
+    max_augmentations: int
+        Maximum number of augmentations to apply. See min_augmentations for
+        more details.
+    shuffle_augmentations:  bool
+        If True, it shuffles the entries of the augmentations dictionary.
+        The effect is to randomply select the order of the augmentations
+        to apply.
+    repeat_augment: int
+        Applies the augmentation algorithm N times. This can be used to
+        perform more data augmentation.
+    augment_start_index: int
+        The index of the first element in the input batch from which data
+        augmentation should begin.
+        This argument allows you to specify the starting point for applying
+        data augmentation.
+    augment_end_index: int
+        The index of the last element in the input batch at which data
+        augmentation should stop.
+        You can use this argument to define the endpoint for applying data
+        augmentation within the batch.
+    concat_start_index: int
+        If `concat_original` is set to True, you can specify a subpart of the
+        original batch to concatenate in the output.
+        Use this argument to select the index of the first element from the
+        original input batch to start copying from.
+    concat_end_index: int
+        If `concat_original` is set to True, you can specify a subpart of the
+        original batch to concatenate in the output. Use this argument to select
+        the index of the last element from the original input batch to end the
+        copying process.
+    augment_prob: float
+        The probability (0.0 to 1.0) of applying data augmentation. When set to 0.0,
+        the original signal is returned without any augmentation. When set to 1.0,
+        augmentation is always applied. Values in between determine the likelihood
+        of augmentation.
+    augmentations: list
+        List of augmentater objects to combine to perform data augmentation.
+    enable_augmentations: list
+        A list of booleans used to selectively enable or disable specific augmentation
+        techniques within the 'augmentations' list.
+        Each boolean corresponds to an augmentation object in the 'augmentations' list
+        and should be of the same length and order.
+        This feature is useful for performing ablations on augmentation techniques to
+        tailor them for a specific task.
+
+    Example
+    -------
+    >>> from speechbrain.augment.time_domain import DropFreq, DropChunk
+    >>> freq_dropper = DropFreq()
+    >>> chunk_dropper = DropChunk(drop_start=100, drop_end=16000)
+    >>> augment = Augmenter(parallel_augment=False, concat_original=False, augmentations=[freq_dropper, chunk_dropper])
+    >>> signal = torch.rand([4, 16000])
+    >>> output_signal, lenghts = augment(signal, lengths=torch.tensor([0.2,0.5,0.7,1.0]))
+    """
+
+    def __init__(
+        self,
+        parallel_augment=False,
+        parallel_augment_fixed_bs=False,
+        concat_original=False,
+        min_augmentations=None,
+        max_augmentations=None,
+        shuffle_augmentations=False,
+        repeat_augment=1,
+        augment_start_index=0,
+        augment_end_index=None,
+        concat_start_index=0,
+        concat_end_index=None,
+        augment_prob=1.0,
+        augmentations=list(),
+        enable_augmentations=None,
+    ):
+        super().__init__()
+        self.parallel_augment = parallel_augment
+        self.parallel_augment_fixed_bs = parallel_augment_fixed_bs
+        self.concat_original = concat_original
+        self.augmentations = augmentations
+        self.min_augmentations = min_augmentations
+        self.max_augmentations = max_augmentations
+        self.shuffle_augmentations = shuffle_augmentations
+        self.augment_start_index = augment_start_index
+        self.augment_end_index = augment_end_index
+        self.concat_start_index = concat_start_index
+        self.concat_end_index = concat_end_index
+        self.repeat_augment = repeat_augment
+        self.augment_prob = augment_prob
+        # Check min and max augmentations
+        self.check_min_max_augmentations()
+
+        # This variable represents the total number of augmentations to perform for each signal,
+        # including the original signal in the count.
+        self.num_augmentations = None
+        self.do_augment = True
+
+        # Check repeat augment arguments
+        if not isinstance(self.repeat_augment, int):
+            raise ValueError("repeat_augment must be an integer.")
+
+        if self.repeat_augment < 0:
+            raise ValueError("repeat_augment must be greater than 0.")
+
+        if self.augment_end_index is not None:
+            if self.augment_end_index < self.augment_start_index:
+                raise ValueError(
+                    "augment_end_index must be smaller or equal to augment_start_index."
+                )
+
+        if self.concat_end_index is not None:
+            if self.concat_end_index < self.concat_start_index:
+                raise ValueError(
+                    "concat_end_index must be smaller or equal to concat_start_index."
+                )
+
+        # Managing enable augmentations
+        if enable_augmentations is None:
+            enable_augmentations = [True] * len(augmentations)
+        elif not isinstance(enable_augmentations, list):
+            raise ValueError("enable_augmentations must be a list.")
+        elif len(enable_augmentations) != len(augmentations):
+            raise ValueError(
+                "enable_augmentations must have the same length as augmentations."
+            )
+        else:
+            augmentations = [
+                aug
+                for aug, enabled in zip(augmentations, enable_augmentations)
+                if enabled
+            ]
+
+        # Turn augmentations into a dictionary
+        self.augmentations = {
+            augmentation.__class__.__name__ + str(i): augmentation
+            for i, augmentation in enumerate(augmentations)
+        }
+
+        if len(self.augmentations) == 0:
+            logger.warning(
+                "No augmentation is applied because the augmentation list is empty."
+            )
+
+        # Check min and max augmentations
+        if self.max_augmentations <= 0:
+            logger.warning(
+                "No augmentations applied because max_augmentations is non-positive."
+            )
+        if self.min_augmentations < 0:
+            self.min_augmentations = 0
+            logger.warning(
+                "min_augmentations is negative. Modified to be non-negative."
+            )
+        if self.min_augmentations > self.max_augmentations:
+            logger.warning(
+                "min_augmentations is greater than max_augmentations. min_augmentations set to max_augmentations."
+            )
+            self.max_augmentations = self.min_augmentations
+
+        # Check if augmentation modules need the length argument
+        self.require_lengths = {}
+        for aug_key, aug_fun in self.augmentations.items():
+            self.require_lengths[aug_key] = lengths_arg_exists(aug_fun.forward)
+
+    def augment(self, x, lengths, selected_augmentations):
+        """Applies data augmentation on the seleted augmentations.
+
+        Arguments
+        ---------
+        x : torch.Tensor (batch, time, channel)
+            input to augment.
+        lengths : torch.Tensor
+            The length of each sequence in the batch.
+        selected_augmentations: dict
+            Dictionary containg the selected augmentation to apply.
+        """
+        next_input = x
+        next_lengths = lengths
+        output = []
+        output_lengths = []
+        out_lengths = lengths
+        for k, augment_name in enumerate(selected_augmentations):
+            augment_fun = self.augmentations[augment_name]
+
+            idx = torch.arange(x.shape[0])
+            if self.parallel_augment and self.parallel_augment_fixed_bs:
+                idx_startstop = torch.linspace(
+                    0, x.shape[0], len(selected_augmentations) + 1
+                ).to(torch.int)
+                idx_start = idx_startstop[k]
+                idx_stop = idx_startstop[k + 1]
+                idx = idx[idx_start:idx_stop]
+
+            # Check input arguments
+            if self.require_lengths[augment_name]:
+                out = augment_fun(
+                    next_input[idx, ...], lengths=next_lengths[idx]
+                )
+            else:
+                out = augment_fun(next_input[idx, ...])
+
+            # Check output arguments
+            if isinstance(out, tuple):
+                if len(out) == 2:
+                    out, out_lengths = out
+                else:
+                    raise ValueError(
+                        "The function must return max two arguments (Tensor, Length[optional])"
+                    )
+
+            # Manage sequential or parallel augmentation
+            if not self.parallel_augment:
+                next_input = out
+                next_lengths = out_lengths[idx]
+            else:
+                output.append(out)
+                output_lengths.append(out_lengths)
+
+        if self.parallel_augment:
+            # Concatenate all the augmented data
+            output, output_lengths = self.concatenate_outputs(
+                output, output_lengths
+            )
+        else:
+            # Take the last agumented signal of the pipeline
+            output = out
+            output_lengths = out_lengths
+
+        return output, output_lengths
+
+    def forward(self, x, lengths):
+        """Applies data augmentation.
+
+        Arguments
+        ---------
+        x : torch.Tensor (batch, time, channel)
+            input to augment.
+        lengths : torch.Tensor
+            The length of each sequence in the batch.
+        """
+
+        # Determine whether to apply data augmentation
+        self.do_augment = True
+        if random.random() > self.augment_prob:
+            self.do_augment = False
+            return x, lengths
+
+        x_original = x
+        len_original = lengths
+
+        # Determine the ending index for augmentation, considering user-specified or default values.
+        self.augment_end_index_batch = (
+            min(self.augment_end_index, x.shape[0])
+            if self.augment_end_index is not None
+            else x.shape[0]
+        )
+
+        # If the augmentation starting index is beyond the size of the data, return the original data.
+        if self.augment_start_index >= x.shape[0]:
+            self.do_augment = False
+            logger.warning(
+                "No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch."
+            )
+            return x, lengths
+
+        # Select the number of augmentations to apply
+        self.N_augment = torch.randint(
+            low=self.min_augmentations,
+            high=self.max_augmentations + 1,
+            size=(1,),
+            device=x.device,
+        )
+
+        # Get augmentations list
+        augmentations_lst = list(self.augmentations.keys())
+
+        # No augmentation
+        if (
+            self.repeat_augment == 0
+            or self.N_augment == 0
+            or len(augmentations_lst) == 0
+        ):
+            self.do_augment = False
+            return x, lengths
+
+        # Shuffle augmentation
+        if self.shuffle_augmentations:
+            random.shuffle(augmentations_lst)
+
+        # Select the augmentations to apply
+        selected_augmentations = augmentations_lst[0 : self.N_augment]
+
+        # Select the portion of the input to augment and update lengths accordingly.
+        x = x[self.augment_start_index : self.augment_end_index_batch]
+        lengths = lengths[
+            self.augment_start_index : self.augment_end_index_batch
+        ]
+
+        # Lists to collect the outputs
+        output_lst = []
+        output_len_lst = []
+
+        # Concatenate the original signal if required
+        self.skip_concat = not (self.concat_original)
+        if self.concat_original:
+
+            # Check start index
+            if self.concat_start_index >= x.shape[0]:
+                self.skip_concat = True
+                pass
+            else:
+                self.skip_concat = False
+                # Determine the ending index for concatenation, considering user-specified or default values.
+                self.concat_end_index_batch = (
+                    min(self.concat_end_index, x_original.shape[0])
+                    if self.concat_end_index is not None
+                    else x_original.shape[0]
+                )
+
+                output_lst.append(
+                    x_original[
+                        self.concat_start_index : self.concat_end_index_batch
+                    ]
+                )
+                output_len_lst.append(
+                    len_original[
+                        self.concat_start_index : self.concat_end_index_batch
+                    ]
+                )
+
+        # Perform augmentations
+        for i in range(self.repeat_augment):
+            output, output_lengths = self.augment(
+                x, lengths, selected_augmentations
+            )
+            output_lst.append(output)
+            output_len_lst.append(output_lengths)
+
+        # Concatenate the final outputs while handling scenarios where
+        # different temporal dimensions may arise due to augmentations
+        # like speed change.
+        output, output_lengths = self.concatenate_outputs(
+            output_lst, output_len_lst
+        )
+
+        return output, output_lengths
+
+    def concatenate_outputs(self, augment_lst, augment_len_lst):
+        """
+        Concatenate a list of augmented signals, accounting for varying temporal lengths.
+        Padding is applied to ensure all signals can be concatenated.
+
+        Arguments
+        ---------
+        augmentations : List of torch.Tensor
+            List of augmented signals to be concatenated.
+
+        augmentation_lengths : List of torch.Tensor
+            List of lengths corresponding to the augmented signals.
+
+        Returns
+        -------
+        concatenated_signals : torch.Tensor
+            A tensor containing the concatenated signals.
+
+        concatenated_lengths : torch.Tensor
+            A tensor containing the concatenated signal lengths.
+
+        Notes
+        -----
+        This function takes a list of augmented signals, which may have different temporal
+        lengths due to variations such as speed changes. It pads the signals to match the
+        maximum temporal dimension found among the input signals and rescales the lengths
+        accordingly before concatenating them.
+        """
+
+        # Find the maximum temporal dimension (batch length) among the sequences
+        max_len = max(augment.shape[1] for augment in augment_lst)
+
+        # Rescale the sequence lengths to adjust for augmented batches with different temporal dimensions.
+        augment_len_lst = [
+            length * (output.shape[1] / max_len)
+            for length, output in zip(augment_len_lst, augment_lst)
+        ]
+
+        # Pad sequences to match the maximum temporal dimension.
+        # Note that some augmented batches, like those with speed changes, may have different temporal dimensions.
+        augment_lst = [
+            F.pad(output, (0, max_len - output.shape[1]))
+            for output in augment_lst
+        ]
+
+        # Concatenate the padded sequences and rescaled lengths
+        output = torch.cat(augment_lst, dim=0)
+        output_lengths = torch.cat(augment_len_lst, dim=0)
+
+        return output, output_lengths
+
+    def replicate_multiple_labels(self, *args):
+        """
+        Replicates the labels along the batch axis a number of times that
+        corresponds to the number of augmentations. Indeed parallel and
+        concatenation augmentations alter the time dimension.
+
+        Arguments
+        ---------
+        args : torch.Tensor
+            Input label tensors to be replicated. Can be a uniq or a list of
+            Tensors.
+
+        Returns
+        -------
+        augmented_labels: torch.Tensor
+            Labels corresponding to the augmented input. Returns as many Tensor
+            as given in input.
+        """
+
+        # Determine whether to apply data augmentation
+        if not self.do_augment:
+            return args
+
+        list_of_augmented_labels = []
+
+        for labels in args:
+            list_of_augmented_labels.append(self.replicate_labels(labels))
+
+        return list_of_augmented_labels
+
+    def replicate_labels(self, labels):
+        """
+        Replicates the labels along the batch axis a number of times that
+        corresponds to the number of augmentations. Indeed parallel and
+        concatenation augmentations alter the time dimension.
+
+        Arguments
+        ---------
+        labels : torch.Tensor
+            Input label tensors to be replicated.
+
+        Returns
+        -------
+        augmented_labels: torch.Tensor
+            Labels corresponding to the augmented input. Returns as many Tensor
+            as given in input.
+        """
+
+        # Determine whether to apply data augmentation
+        if not self.do_augment:
+            return labels
+
+        augmented_labels = []
+        if self.concat_original and not (self.skip_concat):
+            augmented_labels = [
+                labels[self.concat_start_index : self.concat_end_index_batch]
+            ]
+        selected_labels = labels[
+            self.augment_start_index : self.augment_end_index_batch
+        ]
+
+        if self.parallel_augment:
+            selected_labels = torch.cat(
+                [selected_labels] * self.N_augment, dim=0
+            )
+
+        augmented_labels = (
+            augmented_labels + [selected_labels] * self.repeat_augment
+        )
+
+        augmented_labels = torch.cat(augmented_labels, dim=0)
+
+        return augmented_labels
+
+    def check_min_max_augmentations(self):
+        """Checks the min_augmentations and max_augmentations arguments.
+        """
+        if self.min_augmentations is None:
+            self.min_augmentations = 1
+        if self.max_augmentations is None:
+            self.max_augmentations = len(self.augmentations)
+        if self.max_augmentations > len(self.augmentations):
+            self.max_augmentations = len(self.augmentations)
+        if self.min_augmentations > len(self.augmentations):
+            self.min_augmentations = len(self.augmentations)
diff --git a/speechbrain/augment/codec.py b/speechbrain/augment/codec.py
new file mode 100644
index 0000000000000000000000000000000000000000..48f7e0f5f0f7d87f8e0815ffd4d5747d4fddaecf
--- /dev/null
+++ b/speechbrain/augment/codec.py
@@ -0,0 +1,91 @@
+"""
+Codec Augmentation via torchaudio
+
+This library provides codec augmentation techniques in torchaudio for enhanced
+audio data processing.
+
+For detailed guidance and usage examples, refer to the tutorial at:
+https://pytorch.org/audio/stable/tutorials/audio_data_augmentation_tutorial.html
+
+Note: This code is compatible with FFmpeg as the torchaudio backend.
+When using FFmpeg2, the maximum number of samples for processing is limited to 16.
+
+Authors
+ * Mirco Ravanelli 2023
+"""
+
+import random
+import torch
+import torchaudio
+
+
+class CodecAugment(torch.nn.Module):
+    """
+    Apply random audio codecs to input waveforms using torchaudio.
+
+    This class provides an interface for applying codec augmentation techniques to audio data.
+
+    Arguments
+    ---------
+    sample_rate: int
+        The sample rate of the input waveform.
+
+    Example
+    -------
+    >>> waveform = torch.rand(4, 16000)
+    >>> if torchaudio.list_audio_backends()[0] == 'ffmpeg':
+    ...     augmenter = CodecAugment(16000)
+    ...     output_waveform = augmenter(waveform)
+    """
+
+    def __init__(self, sample_rate=16000):
+        super().__init__()
+        self.sample_rate = sample_rate
+        self.available_format_encoders = [
+            ("wav", "pcm_mulaw"),
+            ("mp3", None),
+            ("g722", None),
+        ]
+
+    def apply_codec(self, waveform, format=None, encoder=None):
+        """
+        Apply the selected audio codec.
+
+        Arguments
+        ----------
+        waveform: torch.Tensor
+            Input waveform of shape `[batch, time]`.
+        format: str
+            The audio format to use (e.g., "wav", "mp3"). Default is None.
+        encoder: str
+            The encoder to use for the format (e.g., "opus", "vorbis"). Default is None.
+
+        Returns
+        ---------
+        torch.Tensor:
+            Coded version of the input waveform of shape `[batch, time]`.
+        """
+        audio_effector = torchaudio.io.AudioEffector(
+            format=format, encoder=encoder
+        )
+        waveform_aug = audio_effector.apply(
+            waveform.transpose(0, 1).to("cpu"), self.sample_rate
+        )
+        return waveform_aug.transpose(0, 1).to(waveform.device)
+
+    def forward(self, waveform):
+        """
+        Apply a random audio codec from the available list.
+
+        Arguments
+        ---------
+        waveform: torch.Tensor
+            Input waveform of shape `[batch, time]`.
+
+        Returns
+        ---------
+        torch.Tensor
+            Coded version of the input waveform of shape `[batch, time]`.
+        """
+        format, encoder = random.choice(self.available_format_encoders)
+        return self.apply_codec(waveform, format=format, encoder=encoder)
diff --git a/speechbrain/augment/freq_domain.py b/speechbrain/augment/freq_domain.py
new file mode 100644
index 0000000000000000000000000000000000000000..6981ff2e0553ed97dcbcbe70236d928f7a159c8d
--- /dev/null
+++ b/speechbrain/augment/freq_domain.py
@@ -0,0 +1,394 @@
+"""Frequency-Domain Sequential Data Augmentation Classes
+
+This module comprises classes tailored for augmenting sequential data in the
+frequency domain, such as spectrograms and mel spectrograms.
+Its primary purpose is to enhance the resilience of neural models during the training process.
+
+Authors:
+- Peter Plantinga (2020)
+- Mirco Ravanelli (2023)
+"""
+
+import torch
+import random
+
+
+class SpectrogramDrop(torch.nn.Module):
+    """This class drops slices of the input spectrogram.
+
+    Using `SpectrogramDrop` as an augmentation strategy helps a models learn to rely
+    on all parts of the signal, since it can't expect a given part to be
+    present.
+
+    Reference:
+        https://arxiv.org/abs/1904.08779
+
+    Arguments
+    ---------
+    drop_length_low : int
+        The low end of lengths for which to drop the
+        spectrogram, in samples.
+    drop_length_high : int
+        The high end of lengths for which to drop the
+        signal, in samples.
+    drop_count_low : int
+        The low end of number of times that the signal
+        can be dropped.
+    drop_count_high : int
+        The high end of number of times that the signal
+        can be dropped.
+    replace: str
+        - 'zeros': Masked values are replaced with zeros.
+        - 'mean': Masked values are replaced with the mean value of the spectrogram.
+        - 'rand': Masked values are replaced with random numbers ranging between
+                  the maximum and minimum values of the spectrogram.
+        - 'cutcat': Masked values are replaced with chunks from other signals in the batch.
+        - 'swap': Masked values are replaced with other chunks from the same sentence.
+        - 'random_selection': A random selection among the approaches above.
+    dim : int
+        Corresponding dimension to mask. If dim=1, we apply time masking.
+        If dim=2, we apply frequency masking.
+
+    Example
+    -------
+    >>> # time-masking
+    >>> drop = SpectrogramDrop(dim=1)
+    >>> spectrogram = torch.rand(4, 150, 40)
+    >>> print(spectrogram.shape)
+    torch.Size([4, 150, 40])
+    >>> out = drop(spectrogram)
+    >>> print(out.shape)
+    torch.Size([4, 150, 40])
+    >>> # frequency-masking
+    >>> drop = SpectrogramDrop(dim=2)
+    >>> spectrogram = torch.rand(4, 150, 40)
+    >>> print(spectrogram.shape)
+    torch.Size([4, 150, 40])
+    >>> out = drop(spectrogram)
+    >>> print(out.shape)
+    torch.Size([4, 150, 40])
+    """
+
+    def __init__(
+        self,
+        drop_length_low=5,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="zeros",
+        dim=1,
+    ):
+        super().__init__()
+        self.drop_length_low = drop_length_low
+        self.drop_length_high = drop_length_high
+        self.drop_count_low = drop_count_low
+        self.drop_count_high = drop_count_high
+        self.replace = replace
+        self.dim = dim
+
+        # Validate low < high
+        if drop_length_low > drop_length_high:
+            raise ValueError("Low limit must not be more than high limit")
+        if drop_count_low > drop_count_high:
+            raise ValueError("Low limit must not be more than high limit")
+
+        self.replace_opts = [
+            "zeros",
+            "mean",
+            "rand",
+            "cutcat",
+            "swap",
+            "random_selection",
+        ]
+        if self.replace not in self.replace_opts:
+            raise ValueError(
+                f"Invalid 'replace' option. Select one of {', '.join(self.replace_opts)}"
+            )
+
+    def forward(self, spectrogram):
+        """
+        Apply the DropChunk augmentation to the input spectrogram.
+
+        This method randomly drops chunks of the input spectrogram to augment the data.
+
+        Arguments
+        ---------
+        spectrogram : torch.Tensor
+            Input spectrogram of shape `[batch, time, fea]`.
+
+        Returns
+        -------
+        torch.Tensor
+            Augmented spectrogram of shape `[batch, time, fea]`.
+        """
+
+        # Manage 4D tensors
+        if spectrogram.dim() == 4:
+            spectrogram = spectrogram.view(
+                -1, spectrogram.shape[2], spectrogram.shape[3]
+            )
+
+        # Get the batch size
+        batch_size, time_duration, fea_size = spectrogram.shape
+
+        # Managing masking dimensions
+        if self.dim == 1:
+            D = time_duration
+        else:
+            D = fea_size
+
+        # Randomly select the number of chunks to drop (same for all samples in the batch)
+        n_masks = torch.randint(
+            low=self.drop_count_low,
+            high=self.drop_count_high + 1,
+            size=(1,),
+            device=spectrogram.device,
+        )
+
+        # Randomly sample the lengths of the chunks to drop
+        mask_len = torch.randint(
+            low=self.drop_length_low,
+            high=self.drop_length_high,
+            size=(batch_size, n_masks),
+            device=spectrogram.device,
+        ).unsqueeze(2)
+
+        # Randomly sample the positions of the chunks to drop
+        mask_pos = torch.randint(
+            0,
+            max(1, D, -mask_len.max()),
+            (batch_size, n_masks),
+            device=spectrogram.device,
+        ).unsqueeze(2)
+
+        # Compute the mask for the selected chunk positions
+        arange = torch.arange(D, device=spectrogram.device).view(1, 1, -1)
+        mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
+        mask = mask.any(dim=1)
+        mask = mask.unsqueeze(2) if self.dim == 1 else mask.unsqueeze(1)
+
+        # Determine the value to replace the masked chunks (zero or mean of the spectrogram)
+        if self.replace == "random_selection":
+            self.replace = random.choice(self.replace_opts[:-1])
+
+        if self.replace == "zeros":
+            spectrogram = spectrogram.masked_fill_(mask, 0.0)
+        elif self.replace == "mean":
+            mean = spectrogram.mean().detach()
+            spectrogram = spectrogram.masked_fill_(mask, mean)
+        elif self.replace == "rand":
+            max_spectrogram = spectrogram.max().detach()
+            min_spectrogram = spectrogram.min().detach()
+            rand_spectrogram = torch.rand_like(spectrogram)
+            rand_spectrogram = (
+                rand_spectrogram * (max_spectrogram - min_spectrogram)
+                + min_spectrogram
+            )
+            mask = mask.float()
+            spectrogram = (1 - mask) * spectrogram + mask * rand_spectrogram
+        elif self.replace == "cutcat":
+            rolled_spectrogram = torch.roll(spectrogram, shifts=1, dims=0)
+            mask = mask.float()
+            spectrogram = (1 - mask) * spectrogram + mask * rolled_spectrogram
+        elif self.replace == "swap":
+            shift = torch.randint(
+                low=1,
+                high=spectrogram.shape[1],
+                size=(1,),
+                device=spectrogram.device,
+            )
+            rolled_spectrogram = torch.roll(
+                spectrogram, shifts=shift.item(), dims=1
+            )
+            mask = mask.float()
+            spectrogram = (1 - mask) * spectrogram + mask * rolled_spectrogram
+
+        return spectrogram.view(*spectrogram.shape)
+
+
+class Warping(torch.nn.Module):
+    """
+    Apply time or frequency warping to a spectrogram.
+
+    If `dim=1`, time warping is applied; if `dim=2`, frequency warping is applied.
+    This implementation selects a center and a window length to perform warping.
+    It ensures that the temporal dimension remains unchanged by upsampling or
+    downsampling the affected regions accordingly.
+
+    Reference:
+        https://arxiv.org/abs/1904.08779
+
+    Arguments
+    ---------
+    warp_window : int, optional
+        The width of the warping window. Default is 5.
+    warp_mode : str, optional
+        The interpolation mode for time warping. Default is "bicubic."
+    dim : int, optional
+        Dimension along which to apply warping (1 for time, 2 for frequency).
+        Default is 1.
+
+    Example
+    -------
+    >>> # Time-warping
+    >>> warp = Warping()
+    >>> spectrogram = torch.rand(4, 150, 40)
+    >>> print(spectrogram.shape)
+    torch.Size([4, 150, 40])
+    >>> out = warp(spectrogram)
+    >>> print(out.shape)
+    torch.Size([4, 150, 40])
+    >>> # Frequency-warping
+    >>> warp = Warping(dim=2)
+    >>> spectrogram = torch.rand(4, 150, 40)
+    >>> print(spectrogram.shape)
+    torch.Size([4, 150, 40])
+    >>> out = warp(spectrogram)
+    >>> print(out.shape)
+    torch.Size([4, 150, 40])
+    """
+
+    def __init__(self, warp_window=5, warp_mode="bicubic", dim=1):
+        super().__init__()
+        self.warp_window = warp_window
+        self.warp_mode = warp_mode
+        self.dim = dim
+
+    def forward(self, spectrogram):
+        """
+        Apply warping to the input spectrogram.
+
+        Arguments
+        ---------
+        spectrogram : torch.Tensor
+            Input spectrogram with shape `[batch, time, fea]`.
+
+        Returns
+        -------
+        torch.Tensor
+            Augmented spectrogram with shape `[batch, time, fea]`.
+        """
+
+        # Set warping dimension
+        if self.dim == 2:
+            spectrogram = spectrogram.transpose(1, 2)
+
+        original_size = spectrogram.shape
+        window = self.warp_window
+
+        # 2d interpolation requires 4D or higher dimension tensors
+        # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
+        if spectrogram.dim() == 3:
+            spectrogram = spectrogram.unsqueeze(1)
+
+        len_original = spectrogram.shape[2]
+        if len_original - window <= window:
+            return spectrogram.view(*original_size)
+
+        # Compute center and corresponding window
+        c = torch.randint(window, len_original - window, (1,))[0]
+        w = torch.randint(c - window, c + window, (1,))[0] + 1
+
+        # Update the left part of the spectrogram
+        left = torch.nn.functional.interpolate(
+            spectrogram[:, :, :c],
+            (w, spectrogram.shape[3]),
+            mode=self.warp_mode,
+            align_corners=True,
+        )
+
+        # Update the right part of the spectrogram.
+        # When the left part is expanded, the right part is compressed by the
+        # same factor, and vice versa.
+        right = torch.nn.functional.interpolate(
+            spectrogram[:, :, c:],
+            (len_original - w, spectrogram.shape[3]),
+            mode=self.warp_mode,
+            align_corners=True,
+        )
+
+        # Injecting the warped left and right parts.
+        spectrogram[:, :, :w] = left
+        spectrogram[:, :, w:] = right
+        spectrogram = spectrogram.view(*original_size)
+
+        # Transpose if freq warping is applied.
+        if self.dim == 2:
+            spectrogram = spectrogram.transpose(1, 2)
+
+        return spectrogram
+
+
+class RandomShift(torch.nn.Module):
+    """Shifts the input tensor by a random amount, allowing for either a time
+    or frequency (or channel) shift depending on the specified axis.
+    It is crucial to calibrate the minimum and maximum shifts according to the
+    requirements of your specific task.
+    We recommend using small shifts to preserve information integrity.
+    Using large shifts may result in the loss of significant data and could
+    potentially lead to misalignments with corresponding labels.
+
+    Arguments
+    ---------
+    min_shift : int
+        The mininum channel shift.
+    max_shift : int
+        The maximum channel shift.
+    dim: int
+        The dimension to shift.
+
+    Example
+    -------
+    >>> # time shift
+    >>> signal = torch.zeros(4, 100, 80)
+    >>> signal[0,50,:] = 1
+    >>> rand_shift =  RandomShift(dim=1, min_shift=-10, max_shift=10)
+    >>> lenghts = torch.tensor([0.2, 0.8, 0.9,1.0])
+    >>> output_signal, lenghts = rand_shift(signal,lenghts)
+
+    >>> # frequency shift
+    >>> signal = torch.zeros(4, 100, 80)
+    >>> signal[0,:,40] = 1
+    >>> rand_shift =  RandomShift(dim=2, min_shift=-10, max_shift=10)
+    >>> lenghts = torch.tensor([0.2, 0.8, 0.9,1.0])
+    >>> output_signal, lenghts = rand_shift(signal,lenghts)
+    """
+
+    def __init__(self, min_shift=0, max_shift=0, dim=1):
+        super().__init__()
+        self.min_shift = min_shift
+        self.max_shift = max_shift
+        self.dim = dim
+
+        # Check arguments
+        if self.max_shift < self.min_shift:
+            raise ValueError("max_shift must be  >= min_shift")
+
+    def forward(self, waveforms, lengths):
+        """
+        Arguments
+        ---------
+        waveforms : tensor
+            Shape should be `[batch, time]` or `[batch, time, channels]`.
+        lengths : tensor
+            Shape should be a single dimension, `[batch]`.
+
+        Returns
+        -------
+        Tensor of shape `[batch, time]` or `[batch, time, channels]`
+        """
+        # Pick a frequency to drop
+        N_shifts = torch.randint(
+            low=self.min_shift,
+            high=self.max_shift + 1,
+            size=(1,),
+            device=waveforms.device,
+        )
+        waveforms = torch.roll(waveforms, shifts=N_shifts.item(), dims=self.dim)
+
+        # Update lenghts in the case of temporal shift.
+        if self.dim == 1:
+            lengths = lengths + N_shifts / waveforms.shape[self.dim]
+            lengths = torch.clamp(lengths, min=0.0, max=1.0)
+
+        return waveforms, lengths
diff --git a/speechbrain/augment/preparation.py b/speechbrain/augment/preparation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ed0324bf6563ba6470357313b632f13bec91875
--- /dev/null
+++ b/speechbrain/augment/preparation.py
@@ -0,0 +1,208 @@
+"""Library for Downloading and Preparing Datasets for Data Augmentation,
+This library provides functions for downloading datasets from the web and
+preparing the necessary CSV data manifest files for use by data augmenters.
+
+Authors:
+* Mirco Ravanelli 2023
+
+"""
+
+import os
+import logging
+import torchaudio
+from speechbrain.utils.data_utils import download_file
+from speechbrain.utils.data_utils import get_all_files
+
+# Logger init
+logger = logging.getLogger(__name__)
+
+
+def prepare_dataset_from_URL(URL, dest_folder, ext, csv_file, max_length=None):
+    """Downloads a dataset containing recordings (e.g., noise sequences)
+    from the provided URL and prepares the necessary CSV files for use by the noise augmenter.
+
+    Arguments
+    ---------
+    URL: str
+        The URL of the dataset to download.
+    dest_folder : str
+        The local folder where the noisy dataset will be downloaded.
+    ext: str
+        File extensions to search for within the downloaded dataset.
+    csv_file : str
+        The path to store the prepared noise CSV file.
+    max_length : float
+        The maximum length in seconds.
+        Recordings longer than this will be automatically cut into pieces.
+    """
+
+    # Download and unpack if necessary
+    data_file = os.path.join(dest_folder, "data.zip")
+
+    if not os.path.isdir(dest_folder):
+        download_file(URL, data_file, unpack=True)
+    else:
+        download_file(URL, data_file)
+
+    # Prepare noise csv if necessary
+    if not os.path.isfile(csv_file):
+        filelist = get_all_files(dest_folder, match_and=["." + ext])
+        prepare_csv(filelist, csv_file, max_length)
+
+
+def prepare_csv(filelist, csv_file, max_length=None):
+    """Iterate a set of wavs and write the corresponding csv file.
+
+    Arguments
+    ---------
+    filelist : str
+        A list containing the paths of files of interest.
+    csv_file : str
+        The path to store the prepared noise CSV file.
+    max_length : float
+        The maximum length in seconds.
+        Recordings longer than this will be automatically cut into pieces.
+    """
+    try:
+        write_csv(filelist, csv_file, max_length)
+    except Exception as e:
+        # Handle the exception or log the error message
+        logger.error("Exception:", exc_info=(e))
+
+        # Delete the file if something fails
+        if os.path.exists(csv_file):
+            os.remove(csv_file)
+
+
+def write_csv(filelist, csv_file, max_length=None):
+    """
+    Iterate through a list of audio files and write the corresponding CSV file.
+
+    Arguments
+    ---------
+    filelist: list of str
+        A list containing the paths of audio files of interest.
+    csv_file: str
+        The path where to store the prepared noise CSV file.
+    max_lengthL float (optional):
+        The maximum recording length in seconds.
+        Recordings longer than this will be automatically cut into pieces.
+    """
+    with open(csv_file, "w") as w:
+        w.write("ID,duration,wav,wav_format,wav_opts\n")
+        for i, filename in enumerate(filelist):
+            _write_csv_row(w, filename, i, max_length)
+
+
+def _write_csv_row(w, filename, index, max_length):
+    """
+    Write a single row to the CSV file based on the audio file information.
+
+    Arguments
+    ---------
+    w: file
+        The open CSV file for writing.
+    filename: str
+        The path to the audio file.
+    index: int
+        The index of the audio file in the list.
+    max_length: float (optional)
+        The maximum recording length in seconds.
+    """
+    signal, rate = torchaudio.load(filename)
+    signal = _ensure_single_channel(signal, filename, rate)
+
+    ID, ext = os.path.basename(filename).split(".")
+    duration = signal.shape[1] / rate
+
+    if max_length is not None and duration > max_length:
+        _handle_long_waveform(
+            w, filename, ID, ext, signal, rate, duration, max_length, index
+        )
+    else:
+        _write_short_waveform_csv(w, ID, ext, duration, filename, index)
+
+
+def _ensure_single_channel(signal, filename, rate):
+    """
+    Ensure that the audio signal has only one channel.
+
+    Arguments
+    ---------
+    signal: Tensor
+        The audio signal.
+    filename: str
+        The path to the audio file.
+    rate: int
+        The sampling frequency of the signal.
+
+    Returns:
+    ---------
+        Torch.Tensor
+        The audio signal with a single channel.
+    """
+    if signal.shape[0] > 1:
+        signal = signal[0].unsqueeze(0)
+        torchaudio.save(filename, signal, rate)
+    return signal
+
+
+def _handle_long_waveform(
+    w, filename, ID, ext, signal, rate, duration, max_length, index
+):
+    """
+    Handle long audio waveforms by cutting them into pieces and writing to the CSV.
+
+    Arguments
+    ---------
+        w: file
+            The open CSV file for writing.
+        filename: str
+            The path to the audio file.
+        ID: str
+            The unique identifier for the audio.
+        ext:  str
+            The audio file extension.
+        signal: Tensor
+            The audio signal.
+        rate: int
+            The audio sample rate.
+        duration:  float
+            The duration of the audio in seconds.
+        max_length:  float
+            The maximum recording length in seconds.
+        index: int
+            The index of the audio file in the list.
+    """
+    os.remove(filename)
+    for j in range(int(duration / max_length)):
+        start = int(max_length * j * rate)
+        stop = int(min(max_length * (j + 1), duration) * rate)
+        ext = filename.split(".")[1]
+        new_filename = filename.replace("." + ext, "_" + str(j) + "." + ext)
+
+        torchaudio.save(new_filename, signal[:, start:stop], rate)
+        csv_row = (
+            f"{ID}_{index}_{j}",
+            str((stop - start) / rate),
+            new_filename,
+            ext,
+            "\n",
+        )
+        w.write(",".join(csv_row))
+
+
+def _write_short_waveform_csv(w, ID, ext, duration, filename, index):
+    """
+    Write a CSV row for a short audio waveform.
+
+    Arguments
+    ---------
+        w (file): The open CSV file for writing.
+        ID (str): The unique identifier for the audio.
+        ext (str): The audio file extension.
+        duration (float): The duration of the audio in seconds.
+        filename (str): The path to the audio file.
+        index (int): The index of the audio file in the list.
+    """
+    w.write(",".join((f"{ID}_{index}", str(duration), filename, ext, "\n",)))
diff --git a/speechbrain/processing/speech_augmentation.py b/speechbrain/augment/time_domain.py
similarity index 54%
rename from speechbrain/processing/speech_augmentation.py
rename to speechbrain/augment/time_domain.py
index cf5440537fb4acf57047554b2018e6a107e0a215..5a68942e80da48e683d51a6723e2fa7399ff455d 100644
--- a/speechbrain/processing/speech_augmentation.py
+++ b/speechbrain/augment/time_domain.py
@@ -1,21 +1,20 @@
-"""Classes for mutating speech data for data augmentation.
-
-This module provides classes that produce realistic distortions of speech
-data for the purpose of training speech processing models. The list of
-distortions includes adding noise, adding reverberation, changing speed,
-and more. All the classes are of type `torch.nn.Module`. This gives the
-possibility to have end-to-end differentiability and
-backpropagate the gradient through them. In addition, all operations
-are expected to be performed on the GPU (where available) for efficiency.
-
-Authors
- * Peter Plantinga 2020
+"""Time-Domain Sequential Data Augmentation Classes
+
+This module contains classes designed for augmenting sequential data in the time domain.
+It is particularly useful for enhancing the robustness of neural models during training.
+The available data distortions include adding noise, applying reverberation, adjusting playback speed, and more.
+All classes are implemented as `torch.nn.Module`, enabling end-to-end differentiability and gradient backpropagation.
+
+Authors:
+- Peter Plantinga (2020)
+- Mirco Ravanelli (2023)
 """
 
 # Importing libraries
-import math
+import random
 import torch
 import torch.nn.functional as F
+import torchaudio
 from speechbrain.dataio.legacy import ExtendedCSVDataset
 from speechbrain.dataio.dataloader import make_dataloader
 from speechbrain.processing.signal_processing import (
@@ -51,15 +50,18 @@ class AddNoise(torch.nn.Module):
         If True, copy noise signals that are shorter than
         their corresponding clean signals so as to cover the whole clean
         signal. Otherwise, leave the noise un-padded.
-    mix_prob : float
-        The probability that a batch of signals will be mixed
-        with a noise signal. By default, every batch is mixed with noise.
     start_index : int
         The index in the noise waveforms to start from. By default, chooses
         a random index in [0, len(noise) - len(waveforms)].
     normalize : bool
         If True, output noisy signals that exceed [-1,1] will be
         normalized to [-1,1].
+    noise_funct: funct object
+        function to use to draw a noisy sample. It is enabled if the csv files
+        containing the noisy sequences are not provided. By default,
+        torch.randn_like is used (to sample white noise). In general, it must
+        be a function that takes in input the original waveform and returns
+        a tensor with the corresponsing noise to add (e.g., see pink_noise_like).
     replacements : dict
         A set of string replacements to carry out in the
         csv file. Each time a key is found in the text, it will be replaced
@@ -91,9 +93,9 @@ class AddNoise(torch.nn.Module):
         snr_low=0,
         snr_high=0,
         pad_noise=False,
-        mix_prob=1.0,
         start_index=None,
         normalize=False,
+        noise_funct=torch.randn_like,
         replacements={},
         noise_sample_rate=16000,
         clean_sample_rate=16000,
@@ -107,13 +109,12 @@ class AddNoise(torch.nn.Module):
         self.snr_low = snr_low
         self.snr_high = snr_high
         self.pad_noise = pad_noise
-        self.mix_prob = mix_prob
         self.start_index = start_index
         self.normalize = normalize
         self.replacements = replacements
-
-        if noise_sample_rate != clean_sample_rate:
-            self.resampler = Resample(noise_sample_rate, clean_sample_rate)
+        self.noise_funct = noise_funct
+        self.noise_sample_rate = noise_sample_rate
+        self.clean_sample_rate = clean_sample_rate
 
     def forward(self, waveforms, lengths):
         """
@@ -133,37 +134,44 @@ class AddNoise(torch.nn.Module):
         noisy_waveform = waveforms.clone()
         lengths = (lengths * waveforms.shape[1]).unsqueeze(1)
 
-        # Don't add noise (return early) 1-`mix_prob` portion of the batches
-        if torch.rand(1) > self.mix_prob:
-            return noisy_waveform
-
         # Compute the average amplitude of the clean waveforms
-        clean_amplitude = compute_amplitude(waveforms, lengths)
+        clean_amplitude = compute_amplitude(waveforms, lengths, amp_type="rms")
 
         # Pick an SNR and use it to compute the mixture amplitude factors
         SNR = torch.rand(len(waveforms), 1, device=waveforms.device)
         SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low
         noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1)
-        new_noise_amplitude = noise_amplitude_factor * clean_amplitude
+
+        # Support for multichannel waveforms
+        if len(noisy_waveform.shape) == 3:
+            noise_amplitude_factor = noise_amplitude_factor.unsqueeze(1)
 
         # Scale clean signal appropriately
+        new_noise_amplitude = noise_amplitude_factor * clean_amplitude
         noisy_waveform *= 1 - noise_amplitude_factor
 
         # Loop through clean samples and create mixture
         if self.csv_file is None:
-            white_noise = torch.randn_like(waveforms)
-            noisy_waveform += new_noise_amplitude * white_noise
+            noise_waveform = self.noise_funct(waveforms)
+            if noise_waveform.shape[0] == 1:
+                noise_waveform = torch.cat(
+                    [noise_waveform] * waveforms.shape[0], dim=0
+                )
+
+            noise_length = lengths
         else:
             tensor_length = waveforms.shape[1]
             noise_waveform, noise_length = self._load_noise(
-                lengths, tensor_length,
+                lengths, tensor_length
             )
 
-            # Rescale and add
-            noise_amplitude = compute_amplitude(noise_waveform, noise_length)
-            noise_waveform *= new_noise_amplitude / (noise_amplitude + 1e-14)
-            noisy_waveform += noise_waveform
+        # Rescale and add
+        noise_amplitude = compute_amplitude(
+            noise_waveform, noise_length, amp_type="rms"
+        )
+        noise_waveform *= new_noise_amplitude / (noise_amplitude + 1e-14)
 
+        noisy_waveform += noise_waveform
         # Normalizing to prevent clipping
         if self.normalize:
             abs_max, _ = torch.max(
@@ -180,6 +188,12 @@ class AddNoise(torch.nn.Module):
 
         # Load a noise batch
         if not hasattr(self, "data_loader"):
+
+            if self.noise_sample_rate != self.clean_sample_rate:
+                self.resampler = Resample(
+                    self.noise_sample_rate, self.clean_sample_rate
+                )
+
             # Set parameters based on input
             self.device = lengths.device
 
@@ -303,9 +317,8 @@ class AddReverb(torch.nn.Module):
     sorting : str
         The order to iterate the csv file, from one of
         the following options: random, original, ascending, and descending.
-    reverb_prob : float
-        The chance that the audio signal will be reverbed.
-        By default, every batch is reverbed.
+    num_workers : int
+        Number of workers in the DataLoader (See PyTorch DataLoader docs).
     rir_scale_factor: float
         It compresses or dilates the given impulse response.
         If 0 < scale_factor < 1, the impulse response is compressed
@@ -330,14 +343,14 @@ class AddReverb(torch.nn.Module):
     >>> clean = signal.unsqueeze(0) # [batch, time, channels]
     >>> reverb = AddReverb('tests/samples/annotation/RIRs.csv',
     ...                     replacements={'rir_folder': 'tests/samples/RIRs'})
-    >>> reverbed = reverb(clean, torch.ones(1))
+    >>> reverbed = reverb(clean)
     """
 
     def __init__(
         self,
         csv_file,
         sorting="random",
-        reverb_prob=1.0,
+        num_workers=0,
         rir_scale_factor=1.0,
         replacements={},
         reverb_sample_rate=16000,
@@ -346,41 +359,28 @@ class AddReverb(torch.nn.Module):
         super().__init__()
         self.csv_file = csv_file
         self.sorting = sorting
-        self.reverb_prob = reverb_prob
+        self.num_workers = num_workers
         self.replacements = replacements
+        self.reverb_sample_rate = reverb_sample_rate
+        self.clean_sample_rate = clean_sample_rate
         self.rir_scale_factor = rir_scale_factor
 
-        # Create a data loader for the RIR waveforms
-        dataset = ExtendedCSVDataset(
-            csvpath=self.csv_file,
-            sorting=self.sorting if self.sorting != "random" else "original",
-            replacements=self.replacements,
-        )
-        self.data_loader = make_dataloader(
-            dataset, shuffle=(self.sorting == "random")
-        )
-        self.rir_data = iter(self.data_loader)
-
-        if reverb_sample_rate != clean_sample_rate:
-            self.resampler = Resample(reverb_sample_rate, clean_sample_rate)
-
-    def forward(self, waveforms, lengths):
+    def forward(self, waveforms):
         """
         Arguments
         ---------
         waveforms : tensor
             Shape should be `[batch, time]` or `[batch, time, channels]`.
-        lengths : tensor
-            Shape should be a single dimension, `[batch]`.
 
         Returns
         -------
         Tensor of shape `[batch, time]` or `[batch, time, channels]`.
         """
 
-        # Don't add reverb (return early) 1-`reverb_prob` portion of the time
-        if torch.rand(1) > self.reverb_prob:
-            return waveforms.clone()
+        if self.reverb_sample_rate != self.clean_sample_rate:
+            self.resampler = Resample(
+                self.reverb_sample_rate, self.clean_sample_rate
+            )
 
         # Add channels dimension if necessary
         channel_added = False
@@ -388,9 +388,6 @@ class AddReverb(torch.nn.Module):
             waveforms = waveforms.unsqueeze(-1)
             channel_added = True
 
-        # Convert length from ratio to number of indices
-        # lengths = (lengths * waveforms.shape[1])[:, None, None]
-
         # Load and prepare RIR
         rir_waveform = self._load_rir(waveforms)
 
@@ -417,6 +414,22 @@ class AddReverb(torch.nn.Module):
         return rev_waveform
 
     def _load_rir(self, waveforms):
+        # Create a data loader for the RIR waveforms
+        if not hasattr(self, "data_loader"):
+            dataset = ExtendedCSVDataset(
+                csvpath=self.csv_file,
+                sorting=self.sorting
+                if self.sorting != "random"
+                else "original",
+                replacements=self.replacements,
+            )
+            self.data_loader = make_dataloader(
+                dataset,
+                shuffle=(self.sorting == "random"),
+                num_workers=self.num_workers,
+            )
+            self.rir_data = iter(self.data_loader)
+
         try:
             rir_waveform, length = next(self.rir_data).at_position(0)
         except StopIteration:
@@ -446,9 +459,6 @@ class SpeedPerturb(torch.nn.Module):
     speeds : list
         The speeds that the signal should be changed to, as a percentage of the
         original signal (i.e. `speeds` is divided by 100 to get a ratio).
-    perturb_prob : float
-        The chance that the batch will be speed-
-        perturbed. By default, every batch is perturbed.
 
     Example
     -------
@@ -463,14 +473,10 @@ class SpeedPerturb(torch.nn.Module):
     torch.Size([1, 46956])
     """
 
-    def __init__(
-        self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0,
-    ):
+    def __init__(self, orig_freq, speeds=[90, 100, 110]):
         super().__init__()
         self.orig_freq = orig_freq
         self.speeds = speeds
-        self.perturb_prob = perturb_prob
-
         # Initialize index of perturbation
         self.samp_index = 0
 
@@ -489,30 +495,22 @@ class SpeedPerturb(torch.nn.Module):
         ---------
         waveforms : tensor
             Shape should be `[batch, time]` or `[batch, time, channels]`.
-        lengths : tensor
-            Shape should be a single dimension, `[batch]`.
 
         Returns
         -------
         Tensor of shape `[batch, time]` or `[batch, time, channels]`.
         """
 
-        # Don't perturb (return early) 1-`perturb_prob` portion of the batches
-        if torch.rand(1) > self.perturb_prob:
-            return waveform.clone()
-
         # Perform a random perturbation
-        self.samp_index = torch.randint(len(self.speeds), (1,))[0]
+        self.samp_index = torch.randint(0, len(self.speeds), (1,))
         perturbed_waveform = self.resamplers[self.samp_index](waveform)
-
         return perturbed_waveform
 
 
 class Resample(torch.nn.Module):
-    """This class resamples an audio signal using sinc-based interpolation.
-
-    It is a modification of the `resample` function from torchaudio
-    (https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html)
+    """This class resamples audio using the
+    :class:`torchaudio resampler <torchaudio.transforms.Resample>` based on
+    sinc interpolation.
 
     Arguments
     ---------
@@ -520,10 +518,12 @@ class Resample(torch.nn.Module):
         the sampling frequency of the input signal.
     new_freq : int
         the new sampling frequency after this operation is performed.
-    lowpass_filter_width : int
-        Controls the sharpness of the filter, larger numbers result in a
-        sharper filter, but they are less efficient. Values from 4 to 10 are
-        allowed.
+    *args
+        additional arguments forwarded to the
+        :class:`torchaudio.transforms.Resample` constructor
+    **kwargs
+        additional keyword arguments forwarded to the
+        :class:`torchaudio.transforms.Resample` constructor
 
     Example
     -------
@@ -538,33 +538,15 @@ class Resample(torch.nn.Module):
     torch.Size([1, 26087])
     """
 
-    def __init__(
-        self, orig_freq=16000, new_freq=16000, lowpass_filter_width=6,
-    ):
+    def __init__(self, orig_freq=16000, new_freq=16000, *args, **kwargs):
         super().__init__()
+
         self.orig_freq = orig_freq
         self.new_freq = new_freq
-        self.lowpass_filter_width = lowpass_filter_width
-
-        # Compute rate for striding
-        self._compute_strides()
-        assert self.orig_freq % self.conv_stride == 0
-        assert self.new_freq % self.conv_transpose_stride == 0
-
-    def _compute_strides(self):
-        """Compute the phases in polyphase filter.
-
-        (almost directly from torchaudio.compliance.kaldi)
-        """
 
-        # Compute new unit based on ratio of in/out frequencies
-        base_freq = math.gcd(self.orig_freq, self.new_freq)
-        input_samples_in_unit = self.orig_freq // base_freq
-        self.output_samples = self.new_freq // base_freq
-
-        # Store the appropriate stride based on the new units
-        self.conv_stride = input_samples_in_unit
-        self.conv_transpose_stride = self.output_samples
+        self.resampler = torchaudio.transforms.Resample(
+            orig_freq=orig_freq, new_freq=new_freq, *args, **kwargs,
+        )
 
     def forward(self, waveforms):
         """
@@ -572,17 +554,12 @@ class Resample(torch.nn.Module):
         ---------
         waveforms : tensor
             Shape should be `[batch, time]` or `[batch, time, channels]`.
-        lengths : tensor
-            Shape should be a single dimension, `[batch]`.
 
         Returns
         -------
         Tensor of shape `[batch, time]` or `[batch, time, channels]`.
         """
 
-        if not hasattr(self, "first_indices"):
-            self._indices_and_weights(waveforms)
-
         # Don't do anything if the frequencies are the same
         if self.orig_freq == self.new_freq:
             return waveforms
@@ -596,8 +573,15 @@ class Resample(torch.nn.Module):
         else:
             raise ValueError("Input must be 2 or 3 dimensions")
 
+        # If necessary, migrate the resampler to the current device, for
+        # backwards compat with scripts that do not call `resampler.to()`
+        # themselves.
+        # Please do not reuse the sample resampler for tensors that live on
+        # different devices, though.
+        self.resampler.to(waveforms.device)  # in-place
+
         # Do resampling
-        resampled_waveform = self._perform_resample(waveforms)
+        resampled_waveform = self.resampler(waveforms)
 
         if unsqueezed:
             resampled_waveform = resampled_waveform.squeeze(1)
@@ -606,305 +590,6 @@ class Resample(torch.nn.Module):
 
         return resampled_waveform
 
-    def _perform_resample(self, waveforms):
-        """Resamples the waveform at the new frequency.
-
-        This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a
-        LinearResample (resample a signal at linearly spaced intervals to
-        up/downsample a signal). LinearResample (LR) means that the output
-        signal is at linearly spaced intervals (i.e the output signal has a
-        frequency of `new_freq`). It uses sinc/bandlimited interpolation to
-        upsample/downsample the signal.
-
-        (almost directly from torchaudio.compliance.kaldi)
-
-        https://ccrma.stanford.edu/~jos/resample/
-        Theory_Ideal_Bandlimited_Interpolation.html
-
-        https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
-
-        Arguments
-        ---------
-        waveforms : tensor
-            The batch of audio signals to resample.
-
-        Returns
-        -------
-        The waveforms at the new frequency.
-        """
-
-        # Compute output size and initialize
-        batch_size, num_channels, wave_len = waveforms.size()
-        window_size = self.weights.size(1)
-        tot_output_samp = self._output_samples(wave_len)
-        resampled_waveform = torch.zeros(
-            (batch_size, num_channels, tot_output_samp),
-            device=waveforms.device,
-        )
-        self.weights = self.weights.to(waveforms.device)
-
-        # Check weights are on correct device
-        if waveforms.device != self.weights.device:
-            self.weights = self.weights.to(waveforms.device)
-
-        # eye size: (num_channels, num_channels, 1)
-        eye = torch.eye(num_channels, device=waveforms.device).unsqueeze(2)
-
-        # Iterate over the phases in the polyphase filter
-        for i in range(self.first_indices.size(0)):
-            wave_to_conv = waveforms
-            first_index = int(self.first_indices[i].item())
-            if first_index >= 0:
-                # trim the signal as the filter will not be applied
-                # before the first_index
-                wave_to_conv = wave_to_conv[..., first_index:]
-
-            # pad the right of the signal to allow partial convolutions
-            # meaning compute values for partial windows (e.g. end of the
-            # window is outside the signal length)
-            max_index = (tot_output_samp - 1) // self.output_samples
-            end_index = max_index * self.conv_stride + window_size
-            current_wave_len = wave_len - first_index
-            right_padding = max(0, end_index + 1 - current_wave_len)
-            left_padding = max(0, -first_index)
-            wave_to_conv = torch.nn.functional.pad(
-                wave_to_conv, (left_padding, right_padding)
-            )
-            conv_wave = torch.nn.functional.conv1d(
-                input=wave_to_conv,
-                weight=self.weights[i].repeat(num_channels, 1, 1),
-                stride=self.conv_stride,
-                groups=num_channels,
-            )
-
-            # we want conv_wave[:, i] to be at
-            # output[:, i + n*conv_transpose_stride]
-            dilated_conv_wave = torch.nn.functional.conv_transpose1d(
-                conv_wave, eye, stride=self.conv_transpose_stride
-            )
-
-            # pad dilated_conv_wave so it reaches the output length if needed.
-            left_padding = i
-            previous_padding = left_padding + dilated_conv_wave.size(-1)
-            right_padding = max(0, tot_output_samp - previous_padding)
-            dilated_conv_wave = torch.nn.functional.pad(
-                dilated_conv_wave, (left_padding, right_padding)
-            )
-            dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp]
-
-            resampled_waveform += dilated_conv_wave
-
-        return resampled_waveform
-
-    def _output_samples(self, input_num_samp):
-        """Based on LinearResample::GetNumOutputSamples.
-
-        LinearResample (LR) means that the output signal is at
-        linearly spaced intervals (i.e the output signal has a
-        frequency of ``new_freq``). It uses sinc/bandlimited
-        interpolation to upsample/downsample the signal.
-
-        (almost directly from torchaudio.compliance.kaldi)
-
-        Arguments
-        ---------
-        input_num_samp : int
-            The number of samples in each example in the batch.
-
-        Returns
-        -------
-        Number of samples in the output waveform.
-        """
-
-        # For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
-        # where tick_freq is the least common multiple of samp_in and
-        # samp_out.
-        samp_in = int(self.orig_freq)
-        samp_out = int(self.new_freq)
-
-        tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out)
-        ticks_per_input_period = tick_freq // samp_in
-
-        # work out the number of ticks in the time interval
-        # [ 0, input_num_samp/samp_in ).
-        interval_length = input_num_samp * ticks_per_input_period
-        if interval_length <= 0:
-            return 0
-        ticks_per_output_period = tick_freq // samp_out
-
-        # Get the last output-sample in the closed interval,
-        # i.e. replacing [ ) with [ ]. Note: integer division rounds down.
-        # See http://en.wikipedia.org/wiki/Interval_(mathematics) for an
-        # explanation of the notation.
-        last_output_samp = interval_length // ticks_per_output_period
-
-        # We need the last output-sample in the open interval, so if it
-        # takes us to the end of the interval exactly, subtract one.
-        if last_output_samp * ticks_per_output_period == interval_length:
-            last_output_samp -= 1
-
-        # First output-sample index is zero, so the number of output samples
-        # is the last output-sample plus one.
-        num_output_samp = last_output_samp + 1
-
-        return num_output_samp
-
-    def _indices_and_weights(self, waveforms):
-        """Based on LinearResample::SetIndexesAndWeights
-
-        Retrieves the weights for resampling as well as the indices in which
-        they are valid. LinearResample (LR) means that the output signal is at
-        linearly spaced intervals (i.e the output signal has a frequency
-        of ``new_freq``). It uses sinc/bandlimited interpolation to
-        upsample/downsample the signal.
-
-        Returns
-        -------
-        - the place where each filter should start being applied
-        - the filters to be applied to the signal for resampling
-        """
-
-        # Lowpass filter frequency depends on smaller of two frequencies
-        min_freq = min(self.orig_freq, self.new_freq)
-        lowpass_cutoff = 0.99 * 0.5 * min_freq
-
-        assert lowpass_cutoff * 2 <= min_freq
-        window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
-
-        assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
-        output_t = torch.arange(
-            start=0.0, end=self.output_samples, device=waveforms.device,
-        )
-        output_t /= self.new_freq
-        min_t = output_t - window_width
-        max_t = output_t + window_width
-
-        min_input_index = torch.ceil(min_t * self.orig_freq)
-        max_input_index = torch.floor(max_t * self.orig_freq)
-        num_indices = max_input_index - min_input_index + 1
-
-        max_weight_width = num_indices.max()
-        j = torch.arange(max_weight_width, device=waveforms.device)
-        input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0)
-        delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1)
-
-        weights = torch.zeros_like(delta_t)
-        inside_window_indices = delta_t.abs().lt(window_width)
-
-        # raised-cosine (Hanning) window with width `window_width`
-        weights[inside_window_indices] = 0.5 * (
-            1
-            + torch.cos(
-                2
-                * math.pi
-                * lowpass_cutoff
-                / self.lowpass_filter_width
-                * delta_t[inside_window_indices]
-            )
-        )
-
-        t_eq_zero_indices = delta_t.eq(0.0)
-        t_not_eq_zero_indices = ~t_eq_zero_indices
-
-        # sinc filter function
-        weights[t_not_eq_zero_indices] *= torch.sin(
-            2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]
-        ) / (math.pi * delta_t[t_not_eq_zero_indices])
-
-        # limit of the function at t = 0
-        weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
-
-        # size (output_samples, max_weight_width)
-        weights /= self.orig_freq
-
-        self.first_indices = min_input_index
-        self.weights = weights
-
-
-class AddBabble(torch.nn.Module):
-    """Simulate babble noise by mixing the signals in a batch.
-
-    Arguments
-    ---------
-    speaker_count : int
-        The number of signals to mix with the original signal.
-    snr_low : int
-        The low end of the mixing ratios, in decibels.
-    snr_high : int
-        The high end of the mixing ratios, in decibels.
-    mix_prob : float
-        The probability that the batch of signals will be
-        mixed with babble noise. By default, every signal is mixed.
-
-    Example
-    -------
-    >>> import pytest
-    >>> babbler = AddBabble()
-    >>> dataset = ExtendedCSVDataset(
-    ...     csvpath='tests/samples/annotation/speech.csv',
-    ...     replacements={"data_folder": "tests/samples/single-mic"}
-    ... )
-    >>> loader = make_dataloader(dataset, batch_size=5)
-    >>> speech, lengths = next(iter(loader)).at_position(0)
-    >>> noisy = babbler(speech, lengths)
-    """
-
-    def __init__(
-        self, speaker_count=3, snr_low=0, snr_high=0, mix_prob=1,
-    ):
-        super().__init__()
-        self.speaker_count = speaker_count
-        self.snr_low = snr_low
-        self.snr_high = snr_high
-        self.mix_prob = mix_prob
-
-    def forward(self, waveforms, lengths):
-        """
-        Arguments
-        ---------
-        waveforms : tensor
-            A batch of audio signals to process, with shape `[batch, time]` or
-            `[batch, time, channels]`.
-        lengths : tensor
-            The length of each audio in the batch, with shape `[batch]`.
-
-        Returns
-        -------
-        Tensor with processed waveforms.
-        """
-
-        babbled_waveform = waveforms.clone()
-        lengths = (lengths * waveforms.shape[1]).unsqueeze(1)
-        batch_size = len(waveforms)
-
-        # Don't mix (return early) 1-`mix_prob` portion of the batches
-        if torch.rand(1) > self.mix_prob:
-            return babbled_waveform
-
-        # Pick an SNR and use it to compute the mixture amplitude factors
-        clean_amplitude = compute_amplitude(waveforms, lengths)
-        SNR = torch.rand(batch_size, 1, device=waveforms.device)
-        SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low
-        noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1)
-        new_noise_amplitude = noise_amplitude_factor * clean_amplitude
-
-        # Scale clean signal appropriately
-        babbled_waveform *= 1 - noise_amplitude_factor
-
-        # For each speaker in the mixture, roll and add
-        babble_waveform = waveforms.roll((1,), dims=0)
-        babble_len = lengths.roll((1,), dims=0)
-        for i in range(1, self.speaker_count):
-            babble_waveform += waveforms.roll((1 + i,), dims=0)
-            babble_len = torch.max(babble_len, babble_len.roll((1,), dims=0))
-
-        # Rescale and add to mixture
-        babble_amplitude = compute_amplitude(babble_waveform, babble_len)
-        babble_waveform *= new_noise_amplitude / (babble_amplitude + 1e-14)
-        babbled_waveform += babble_waveform
-
-        return babbled_waveform
-
 
 class DropFreq(torch.nn.Module):
     """This class drops a random frequency from the signal.
@@ -920,16 +605,13 @@ class DropFreq(torch.nn.Module):
     drop_freq_high : float
         The high end of frequencies that can be
         dropped, as a fraction of the sampling rate / 2.
-    drop_count_low : int
+    drop_freq_count_low : int
         The low end of number of frequencies that could be dropped.
-    drop_count_high : int
+    drop_freq_count_high : int
         The high end of number of frequencies that could be dropped.
-    drop_width : float
+    drop_freq_width : float
         The width of the frequency band to drop, as
         a fraction of the sampling_rate / 2.
-    drop_prob : float
-        The probability that the batch of signals will  have a frequency
-        dropped. By default, every batch has frequencies dropped.
 
     Example
     -------
@@ -943,18 +625,16 @@ class DropFreq(torch.nn.Module):
         self,
         drop_freq_low=1e-14,
         drop_freq_high=1,
-        drop_count_low=1,
-        drop_count_high=2,
-        drop_width=0.05,
-        drop_prob=1,
+        drop_freq_count_low=1,
+        drop_freq_count_high=3,
+        drop_freq_width=0.05,
     ):
         super().__init__()
         self.drop_freq_low = drop_freq_low
         self.drop_freq_high = drop_freq_high
-        self.drop_count_low = drop_count_low
-        self.drop_count_high = drop_count_high
-        self.drop_width = drop_width
-        self.drop_prob = drop_prob
+        self.drop_freq_count_low = drop_freq_count_low
+        self.drop_freq_count_high = drop_freq_count_high
+        self.drop_freq_width = drop_freq_width
 
     def forward(self, waveforms):
         """
@@ -970,8 +650,6 @@ class DropFreq(torch.nn.Module):
 
         # Don't drop (return early) 1-`drop_prob` portion of the batches
         dropped_waveform = waveforms.clone()
-        if torch.rand(1) > self.drop_prob:
-            return dropped_waveform
 
         # Add channels dimension
         if len(waveforms.shape) == 2:
@@ -979,7 +657,9 @@ class DropFreq(torch.nn.Module):
 
         # Pick number of frequencies to drop
         drop_count = torch.randint(
-            low=self.drop_count_low, high=self.drop_count_high + 1, size=(1,),
+            low=self.drop_freq_count_low,
+            high=self.drop_freq_count_high + 1,
+            size=(1,),
         )
 
         # Pick a frequency to drop
@@ -999,13 +679,26 @@ class DropFreq(torch.nn.Module):
         # Subtract each frequency
         for frequency in drop_frequency:
             notch_kernel = notch_filter(
-                frequency, filter_length, self.drop_width,
+                frequency, filter_length, self.drop_freq_width
             ).to(waveforms.device)
             drop_filter = convolve1d(drop_filter, notch_kernel, pad)
 
+        # Manage multiple channels
+        if len(waveforms.shape) == 3:
+            dropped_waveform = dropped_waveform.reshape(
+                dropped_waveform.shape[0] * dropped_waveform.shape[2],
+                dropped_waveform.shape[1],
+                1,
+            )
+
         # Apply filter
         dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad)
 
+        if len(waveforms.shape) == 3:
+            dropped_waveform = dropped_waveform.reshape(
+                waveforms.shape[0], waveforms.shape[1], waveforms.shape[2]
+            )
+
         # Remove channels dimension if added
         return dropped_waveform.squeeze(-1)
 
@@ -1035,10 +728,6 @@ class DropChunk(torch.nn.Module):
         The first index for which dropping will be allowed.
     drop_end : int
         The last index for which dropping will be allowed.
-    drop_prob : float
-        The probability that the batch of signals will
-        have a portion dropped. By default, every batch
-        has portions dropped.
     noise_factor : float
         The factor relative to average amplitude of an utterance
         to use for scaling the white noise inserted. 1 keeps
@@ -1061,10 +750,9 @@ class DropChunk(torch.nn.Module):
         drop_length_low=100,
         drop_length_high=1000,
         drop_count_low=1,
-        drop_count_high=10,
+        drop_count_high=3,
         drop_start=0,
         drop_end=None,
-        drop_prob=1,
         noise_factor=0.0,
     ):
         super().__init__()
@@ -1074,7 +762,6 @@ class DropChunk(torch.nn.Module):
         self.drop_count_high = drop_count_high
         self.drop_start = drop_start
         self.drop_end = drop_end
-        self.drop_prob = drop_prob
         self.noise_factor = noise_factor
 
         # Validate low < high
@@ -1112,10 +799,6 @@ class DropChunk(torch.nn.Module):
         batch_size = waveforms.size(0)
         dropped_waveform = waveforms.clone()
 
-        # Don't drop (return early) 1-`drop_prob` portion of the batches
-        if torch.rand(1) > self.drop_prob:
-            return dropped_waveform
-
         # Store original amplitude for computing white noise amplitude
         clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1))
 
@@ -1151,7 +834,7 @@ class DropChunk(torch.nn.Module):
 
             # Pick starting locations
             start = torch.randint(
-                low=start_min, high=start_max + 1, size=(drop_times[i],),
+                low=start_min, high=start_max + 1, size=(drop_times[i],)
             )
 
             end = start + length
@@ -1173,8 +856,194 @@ class DropChunk(torch.nn.Module):
         return dropped_waveform
 
 
+class FastDropChunk(torch.nn.Module):
+    """This class drops portions of the input signal. The difference with
+    DropChunk is that in this case we pre-compute the dropping masks in the
+    first time the forward function is called. For all the other calls, we only
+    shuffle and apply them. This makes the code faster and more suitable for
+    data augmentation of large batches.
+
+    It can be used only for fixed-length sequences.
+
+    Arguments
+    ---------
+    drop_length_low : int
+        The low end of lengths for which to set the
+        signal to zero, in samples.
+    drop_length_high : int
+        The high end of lengths for which to set the
+        signal to zero, in samples.
+    drop_count_low : int
+        The low end of number of times that the signal
+        can be dropped to zero.
+    drop_count_high : int
+        The high end of number of times that the signal
+        can be dropped to zero.
+    drop_start : int
+        The first index for which dropping will be allowed.
+    drop_end : int
+        The last index for which dropping will be allowed.
+    n_masks : int
+        The number of precomputed masks.
+
+    Example
+    -------
+    >>> from speechbrain.dataio.dataio import read_audio
+    >>> dropper = FastDropChunk(drop_start=100, drop_end=200)
+    >>> signal = torch.rand(10, 250, 22)
+    >>> dropped_signal = dropper(signal)
+    """
+
+    def __init__(
+        self,
+        drop_length_low=100,
+        drop_length_high=1000,
+        drop_count_low=1,
+        drop_count_high=10,
+        drop_start=0,
+        drop_end=None,
+        n_masks=1000,
+    ):
+        super().__init__()
+        self.drop_length_low = drop_length_low
+        self.drop_length_high = drop_length_high
+        self.drop_count_low = drop_count_low
+        self.drop_count_high = drop_count_high
+        self.drop_start = drop_start
+        self.drop_end = drop_end
+        self.n_masks = n_masks
+        self.first = True
+
+        # Validate low < high
+        if drop_length_low > drop_length_high:
+            raise ValueError("Low limit must not be more than high limit")
+        if drop_count_low > drop_count_high:
+            raise ValueError("Low limit must not be more than high limit")
+
+        # Make sure the length doesn't exceed end - start
+        if drop_end is not None and drop_end >= 0:
+            if drop_start > drop_end:
+                raise ValueError("Low limit must not be more than high limit")
+            drop_range = drop_end - drop_start
+            self.drop_length_low = min(drop_length_low, drop_range)
+            self.drop_length_high = min(drop_length_high, drop_range)
+
+    def initialize_masks(self, waveforms):
+        """
+        Arguments
+        ---------
+        waveforms : tensor
+            Shape should be `[batch, time]` or `[batch, time, channels]`.
+`.
+        Returns
+        -------
+        dropped_masks: tensor
+            Tensor of size `[n_masks, time]` with the dropped chunks. Dropped
+            regions are assigned to 0.
+        """
+
+        if self.n_masks < waveforms.shape[0]:
+            raise ValueError("n_mask cannot be smaller than the batch size")
+
+        # Initiaizing the drop mask
+        dropped_masks = torch.ones(
+            [self.n_masks, self.sig_len], device=waveforms.device
+        )
+
+        # Pick a number of times to drop
+        drop_times = torch.randint(
+            low=self.drop_count_low,
+            high=self.drop_count_high + 1,
+            size=(self.n_masks,),
+            device=waveforms.device,
+        )
+
+        # Iterate batch to set mask
+        for i in range(self.n_masks):
+            if drop_times[i] == 0:
+                continue
+
+            # Pick lengths
+            length = torch.randint(
+                low=self.drop_length_low,
+                high=self.drop_length_high + 1,
+                size=(drop_times[i],),
+                device=waveforms.device,
+            )
+
+            # Compute range of starting locations
+            start_min = self.drop_start
+            if start_min < 0:
+                start_min += self.sig_len
+            start_max = self.drop_end
+            if start_max is None:
+                start_max = self.sig_len
+            if start_max < 0:
+                start_max += self.sig_len
+            start_max = max(0, start_max - length.max())
+
+            # Pick starting locations
+            start = torch.randint(
+                low=start_min,
+                high=start_max + 1,
+                size=(drop_times[i],),
+                device=waveforms.device,
+            )
+
+            end = start + length
+
+            # Update waveform
+            for j in range(drop_times[i]):
+                dropped_masks[i, start[j] : end[j]] = 0.0
+
+        return dropped_masks
+
+    def forward(self, waveforms):
+        """
+        Arguments
+        ---------
+        waveforms : tensor
+            Shape should be `[batch, time]` or `[batch, time, channels]`.
+
+        Returns
+        -------
+        Tensor of shape `[batch, time]` or `[batch, time, channels]`
+        """
+
+        dropped_waveforms = waveforms.clone()
+
+        # Initialize the masks
+        if self.first:
+            self.sig_len = waveforms.shape[1]
+            self.dropped_masks = self.initialize_masks(waveforms)
+            self.first = False
+
+        # Random Permutation
+        rand_perm = torch.randperm(self.dropped_masks.shape[0])
+        self.dropped_masks = self.dropped_masks[rand_perm, :]
+
+        # Random shift in time
+        rand_shifts = torch.randint(low=0, high=self.sig_len, size=(1,))
+        self.dropped_masks = torch.roll(
+            self.dropped_masks, shifts=rand_shifts.item(), dims=1
+        )
+
+        if len(waveforms.shape) == 3:
+            dropped_waveforms = dropped_waveforms * self.dropped_masks[
+                0 : waveforms.shape[0]
+            ].unsqueeze(2)
+        else:
+            dropped_waveforms = (
+                dropped_waveforms * self.dropped_masks[0 : waveforms.shape[0]]
+            )
+
+        return dropped_waveforms
+
+
 class DoClip(torch.nn.Module):
     """This function mimics audio clipping by clamping the input tensor.
+    First, it normalizes the waveforms from -1 to -1. Then, clipping is applied.
+    Finally, the original amplitude is restored.
 
     Arguments
     ---------
@@ -1182,9 +1051,6 @@ class DoClip(torch.nn.Module):
         The low end of amplitudes for which to clip the signal.
     clip_high : float
         The high end of amplitudes for which to clip the signal.
-    clip_prob : float
-        The probability that the batch of signals will have a portion clipped.
-        By default, every batch has portions clipped.
 
     Example
     -------
@@ -1192,17 +1058,12 @@ class DoClip(torch.nn.Module):
     >>> clipper = DoClip(clip_low=0.01, clip_high=0.01)
     >>> signal = read_audio('tests/samples/single-mic/example1.wav')
     >>> clipped_signal = clipper(signal.unsqueeze(0))
-    >>> "%.2f" % clipped_signal.max()
-    '0.01'
     """
 
-    def __init__(
-        self, clip_low=0.5, clip_high=1, clip_prob=1,
-    ):
+    def __init__(self, clip_low=0.5, clip_high=0.5):
         super().__init__()
         self.clip_low = clip_low
         self.clip_high = clip_high
-        self.clip_prob = clip_prob
 
     def forward(self, waveforms):
         """
@@ -1216,15 +1077,385 @@ class DoClip(torch.nn.Module):
         Tensor of shape `[batch, time]` or `[batch, time, channels]`
         """
 
-        # Don't clip (return early) 1-`clip_prob` portion of the batches
-        if torch.rand(1) > self.clip_prob:
-            return waveforms.clone()
+        # Normalize the signal
+        abs_max, _ = torch.max(torch.abs(waveforms), dim=1, keepdim=True)
+        waveforms = waveforms / abs_max
 
         # Randomly select clip value
         clipping_range = self.clip_high - self.clip_low
-        clip_value = torch.rand(1,)[0] * clipping_range + self.clip_low
+        clip_value = (
+            torch.rand(1, device=waveforms.device)[0] * clipping_range
+            + self.clip_low
+        )
 
         # Apply clipping
         clipped_waveform = waveforms.clamp(-clip_value, clip_value)
 
+        # Restore orignal amplitude
+        clipped_waveform = clipped_waveform * abs_max / clip_value
+
         return clipped_waveform
+
+
+class RandAmp(torch.nn.Module):
+    """This function multiples the signal by a random amplitude. Firist, the
+    signal is normalized to have amplitude between -1 and 1. Then it is
+    multiplied with a random number.
+
+    Arguments
+    ---------
+    amp_low : float
+        The minumum amplitude multiplication factor.
+    amp_high : float
+        The maximum amplitude multiplication factor.
+
+    Example
+    -------
+    >>> from speechbrain.dataio.dataio import read_audio
+    >>> rand_amp = RandAmp(amp_low=0.25, amp_high=1.75)
+    >>> signal = read_audio('tests/samples/single-mic/example1.wav')
+    >>> output_signal = rand_amp(signal.unsqueeze(0))
+    """
+
+    def __init__(self, amp_low=0.5, amp_high=1.5):
+        super().__init__()
+        self.amp_low = amp_low
+        self.amp_high = amp_high
+
+    def forward(self, waveforms):
+        """
+        Arguments
+        ---------
+        waveforms : tensor
+            Shape should be `[batch, time]` or `[batch, time, channels]`.
+
+        Returns
+        -------
+        Tensor of shape `[batch, time]` or `[batch, time, channels]`
+        """
+
+        # Normalize the signal
+        abs_max, _ = torch.max(torch.abs(waveforms), dim=1, keepdim=True)
+        waveforms = waveforms / abs_max
+
+        # Pick a frequency to drop
+        rand_range = self.amp_high - self.amp_low
+        amp = (
+            torch.rand(waveforms.shape[0], device=waveforms.device) * rand_range
+            + self.amp_low
+        )
+        amp = amp.unsqueeze(1)
+        if len(waveforms.shape) == 3:
+            amp = amp.unsqueeze(2)
+        waveforms = waveforms * amp
+
+        return waveforms
+
+
+class ChannelDrop(torch.nn.Module):
+    """This function drops random channels in the multi-channel nput waveform.
+
+    Arguments
+    ---------
+    drop_rate : float
+        The channel droput factor
+
+    Example
+    -------
+    >>> signal = torch.rand(4, 256, 8)
+    >>> ch_drop = ChannelDrop(drop_rate=0.5)
+    >>> output_signal = ch_drop(signal)
+    """
+
+    def __init__(self, drop_rate=0.1):
+        super().__init__()
+        self.drop_rate = drop_rate
+
+    def forward(self, waveforms):
+        """
+        Arguments
+        ---------
+        waveforms : tensor
+            Shape should be `[batch, time]` or `[batch, time, channels]`.
+
+        Returns
+        -------
+        Tensor of shape `[batch, time]` or `[batch, time, channels]`
+        """
+
+        # Pick a channel to drop
+        x = torch.rand(waveforms.shape[-1], device=waveforms.device)
+        channel_mask = x.ge(self.drop_rate)
+        waveforms = waveforms * channel_mask.unsqueeze(0).unsqueeze(1)
+        return waveforms
+
+
+class ChannelSwap(torch.nn.Module):
+    """This function randomly swaps N channels.
+
+    Arguments
+    ---------
+    min_swap : int
+        The mininum number of channels to swap.
+    max_swap : int
+        The maximum number of channels to swap.
+
+    Example
+    -------
+    >>> signal = torch.rand(4, 256, 8)
+    >>> ch_swap = ChannelSwap()
+    >>> output_signal = ch_swap(signal)
+    """
+
+    def __init__(self, min_swap=0, max_swap=0):
+        super().__init__()
+        self.min_swap = min_swap
+        self.max_swap = max_swap
+
+        # Check arguments
+        if self.min_swap < 0:
+            raise ValueError("min_swap must be  >= 0.")
+        if self.max_swap < 0:
+            raise ValueError("max_swap must be  >= 0.")
+        if self.max_swap < self.min_swap:
+            raise ValueError("max_swap must be  >= min_swap")
+
+    def forward(self, waveforms):
+        """
+        Arguments
+        ---------
+        waveforms : tensor
+            Shape should be `[batch, time]` or `[batch, time, channels]`.
+
+        Returns
+        -------
+        Tensor of shape `[batch, time]` or `[batch, time, channels]`
+        """
+
+        # Pick a frequency to drop
+        rand_perm1 = torch.randperm(waveforms.shape[-1])
+        rand_perm2 = torch.randperm(waveforms.shape[-1])
+        N_swaps = torch.randint(
+            low=self.min_swap, high=self.max_swap + 1, size=(1,)
+        )
+
+        if N_swaps < waveforms.shape[-1]:
+            for i in range(N_swaps):
+                store_channel = waveforms[:, :, rand_perm2[i]]
+                waveforms[:, :, rand_perm2[i]] = waveforms[:, :, rand_perm1[i]]
+                waveforms[:, :, rand_perm1[i]] = store_channel
+        else:
+            # Full swap
+            waveforms = waveforms[:, :, rand_perm1]
+
+        return waveforms
+
+
+class CutCat(torch.nn.Module):
+    """This function combines segments (with equal length in time) of the time series contained in the batch.
+    Proposed for EEG signals in https://doi.org/10.1016/j.neunet.2021.05.032.
+
+    Arguments
+    ---------
+    min_num_segments : int
+        The number of segments to combine.
+    max_num_segments : int
+        The maximum number of segments to combine. Default is 10.
+
+    Example
+    -------
+    >>> signal = torch.ones((4, 256, 22)) * torch.arange(4).reshape((4, 1, 1,))
+    >>> cutcat =  CutCat()
+    >>> output_signal = cutcat(signal)
+    """
+
+    def __init__(self, min_num_segments=2, max_num_segments=10):
+        super().__init__()
+        self.min_num_segments = min_num_segments
+        self.max_num_segments = max_num_segments
+        # Check arguments
+        if self.max_num_segments < self.min_num_segments:
+            raise ValueError("max_num_segments must be  >= min_num_segments")
+
+    def forward(self, waveforms):
+        """
+        Arguments
+        ---------
+        waveforms : tensor
+            Shape should be `[batch, time]` or `[batch, time, channels]`.
+
+        Returns
+        -------
+        Tensor of shape `[batch, time]` or `[batch, time, channels]`
+        """
+        if (
+            waveforms.shape[0] > 1
+        ):  # only if there are at least 2 examples in batch
+            # rolling waveforms to point to segments of other examples in batch
+            waveforms_rolled = torch.roll(waveforms, shifts=1, dims=0)
+            # picking number of segments to use
+            num_segments = torch.randint(
+                low=self.min_num_segments,
+                high=self.max_num_segments + 1,
+                size=(1,),
+            )
+            # index of cuts (both starts and stops)
+            idx_cut = torch.linspace(
+                0, waveforms.shape[1], num_segments.item() + 1, dtype=torch.int
+            )
+            for i in range(idx_cut.shape[0] - 1):
+                # half of segments from other examples in batch
+                if i % 2 == 1:
+                    start = idx_cut[i]
+                    stop = idx_cut[i + 1]
+                    waveforms[:, start:stop, ...] = waveforms_rolled[
+                        :, start:stop, ...  # noqa: W504
+                    ]
+
+        return waveforms
+
+
+def pink_noise_like(waveforms, alpha_low=1.0, alpha_high=1.0, sample_rate=50):
+    """Creates a sequence of pink noise (also known as 1/f). The pink noise
+    is obtained by multipling the spectrum of a white noise sequence by a
+    factor (1/f^alpha).
+    The alpha factor controls the decrease factor in the frequnecy domain
+    (alpha=0 adds white noise, alpha>>0 adds low frequnecy noise). It is
+    randomly sampled between alpha_low and alpha_high. With negative alpha this
+    funtion generates blue noise.
+
+    Arguments
+    ---------
+    waveforms : torch.Tensor
+        The original waveform. It is just used to infer the shape.
+    alpha_low : float
+        The minimum value for the alpha spectral smooting factor.
+    alpha_high : float
+        The maximum value for the alpha spectral smooting factor.
+    sample_rate : float
+        The sample rate of the original signal.
+
+    Example
+    -------
+    >>> waveforms = torch.randn(4,257,10)
+    >>> noise = pink_noise_like(waveforms)
+    >>> noise.shape
+    torch.Size([4, 257, 10])
+    """
+    # Sampling white noise (flat spectrum)
+    white_noise = torch.randn_like(waveforms)
+
+    # Computing the fft of the input white noise
+    white_noise_fft = torch.fft.fft(white_noise, dim=1)
+
+    # Sampling the spectral smoothing factor
+    rand_range = alpha_high - alpha_low
+    alpha = (
+        torch.rand(waveforms.shape[0], device=waveforms.device) * rand_range
+        + alpha_low
+    )
+
+    # preparing the spectral mask (1/f^alpha)
+    f = torch.linspace(
+        0,
+        sample_rate / 2,
+        int(white_noise.shape[1] / 2),
+        device=waveforms.device,
+    )
+    spectral_mask = 1 / torch.pow(f.unsqueeze(0), alpha.unsqueeze(1))
+
+    # Avoid inf due to 1/0 division at f=0
+    spectral_mask[:, 0] = spectral_mask[:, 1]
+
+    # Mask for the upper part of the spectrum (f > sample_rate/2)
+    spectral_mask_up = torch.flip(spectral_mask, dims=(1,))
+
+    # Managing odd/even sequences
+    if white_noise.shape[1] % 2:
+        mid_element = spectral_mask[
+            :, int(white_noise.shape[1] / 2) - 1
+        ].unsqueeze(1)
+        spectral_mask = torch.cat(
+            [spectral_mask, mid_element, spectral_mask_up], dim=1
+        )
+    else:
+        spectral_mask = torch.cat([spectral_mask, spectral_mask_up], dim=1)
+
+    # Managing multi-channel inputs
+    if len(white_noise.shape) == 3:
+        spectral_mask = spectral_mask.unsqueeze(2)
+
+    # Spectral masking
+    pink_noise_fft = white_noise_fft * spectral_mask
+
+    # Return to the time-domain
+    pink_noise = torch.fft.ifft(pink_noise_fft, dim=1).real
+    return pink_noise
+
+
+class DropBitResolution(torch.nn.Module):
+    """
+    This class transforms a float32 tensor into a lower resolution one
+    (e.g., int16, int8, float16) and then converts it back to a float32.
+    This process loses information and can be used for data augmentation.
+
+    Arguments:
+    ---------
+        target_dtype: str
+            One of "int16", "int8", "float16". If "random", the bit resolution
+            is randomly selected among the options listed above.
+
+    Example:
+        >>> dropper = DropBitResolution()
+        >>> signal = torch.rand(4, 16000)
+        >>> signal_dropped = dropper(signal)
+    """
+
+    def __init__(self, target_dtype="random"):
+        super().__init__()
+
+        self.target_dtype = target_dtype
+        self.bit_depths = {
+            "int16": (16, torch.int16),
+            "int8": (8, torch.int8),
+            "float16": (16, torch.float16),
+        }
+
+        if (
+            self.target_dtype != "random"
+            and self.target_dtype not in self.bit_depths
+        ):
+            raise ValueError(
+                f"target_dtype must be one of {list(self.bit_depths.keys())}"
+            )
+
+    def forward(self, float32_tensor):
+        """
+        Arguments:
+        ---------
+            float32_tensor: torch.Tensor
+                Float32 tensor with shape `[batch, time]` or `[batch, time, channels]`.
+
+        Returns:
+        ---------
+            torch.Tensor
+                Tensor of shape `[batch, time]` or `[batch, time, channels]` (Float32)
+        """
+
+        if self.target_dtype == "random":
+            random_key = random.choice(list(self.bit_depths.keys()))
+            bit, target_dtype = self.bit_depths[random_key]
+        else:
+            bit, target_dtype = self.bit_depths[self.target_dtype]
+
+        # Define a scale factor to map the float32 range to the target bit depth
+        if target_dtype != torch.float16:
+            scale_factor = (2 ** (bit - 1) - 1) / float32_tensor.abs().max()
+            quantized_tensor = (float32_tensor * scale_factor).to(target_dtype)
+        else:
+            quantized_tensor = float32_tensor.half()
+            scale_factor = 1
+
+        # To dequantize and recover the original float32 values
+        dequantized_tensor = quantized_tensor.to(torch.float32) / scale_factor
+        return dequantized_tensor
diff --git a/speechbrain/core.py b/speechbrain/core.py
index 9ded4c0ce3f6d07a2b0712a6380ea059806ae014..0f3079609ee76ba22421463f79b8ddeb114d8585 100644
--- a/speechbrain/core.py
+++ b/speechbrain/core.py
@@ -35,12 +35,13 @@ from torch.utils.data import IterableDataset
 from torch.utils.data import DistributedSampler
 from torch.nn.parallel import DistributedDataParallel as DDP
 from hyperpyyaml import resolve_references
-from speechbrain.utils.distributed import if_main_process
 from speechbrain.utils.optimizers import rm_vector_weight_decay
 from speechbrain.dataio.dataloader import LoopedLoader
 from speechbrain.dataio.dataloader import SaveableDataLoader
 from speechbrain.dataio.sampler import DistributedSamplerWrapper
 from speechbrain.dataio.sampler import ReproducibleRandomSampler
+from speechbrain.utils.profiling import prepare_profiler
+from dataclasses import dataclass
 
 logger = logging.getLogger(__name__)
 DEFAULT_LOG_CONFIG = os.path.dirname(os.path.abspath(__file__))
@@ -51,6 +52,83 @@ INTRA_EPOCH_CKPT_FLAG = "brain_intra_epoch_ckpt"
 PYTHON_VERSION_MAJOR = 3
 PYTHON_VERSION_MINOR = 7
 
+# Arguments passed via the run opts dictionary
+run_opt_defaults = {
+    "test_only": False,
+    "debug": False,
+    "debug_batches": 2,
+    "debug_epochs": 2,
+    "debug_persistently": False,
+    "device": "cpu",
+    "data_parallel_backend": False,
+    "distributed_backend": "nccl",
+    "find_unused_parameters": False,
+    "jit": False,
+    "jit_module_keys": None,
+    "compile": False,
+    "compile_module_keys": None,
+    "compile_mode": "reduce-overhead",
+    "compile_using_fullgraph": False,
+    "compile_using_dynamic_shape_tracing": False,
+    "precision": "fp32",
+    "eval_precision": "fp32",
+    "auto_mix_prec": False,
+    "bfloat16_mix_prec": False,
+    "max_grad_norm": 5.0,
+    "skip_nonfinite_grads": False,
+    "nonfinite_patience": 3,
+    "noprogressbar": False,
+    "ckpt_interval_minutes": 0,
+    "ckpt_interval_steps": 0,
+    "grad_accumulation_factor": 1,
+    "optimizer_step_limit": None,
+    "tqdm_colored_bar": False,
+    "tqdm_barcolor": {"train": "GREEN", "valid": "MAGENTA", "test": "CYAN"},
+    "remove_vector_weight_decay": False,
+    "profile_training": False,
+    "profile_warmup": 5,
+    "profile_steps": 5,
+}
+
+
+@dataclass
+class AMPConfig:
+    """Configuration for automatic mixed precision (AMP).
+
+    Arguments
+    ---------
+    dtype : torch.dtype
+        The dtype to use for AMP.
+    """
+
+    dtype: torch.dtype
+
+    @classmethod
+    def from_name(self, name):
+        """Create an AMPConfig from a string name.
+
+        Arguments
+        ---------
+        name : str
+            The name of the AMPConfig to create.  Must be one of `fp32`,
+            `fp16`, or `bf16`.
+
+        Returns
+        -------
+        AMPConfig
+            The AMPConfig corresponding to the name.
+        """
+        if name is None or name == "fp32":
+            return AMPConfig(torch.float32)
+        elif name == "fp16":
+            return AMPConfig(torch.float16)
+        elif name == "bf16":
+            return AMPConfig(torch.bfloat16)
+        else:
+            raise ValueError(
+                f"Specified autocast mode ({name}) incorrect, expected one of `fp32`, `fp16`, `bf16`."
+            )
+
 
 def create_experiment_directory(
     experiment_directory,
@@ -208,13 +286,6 @@ def parse_arguments(arg_list=None):
         type=str,
         help="A file storing the configuration options for logging",
     )
-    # if use_env = False in torch.distributed.lunch then local_rank arg is given
-    parser.add_argument(
-        "--local_rank",
-        "--local-rank",  # alias required for PyTorch 2.x
-        type=int,
-        help="Rank on local machine",
-    )
     parser.add_argument(
         "--device",
         type=str,
@@ -227,13 +298,6 @@ def parse_arguments(arg_list=None):
         action="store_true",
         help="This flag enables training with data_parallel.",
     )
-    parser.add_argument(
-        "--distributed_launch",
-        default=False,
-        action="store_true",
-        help="This flag enables training with DDP. Assumes script run with "
-        "`torch.distributed.launch`",
-    )
     parser.add_argument(
         "--distributed_backend",
         type=str,
@@ -295,6 +359,18 @@ def parse_arguments(arg_list=None):
         nargs="*",
         help="Use dynamic shape tracing for compilation",
     )
+    parser.add_argument(
+        "--precision",
+        type=str,
+        help="This flag enables training with automatic mixed-precision."
+        "It can be set to `fp32`, `fp16`, or `bf16`.",
+    )
+    parser.add_argument(
+        "--eval_precision",
+        type=str,
+        help="This flag enables inference with automatic mixed-precision."
+        "It can be set to `fp32`, `fp16`, or `bf16`.",
+    )
     parser.add_argument(
         "--auto_mix_prec",
         default=None,
@@ -313,6 +389,12 @@ def parse_arguments(arg_list=None):
         help="Gradient norm will be clipped to this value, "
         "enter negative value to disable.",
     )
+    parser.add_argument(
+        "--skip_nonfinite_grads",
+        default=False,
+        action="store_true",
+        help="Set the gradients to None if they are nonfinite (inf or nan).",
+    )
     parser.add_argument(
         "--nonfinite_patience",
         type=int,
@@ -359,6 +441,27 @@ def parse_arguments(arg_list=None):
         action="store_true",
         help="Make vectors (e.g. norms and biases) a separate parameter group without weight_decay.",
     )
+    parser.add_argument(
+        "--profile_training",
+        default=False,
+        action="store_true",
+        help=(
+            "If set to True, a profiler will be initiated and tensorboard logs will be generated. "
+            "Please ensure you have installed the TensorBoard profiler with 'pip install torch_tb_profiler'."
+        ),
+    )
+    parser.add_argument(
+        "--profile_warmup",
+        default=5,
+        type=int,
+        help="Number of warmup steps before logging for the profiler.",
+    )
+    parser.add_argument(
+        "--profile_steps",
+        default=5,
+        type=int,
+        help="Number of steps of logging for the profiler",
+    )
 
     # Accept extra args to override yaml
     run_opts, overrides = parser.parse_known_args(arg_list)
@@ -376,17 +479,8 @@ def parse_arguments(arg_list=None):
         if torch.cuda.device_count() == 0:
             raise ValueError("You must have at least 1 GPU.")
 
-    # For DDP, the device args must equal to local_rank used by
-    # torch.distributed.launch. If run_opts["local_rank"] exists,
-    # use os.environ["LOCAL_RANK"]
-    local_rank = None
-    if "local_rank" in run_opts:
-        local_rank = run_opts["local_rank"]
-    else:
-        if "LOCAL_RANK" in os.environ and os.environ["LOCAL_RANK"] != "":
-            local_rank = int(os.environ["LOCAL_RANK"])
-
-    # force device arg to be the same as local_rank from torch.distributed.lunch
+    # force device arg to be the same as local_rank from torchrun
+    local_rank = os.environ.get("LOCAL_RANK")
     if local_rank is not None and "cuda" in run_opts["device"]:
         run_opts["device"] = run_opts["device"][:-1] + str(local_rank)
 
@@ -491,12 +585,24 @@ class Brain:
             One of ``nccl``, ``gloo``, ``mpi``.
         device (str)
             The location for performing computations.
+        precision (str)
+            One of ``fp32``, ``fp16``, ``bf16``.
+        eval_precision (str)
+            One of ``fp32``, ``fp16``, ``bf16``.
         auto_mix_prec (bool)
-            If ``True``, automatic mixed-precision is used.
-            Activate it only with cuda.
+            If ``True``, automatic mixed-precision (fp16) is used.
+            Activate it only with cuda. Note: this is a
+            deprecated feature, and will be removed in the future.
+        bfloat16_mix_prec (bool)
+            If ``True``, automatic mixed-precision (bf16) is used.
+            Activate it only with cuda. Note: this is a
+            deprecated feature, and will be removed in the future.
         max_grad_norm (float)
             Default implementation of ``fit_batch()`` uses
             ``clip_grad_norm_`` with this value. Default: ``5``.
+        skip_nonfinite_grads (bool)
+            If ``True``, sets gradients to zero if they are non-finite
+            (e.g., NaN, Inf). Default: ``False``.
         nonfinite_patience (int)
             Number of times to ignore non-finite losses before stopping.
             Default: ``3``.
@@ -517,9 +623,6 @@ class Brain:
     checkpointer : speechbrain.Checkpointer
         By default, this will be used to load checkpoints, and will have the
         optimizer added to continue training if interrupted.
-    profiler : torch.profiler.profile
-        Context manager for profiling and benchmarking of training/inference steps.
-        Default: ``None`` (skip profiling).
 
     Example
     -------
@@ -541,48 +644,10 @@ class Brain:
         hparams=None,
         run_opts=None,
         checkpointer=None,
-        profiler=None,
     ):
+        self.optimizers_dict = None
         self.opt_class = opt_class
         self.checkpointer = checkpointer
-        self.profiler = profiler
-
-        # Arguments passed via the run opts dictionary
-        run_opt_defaults = {
-            "test_only": False,
-            "debug": False,
-            "debug_batches": 2,
-            "debug_epochs": 2,
-            "debug_persistently": False,
-            "device": "cpu",
-            "data_parallel_backend": False,
-            "distributed_launch": False,
-            "distributed_backend": "nccl",
-            "find_unused_parameters": False,
-            "jit": False,
-            "jit_module_keys": None,
-            "compile": False,
-            "compile_module_keys": None,
-            "compile_mode": "reduce-overhead",
-            "compile_using_fullgraph": False,
-            "compile_using_dynamic_shape_tracing": False,
-            "auto_mix_prec": False,
-            "bfloat16_mix_prec": False,
-            "max_grad_norm": 5.0,
-            "nonfinite_patience": 3,
-            "noprogressbar": False,
-            "ckpt_interval_minutes": 0,
-            "ckpt_interval_steps": 0,
-            "grad_accumulation_factor": 1,
-            "optimizer_step_limit": None,
-            "tqdm_colored_bar": False,
-            "tqdm_barcolor": {
-                "train": "GREEN",
-                "valid": "MAGENTA",
-                "test": "CYAN",
-            },
-            "remove_vector_weight_decay": False,
-        }
 
         for arg, default in run_opt_defaults.items():
             if run_opts is not None and arg in run_opts:
@@ -621,30 +686,25 @@ class Brain:
                 + str(PYTHON_VERSION_MINOR)
             )
 
+        # Assume `torchrun` was used if `RANK` and `LOCAL_RANK` are set
+        self.distributed_launch = (
+            os.environ.get("RANK") is not None
+            and os.environ.get("LOCAL_RANK") is not None
+        )
+
         if self.data_parallel_backend and self.distributed_launch:
             raise ValueError(
                 "To use data_parallel backend, start your script with:\n\t"
                 "python experiment.py hyperparams.yaml "
-                "--data_parallel_backend=True"
+                "--data_parallel_backend=True\n"
                 "To use DDP backend, start your script with:\n\t"
-                "python -m torch.distributed.lunch [args]\n"
-                "experiment.py hyperparams.yaml --distributed_launch=True "
-                "--distributed_backend=nccl"
-            )
-
-        if self.distributed_launch and self.ckpt_interval_minutes > 0:
-            logger.warning(
-                "The --ckpt_interval_minutes option saves only on the main "
-                "process to avoid race conditions. If you need to save an "
-                "intra-epoch checkpoint on multiple processes (e.g. FSDP), "
-                "consider switching to intervals based on # of steps with the "
-                "argument --ckpt_interval_steps."
+                "torchrun [args] experiment.py hyperparams.yaml"
             )
 
         if self.ckpt_interval_minutes > 0 and self.ckpt_interval_steps > 0:
             sys.exit(
                 "The options `ckpt_interval_minutes` and `ckpt_interval_steps` "
-                "are mutually exclusive to prevent race conditions. "
+                "are mutually exclusive. "
                 "Please keep only one active per experiment run."
             )
 
@@ -690,11 +750,53 @@ class Brain:
         # to have your_sampler.set_epoch() called on each epoch.
         self.train_sampler = None
 
-        # Automatic mixed precision init
         if self.auto_mix_prec:
-            self.scaler = torch.cuda.amp.GradScaler()
-            if self.checkpointer is not None:
-                self.checkpointer.add_recoverable("scaler", self.scaler)
+            logger.warning(
+                "The option `--auto_mix_prec` is deprecated and will be removed in the future. "
+                "Please use `--precision=fp16` instead."
+            )
+            self.precision = "fp16"
+
+        if self.bfloat16_mix_prec:
+            logger.warning(
+                "The option `--bfloat16_mix_prec` is deprecated and will be removed in the future. "
+                "Please use `--precision=bf16` instead."
+            )
+            self.precision = "bf16"
+
+        if self.device == "cpu" and (
+            self.precision == "fp16" or self.eval_precision == "fp16"
+        ):
+            raise ValueError(
+                "The option `--precision` or `--eval_precision` is set to fp16. "
+                "This option is not yet supported on CPU. "
+                "Please use `--precision=bf16` or `--eval_precision=bf16` instead "
+                "to enable mixed precision on CPU."
+            )
+
+        gradscaler_enabled = self.precision == "fp16" and "cuda" in self.device
+        if self.skip_nonfinite_grads and gradscaler_enabled:
+            logger.warning(
+                "The option `skip_nonfinite_grads` will be ignored "
+                "because GradScaler is enabled and will automatically "
+                "skip nonfinite gradients."
+            )
+
+        logger.info(
+            f"Gradscaler enabled: {gradscaler_enabled}. Using precision: {self.precision}."
+        )
+        self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
+
+        self.use_amp = False
+        if self.device == "cpu" and self.precision == "bf16":
+            self.use_amp = True
+        elif "cuda" in self.device and self.precision in ["fp16", "bf16"]:
+            self.use_amp = True
+
+        if self.use_amp and self.checkpointer is not None:
+            self.checkpointer.add_recoverable(
+                "scaler", self.scaler, optional_load=True
+            )
 
         # List parameter count for the user
         total_params = sum(
@@ -713,9 +815,7 @@ class Brain:
                         " ================ WARNING ==============="
                         "Please add sb.ddp_init_group() into your exp.py"
                         "To use DDP backend, start your script with:\n\t"
-                        "python -m torch.distributed.launch [args]\n\t"
-                        "experiment.py hyperparams.yaml "
-                        "--distributed_launch=True --distributed_backend=nccl"
+                        "torchrun [args] experiment.py hyperparams.yaml"
                     )
                 else:
                     logger.warning(
@@ -730,7 +830,6 @@ class Brain:
         # Prepare iterating variables
         self.avg_train_loss = 0.0
         self.step = 0
-        self.valid_step = 0
         self.optimizer_step = 0
 
         # Add this class to the checkpointer for intra-epoch checkpoints
@@ -741,6 +840,17 @@ class Brain:
         if not self.tqdm_colored_bar:
             self.tqdm_barcolor = dict.fromkeys(self.tqdm_barcolor, "")
 
+        # Profiler setup
+        self.profiler = None
+        if self.profile_training:
+            logger.info("Pytorch profiler has been activated.")
+            self.tot_prof_steps = (self.profile_steps + self.profile_warmup) - 1
+            self.profiler = prepare_profiler(
+                self.profile_warmup,
+                self.profile_steps,
+                self.hparams.output_folder,
+            )
+
     def compute_forward(self, batch, stage):
         """Forward pass, to be overridden by sub-classes.
 
@@ -968,9 +1078,7 @@ class Brain:
 
         # Load latest checkpoint to resume training if interrupted
         if self.checkpointer is not None:
-            self.checkpointer.recover_if_possible(
-                device=torch.device(self.device)
-            )
+            self.checkpointer.recover_if_possible()
 
     def init_optimizers(self):
         """Called during ``on_fit_start()``, initialize optimizers
@@ -992,6 +1100,8 @@ class Brain:
 
             self.optimizer = self.opt_class(all_params)
 
+            self.optimizers_dict = {"opt_class": self.optimizer}
+
             if self.checkpointer is not None:
                 self.checkpointer.add_recoverable("optimizer", self.optimizer)
 
@@ -1002,8 +1112,11 @@ class Brain:
         Setting gradients to None should save the memory, e.g.
         during ``evaluate()`` and thus larger batch might be used.
         """
-        if hasattr(self, "optimizer"):
-            self.optimizer.zero_grad(set_to_none)
+        if self.optimizers_dict is not None:
+            for opt in self.freeze_optimizers(self.optimizers_dict).values():
+                opt.zero_grad(set_to_none=set_to_none)
+        elif self.opt_class is not None:
+            self.optimizer.zero_grad(set_to_none=set_to_none)
 
     def on_evaluate_start(self, max_key=None, min_key=None):
         """Gets called at the beginning of ``evaluate()``
@@ -1024,9 +1137,7 @@ class Brain:
         # Recover best checkpoint for evaluation
         if self.checkpointer is not None:
             self.checkpointer.recover_if_possible(
-                max_key=max_key,
-                min_key=min_key,
-                device=torch.device(self.device),
+                max_key=max_key, min_key=min_key,
             )
 
     def fit_batch(self, batch):
@@ -1037,6 +1148,7 @@ class Brain:
 
         * ``compute_forward()``
         * ``compute_objectives()``
+        * ``optimizers_step()``
 
         Also depends on having optimizers passed at initialization.
 
@@ -1050,101 +1162,54 @@ class Brain:
         -------
         detached loss
         """
-        valid_loss = False
+        amp = AMPConfig.from_name(self.precision)
+        should_step = (self.step % self.grad_accumulation_factor) == 0
 
-        # Managing automatic mixed precision
-        if self.auto_mix_prec:
-            with torch.autocast(device_type=torch.device(self.device).type):
-                outputs = self.compute_forward(batch, Stage.TRAIN)
-
-            # Losses are excluded from mixed precision to avoid instabilities
-            loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
-
-            if self.check_gradients(loss):
-                valid_loss = True
-                self.valid_step += 1
-
-            should_step = self.valid_step % self.grad_accumulation_factor == 0
-            if valid_loss:
-                with self.no_sync(not should_step):
-                    self.scaler.scale(
-                        loss / self.grad_accumulation_factor
-                    ).backward()
-                if should_step:
-                    self.scaler.unscale_(self.optimizer)
-                    self.scaler.step(self.optimizer)
-                    self.scaler.update()
-                    self.zero_grad()
-                    self.optimizer_step += 1
-        else:
-            if self.bfloat16_mix_prec:
+        with self.no_sync(not should_step):
+            if self.use_amp:
                 with torch.autocast(
-                    device_type=torch.device(self.device).type,
-                    dtype=torch.bfloat16,
+                    dtype=amp.dtype, device_type=torch.device(self.device).type,
                 ):
-                    outputs = self.compute_forward(batch, Stage.TRAIN)
-                    loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
+                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
+                    loss = self.compute_objectives(
+                        outputs, batch, sb.Stage.TRAIN
+                    )
             else:
-                outputs = self.compute_forward(batch, Stage.TRAIN)
-                loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
-
-            if self.check_gradients(loss):
-                valid_loss = True
-                self.valid_step += 1
-
-            should_step = self.valid_step % self.grad_accumulation_factor == 0
-            if valid_loss:
-                with self.no_sync(not should_step):
-                    (loss / self.grad_accumulation_factor).backward()
-                if should_step:
-                    self.optimizer.step()
-                    self.zero_grad()
-                    self.optimizer_step += 1
+                outputs = self.compute_forward(batch, sb.Stage.TRAIN)
+                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
+
+            scaled_loss = self.scaler.scale(
+                loss / self.grad_accumulation_factor
+            )
+            self.check_loss_isfinite(scaled_loss)
+            scaled_loss.backward()
+
+        if should_step:
+            self.optimizers_step()
 
         self.on_fit_batch_end(batch, outputs, loss, should_step)
         return loss.detach().cpu()
 
-    def on_fit_batch_end(self, batch, outputs, loss, should_step):
-        """Called after ``fit_batch()``, meant for calculating and logging metrics.
+    def check_loss_isfinite(self, loss):
+        """Check if the loss is finite.
 
-        Arguments
-        ---------
-        batch : list of torch.Tensors
-            Batch of data to use for training. Default implementation assumes
-            this batch has two elements: inputs and targets.
-        outputs : list or dictionary of torch.Tensors
-            Returned value of compute_forward().
-        loss : torch.Tensor
-            Returned value of compute_objectives().
-        should_step : boolean
-            Whether optimizer.step() was called or not.
-        """
-        pass
+        If the loss is not finite, log a helpful message and increment the `nonfinite_count`.
+        If the `nonfinite_count` exceeds the `--nonfinite_patience` threshold, stop the training
+        and raise an error.
 
-    def check_gradients(self, loss):
-        """Check if gradients are finite and not too large.
-        Automatically clips large gradients.
+        This check is particularly useful when the loss becomes NaN or inf, while the
+        parameters and gradients remain finite. It helps prevent getting stuck in an
+        infinite loop during training.
 
         Arguments
         ---------
         loss : tensor
             The loss tensor after ``backward()`` has been called but
             before the optimizers ``step()``.
-
-        Returns
-        -------
-        bool
-            Whether or not the optimizer step should be carried out.
         """
         if not torch.isfinite(loss):
             self.nonfinite_count += 1
 
-            # Print helpful debug info
-            logger.warning(f"Loss is {loss}.")
-            for p in self.modules.parameters():
-                if not torch.isfinite(p).all():
-                    logger.warning("Parameter is not finite: " + str(p))
-
             # Check if patience is exhausted
             if self.nonfinite_count > self.nonfinite_patience:
                 raise ValueError(
@@ -1154,18 +1219,103 @@ class Brain:
                     "torch.autograd.detect_anomaly():\n\tbrain.fit(...)"
                 )
             else:
-                logger.warning(
-                    "Patience not yet exhausted, ignoring this batch."
-                )
-                return False
+                logger.warning("Patience not yet exhausted.")
+
+    def check_gradients(self):
+        """ Checks if the gradients are finite. If not, it will emit a warning and set them to zero."""
+        for param in self.modules.parameters():
+            if param.requires_grad and param.grad is not None:
+                if not torch.isfinite(param.grad).all():
+                    param.grad = None
+                    logger.warning(
+                        f"Gradients {param.name} contain NaN or Inf. Setting to None."
+                    )
+
+    def freeze_optimizers(self, optimizers):
+        """By default, this method returns the passed optimizers.
+        Override this method if you want to freeze some optimizers
+        during training. To do so, return a of active optimizers.
+        """
+        return optimizers
+
+    def optimizers_step(self):
+        """Performs a step of gradient descent on the optimizers. This method is called every
+        ``grad_accumulation_factor`` steps."""
+        # 1. get the valid optimizers, i.e., the ones that are not frozen during this step
+        if self.optimizers_dict is not None:
+            valid_optimizers = self.freeze_optimizers(self.optimizers_dict)
+        elif self.opt_class is not None:
+            # if valid_optimizers is not defined which could happen if a user is using an old
+            # init_optimizers() method, then we assume that the only valid optimizer is
+            # self.optimizer (which is the default behavior).
+            valid_optimizers = {"optimizer": self.optimizer}
+        else:
+            # Note: in some cases you might want to only compute gradients statistics and
+            # you do not need to call the optimizers.step() method. In this case, you can
+            # simply return from this method and skip the rest of the code.
+            return
+
+        # 2. unscale the gradients of the valid optimizers
+        for opt in valid_optimizers.values():
+            self.scaler.unscale_(opt)
 
-        if self.max_grad_norm > 0.0:
+        # 3. clip gradients
+        # We are clipping this way because clipping on self.modules.parameters()
+        # can leads to NaN/Inf gradients norm as doing the concatenation
+        # of all parameters in a single vector can lead to overflow/underflow.
+        for opt in valid_optimizers.values():
             torch.nn.utils.clip_grad_norm_(
-                (p for p in self.modules.parameters()), self.max_grad_norm
+                opt.param_groups[0]["params"], self.max_grad_norm
             )
 
-        return True
+        # Note: no need to activate this flag if you are in fp16
+        # since GradScaler is automatically handling the nonfinite gradients
+        if not self.scaler.is_enabled() and self.skip_nonfinite_grads:
+            self.check_gradients()
+
+        # 4. step the valid optimizers
+        # If the scaler is disable, it simply calls optimizer.step()
+        for opt in valid_optimizers.values():
+            self.scaler.step(opt)
+
+        self.scaler.update()
+
+        for opt in valid_optimizers.values():
+            opt.zero_grad(set_to_none=True)
+
+        self.optimizer_step += 1
+
+    def on_fit_batch_start(self, batch, should_step):
+        """Called at the beginning of ``fit_batch()``.
+
+        Arguments
+        ---------
+        batch : list of torch.Tensors
+            Batch of data to use for training. Default implementation assumes
+            this batch has two elements: inputs and targets.
+        should_step : boolean
+            Whether optimizer.step() was called or not.
+        """
+        pass
+
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """Called after ``fit_batch()``.
+
+        Arguments
+        ---------
+        batch : list of torch.Tensors
+            Batch of data to use for training. Default implementation assumes
+            this batch has two elements: inputs and targets.
+        outputs : list or dictionary of torch.Tensors
+            Returned value of compute_forward().
+        loss : torch.Tensor
+            Returned value of compute_objectives().
+        should_step : boolean
+            Whether optimizer.step() was called or not.
+        """
+        pass
 
+    @torch.no_grad()
     def evaluate_batch(self, batch, stage):
         """Evaluate one batch, override for different procedure than train.
 
@@ -1187,9 +1337,16 @@ class Brain:
         -------
         detached loss
         """
-
-        out = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(out, batch, stage=stage)
+        amp = AMPConfig.from_name(self.eval_precision)
+        if self.use_amp:
+            with torch.autocast(
+                dtype=amp.dtype, device_type=torch.device(self.device).type,
+            ):
+                out = self.compute_forward(batch, stage=stage)
+                loss = self.compute_objectives(out, batch, stage=stage)
+        else:
+            out = self.compute_forward(batch, stage=stage)
+            loss = self.compute_objectives(out, batch, stage=stage)
         return loss.detach().cpu()
 
     def _fit_train(self, train_set, epoch, enable):
@@ -1216,6 +1373,8 @@ class Brain:
             disable=not enable,
             colour=self.tqdm_barcolor["train"],
         ) as t:
+            if self.profiler is not None:
+                self.profiler.start()
             for batch in t:
                 if self._optimizer_step_limit_exceeded:
                     logger.info("Train iteration limit exceeded")
@@ -1228,10 +1387,14 @@ class Brain:
                 )
                 t.set_postfix(train_loss=self.avg_train_loss)
 
-                # Profile only if desired (steps allow the profiler to know when all is warmed up)
                 if self.profiler is not None:
-                    if self.profiler.record_steps:
-                        self.profiler.step()
+                    self.profiler.step()
+                    if self.profiler.step_num > self.tot_prof_steps:
+                        logger.info(
+                            "The profiler finished, training is stopped."
+                        )
+                        self.profiler.stop()
+                        quit()
 
                 # Debug mode only runs a few batches
                 if self.debug and self.step == self.debug_batches:
@@ -1250,7 +1413,6 @@ class Brain:
         self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
         self.avg_train_loss = 0.0
         self.step = 0
-        self.valid_step = 0
 
     def _should_save_intra_epoch_ckpt(self, last_ckpt_time, steps_since_ckpt):
         """Determines if an intra-epoch checkpoint should be saved.
@@ -1260,15 +1422,28 @@ class Brain:
         if self.checkpointer is None:
             return False
 
+        # Return early if mid-epoch checkpoints are disabled to avoid sync
+        if self.ckpt_interval_minutes <= 0 and self.ckpt_interval_steps <= 0:
+            return False
+
         # Check if we've run for the requested amount of time
-        if 0 < self.ckpt_interval_minutes * 60.0 < time.time() - last_ckpt_time:
-            # Only save on the main process to avoid race conditions
-            return if_main_process()
+        elapsed_minutes = (time.time() - last_ckpt_time) / 60.0
+        decision = 0 < self.ckpt_interval_minutes < elapsed_minutes
+
+        # Save after requested # of steps
+        decision = decision or 0 < self.ckpt_interval_steps <= steps_since_ckpt
 
-        # Save after requested # of steps. This option is the only one that
-        # allows saving on multiple processes. The logic for whether saving
-        # is run only on the main process is handled by the checkpointer.
-        return 0 < self.ckpt_interval_steps <= steps_since_ckpt
+        # If the program is not distributed, just return
+        if not torch.distributed.is_initialized():
+            return decision
+
+        # Otherwise, broadcast decision to all processes from main (rank 0)
+        # This solves synchronization issues where main gets a different
+        # timing result than the other processes.
+        else:
+            broadcast_list = [decision]
+            torch.distributed.broadcast_object_list(broadcast_list, src=0)
+            return broadcast_list[0]
 
     def _fit_valid(self, valid_set, epoch, enable):
         # Validation stage
@@ -1287,11 +1462,6 @@ class Brain:
                     loss = self.evaluate_batch(batch, stage=Stage.VALID)
                     avg_valid_loss = self.update_average(loss, avg_valid_loss)
 
-                    # Profile only if desired (steps allow the profiler to know when all is warmed up)
-                    if self.profiler is not None:
-                        if self.profiler.record_steps:
-                            self.profiler.step()
-
                     # Debug mode only runs a few batches
                     if self.debug and self.step == self.debug_batches:
                         break
@@ -1562,11 +1732,6 @@ class Brain:
                 loss = self.evaluate_batch(batch, stage=Stage.TEST)
                 avg_test_loss = self.update_average(loss, avg_test_loss)
 
-                # Profile only if desired (steps allow the profiler to know when all is warmed up)
-                if self.profiler is not None:
-                    if self.profiler.record_steps:
-                        self.profiler.step()
-
                 # Debug mode only runs a few batches
                 if self.debug and self.step == self.debug_batches:
                     break
@@ -1614,16 +1779,16 @@ class Brain:
             for module in self.modules.values():
                 if not hasattr(module, "require_backward_grad_sync"):
                     # if not using DDP
-                    break
+                    continue
                 old_values_list.append(module.require_backward_grad_sync)
                 module.require_backward_grad_sync = False
             yield
-            for module, old_value in zip(
-                self.modules.values(), old_values_list
-            ):
+            i = 0
+            for module in self.modules.values():
                 if not hasattr(module, "require_backward_grad_sync"):
-                    break
-                module.require_backward_grad_sync = old_value
+                    continue
+                module.require_backward_grad_sync = old_values_list[i]
+                i += 1
         else:
             yield
 
@@ -1638,9 +1803,8 @@ class Brain:
             w.write(yaml.dump(save_dict))
 
     @sb.utils.checkpoints.mark_as_loader
-    def _recover(self, path, end_of_epoch, device):
+    def _recover(self, path, end_of_epoch):
         del end_of_epoch
-        del device
         with open(path) as f:
             save_dict = yaml.safe_load(f)
         self.step = save_dict["step"]
diff --git a/speechbrain/dataio/dataloader.py b/speechbrain/dataio/dataloader.py
index 02c87e0445aed80daaa527e437f7b0a46cac3443..008772ed46e74eea8a9e15f96e7de87fa3a4b6b0 100644
--- a/speechbrain/dataio/dataloader.py
+++ b/speechbrain/dataio/dataloader.py
@@ -310,8 +310,7 @@ class SaveableDataLoader(DataLoader):
             fo.write(str(to_save))
 
     @mark_as_loader
-    def _speechbrain_load(self, path, end_of_epoch, device=None):
-        del device  # Unused here
+    def _speechbrain_load(self, path, end_of_epoch):
         if self._speechbrain_iterator is not None:
             logging.debug(
                 "SaveableDataLoader was requested to load a "
@@ -394,9 +393,8 @@ class LoopedLoader:
             print(self.total_samples, file=fo)
 
     @mark_as_loader
-    def load(self, path, end_of_epoch=True, device=None):
+    def load(self, path, end_of_epoch=True):
         """Loads the needed information."""
-        del device  # Unused here
         with open(path) as fi:
             self.step = int(fi.readline().strip())
             self.total_steps = int(fi.readline().strip())
diff --git a/speechbrain/dataio/dataset.py b/speechbrain/dataio/dataset.py
index 851b60fd26b776b773644e585c37fe7352f613c0..321eb9d51ae0f95480a7f4fdf18b0cade7ee928f 100644
--- a/speechbrain/dataio/dataset.py
+++ b/speechbrain/dataio/dataset.py
@@ -460,14 +460,24 @@ def set_output_keys(datasets, output_keys):
         dataset.set_output_keys(output_keys)
 
 
-def apply_overfit_test(hparams, dataset):
+def apply_overfit_test(
+    overfit_test,
+    overfit_test_sample_count,
+    overfit_test_epoch_data_count,
+    dataset,
+):
     """Applies the overfit test to the specified dataset,
     as configured in the hyperparameters file
 
     Arguments
     ---------
-    hparams: dict
-        the hyperparameters dictionary
+
+    overfit_test: bool
+        when True the overfitting test is performed
+    overfit_test_sample_count: int
+        number of samples for the overfitting test
+    overfit_test_epoch_data_count: int
+        number of epochs for the overfitting test
 
     dataset: DynamicItemDataset
         the dataset
@@ -477,8 +487,8 @@ def apply_overfit_test(hparams, dataset):
     dataset: DynamicItemDataset
         the dataset, with the overfit test apply
     """
-    if hparams["overfit_test"]:
-        sample_count = hparams["overfit_test_sample_count"]
-        epoch_data_count = hparams["overfit_test_epoch_data_count"]
+    if overfit_test:
+        sample_count = overfit_test_sample_count
+        epoch_data_count = overfit_test_epoch_data_count
         dataset = dataset.overfit_test(sample_count, epoch_data_count)
     return dataset
diff --git a/speechbrain/dataio/encoder.py b/speechbrain/dataio/encoder.py
index 9da8d7840837b3fd68cc0b020ff463dcf8edd8f8..76502a5258bbe4aa2da8a0e8c9852a6cdf468a4e 100644
--- a/speechbrain/dataio/encoder.py
+++ b/speechbrain/dataio/encoder.py
@@ -9,7 +9,6 @@ import torch
 import collections
 import itertools
 import logging
-import warnings
 import speechbrain as sb
 from speechbrain.utils.checkpoints import (
     mark_as_saver,
@@ -613,7 +612,7 @@ class CategoricalEncoder:
         logger.debug(f"Loaded categorical encoding from {path}")
 
     @mark_as_loader
-    def load_if_possible(self, path, end_of_epoch=False, device=None):
+    def load_if_possible(self, path, end_of_epoch=False):
         """Loads if possible, returns a bool indicating if loaded or not.
 
         Arguments
@@ -644,7 +643,6 @@ class CategoricalEncoder:
         ['a', 'b', 'c', 'd']
         """
         del end_of_epoch  # Unused here.
-        del device  # Unused here.
 
         try:
             self.load(path)
@@ -719,7 +717,7 @@ class CategoricalEncoder:
                     f"but {real_len} categories found"
                 )
         else:
-            warnings.warn(
+            logger.debug(
                 f"{self.__class__.__name__}.expect_len was never called: "
                 f"assuming category count of {len(self)} to be correct! "
                 "Sanity check your encoder using `.expect_len`. "
@@ -1172,3 +1170,25 @@ class CTCTextEncoder(TextEncoder):
         super()._set_extras(extras)
         if "blank_label" in extras:
             self.blank_label = extras["blank_label"]
+
+
+def load_text_encoder_tokens(model_path):
+    """Loads the encoder tokens from a pretrained model.
+
+    This method is useful when you used with a pretrained HF model.
+    It will load the tokens in the yaml and then you will be able
+    to instantiate any CTCBaseSearcher directly in the YAML file.
+
+    Arguments
+    ---------
+    model_path : str, Path
+        Path to the pretrained model.
+
+    Returns
+    -------
+    list
+        List of tokens.
+    """
+    label_encoder = TextEncoder()
+    label_encoder.load(model_path)
+    return list(label_encoder.lab2ind.keys())
diff --git a/speechbrain/dataio/preprocess.py b/speechbrain/dataio/preprocess.py
index 47d4338bf3eee024282eda4163560a1a74042e5d..4cee8a9e2b0a4265bf0e449e2c9a8e341bcbf496 100644
--- a/speechbrain/dataio/preprocess.py
+++ b/speechbrain/dataio/preprocess.py
@@ -1,6 +1,6 @@
 """Preprocessors for audio"""
 import torch
-from speechbrain.processing.speech_augmentation import Resample
+from speechbrain.augment.time_domain import Resample
 
 
 class AudioNormalizer:
diff --git a/speechbrain/dataio/sampler.py b/speechbrain/dataio/sampler.py
index ef1e5a92e0ef7f1c111452cbbd2ac96970e5f081..cdda2ae496b6d607141214abdadd18c2cae90377 100644
--- a/speechbrain/dataio/sampler.py
+++ b/speechbrain/dataio/sampler.py
@@ -7,7 +7,8 @@ Authors:
   * Samuele Cornell 2020
   * Ralf Leibold 2020
   * Artem Ploujnikov 2021
-  * Andreas Nautsch 2021
+  * Andreas Nautsch 2021, 2023
+  * Adel Moumen 2023
 """
 import torch
 import logging
@@ -485,7 +486,13 @@ class DynamicBatchSampler(Sampler):
         self._max_batch_ex = max_batch_ex
         # Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length?
         self._bucket_lens = [
-            max(1, int(max_batch_length / self._bucket_boundaries[i]))
+            min(
+                self._max_batch_ex,  # tops max_duration_per_batch
+                max(
+                    1,  # and at least 1
+                    int(self._max_batch_length / self._bucket_boundaries[i]),
+                ),
+            )
             for i in range(len(self._bucket_boundaries))
         ] + [1]
         self._epoch = epoch
diff --git a/speechbrain/decoders/__init__.py b/speechbrain/decoders/__init__.py
index dfe31da3b093756d0e8b2eaaacfeb85af18c5ed0..1eb613f764513af012a4b090ce462d2b4f2f951e 100644
--- a/speechbrain/decoders/__init__.py
+++ b/speechbrain/decoders/__init__.py
@@ -1,4 +1,7 @@
 """ Package containing the different decoders (ctc, beamsearch ...)
 """
-from .seq2seq import *  # noqa
+
 from .ctc import *  # noqa
+from .seq2seq import *  # noqa
+from .transducer import *  # noqa
+from .scorer import *  # noqa
diff --git a/speechbrain/decoders/ctc.py b/speechbrain/decoders/ctc.py
index f52743de02358e21b3cd2b941f98df3f1ec71ed3..dd4a92ab06ba7dcd66f75a7aed65008ddb6149f8 100644
--- a/speechbrain/decoders/ctc.py
+++ b/speechbrain/decoders/ctc.py
@@ -4,14 +4,24 @@ Authors
  * Mirco Ravanelli 2020
  * Aku Rouhe 2020
  * Sung-Lin Yeh 2020
+ * Adel Moumen 2023, 2024
 """
-import torch
 from itertools import groupby
 from speechbrain.dataio.dataio import length_to_mask
+import math
+import dataclasses
+import numpy as np
+import heapq
+import logging
+import torch
+import warnings
+from typing import Dict, List, Optional, Union, Any, Tuple
 
+logger = logging.getLogger(__name__)
 
-class CTCPrefixScorer:
-    """This class implements the CTC prefix scorer of Algorithm 2 in
+
+class CTCPrefixScore:
+    """This class implements the CTC prefix score of Algorithm 2 in
     reference: https://www.merl.com/publications/docs/TR2017-190.pdf.
     Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
 
@@ -35,25 +45,18 @@ class CTCPrefixScorer:
     """
 
     def __init__(
-        self,
-        x,
-        enc_lens,
-        batch_size,
-        beam_size,
-        blank_index,
-        eos_index,
-        ctc_window_size=0,
+        self, x, enc_lens, blank_index, eos_index, ctc_window_size=0,
     ):
         self.blank_index = blank_index
         self.eos_index = eos_index
+        self.batch_size = x.size(0)
         self.max_enc_len = x.size(1)
-        self.batch_size = batch_size
-        self.beam_size = beam_size
         self.vocab_size = x.size(-1)
         self.device = x.device
         self.minus_inf = -1e20
         self.last_frame_index = enc_lens - 1
         self.ctc_window_size = ctc_window_size
+        self.prefix_length = -1
 
         # mask frames > enc_lens
         mask = 1 - length_to_mask(enc_lens)
@@ -72,39 +75,38 @@ class CTCPrefixScorer:
         # (2, L, batch_size * beam_size, vocab_size)
         self.x = torch.stack([xnb, xb])
 
-        # The first index of each sentence.
-        self.beam_offset = (
-            torch.arange(batch_size, device=self.device) * self.beam_size
-        )
-        # The first index of each candidates.
-        self.cand_offset = (
-            torch.arange(batch_size, device=self.device) * self.vocab_size
-        )
+        # indices of batch.
+        self.batch_index = torch.arange(self.batch_size, device=self.device)
 
-    def forward_step(self, g, state, candidates=None, attn=None):
+    @torch.no_grad()
+    def forward_step(self, inp_tokens, states, candidates=None, attn=None):
         """This method if one step of forwarding operation
         for the prefix ctc scorer.
 
         Arguments
         ---------
-        g : torch.Tensor
-            The tensor of prefix label sequences, h = g + c.
-        state : tuple
+        inp_tokens : torch.Tensor
+            The last chars of prefix label sequences g, where h = g + c.
+        states : tuple
             Previous ctc states.
         candidates : torch.Tensor
             (batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring.
-            The ctc_beam_size is set as 2 * beam_size. If given, performing partial ctc scoring.
+            If given, performing partial ctc scoring.
+        attn : torch.Tensor
+            (batch_size * beam_size, max_enc_len), The attention weights.
         """
 
-        prefix_length = g.size(1)
-        last_char = [gi[-1] for gi in g] if prefix_length > 0 else [0] * len(g)
+        n_bh = inp_tokens.size(0)
+        beam_size = n_bh // self.batch_size
+        last_char = inp_tokens
+        self.prefix_length += 1
         self.num_candidates = (
             self.vocab_size if candidates is None else candidates.size(-1)
         )
-        if state is None:
+        if states is None:
             # r_prev: (L, 2, batch_size * beam_size)
             r_prev = torch.full(
-                (self.max_enc_len, 2, self.batch_size, self.beam_size),
+                (self.max_enc_len, 2, self.batch_size, beam_size),
                 self.minus_inf,
                 device=self.device,
             )
@@ -113,64 +115,59 @@ class CTCPrefixScorer:
             r_prev[:, 1] = torch.cumsum(
                 self.x[0, :, :, self.blank_index], 0
             ).unsqueeze(2)
-            r_prev = r_prev.view(-1, 2, self.batch_size * self.beam_size)
-            psi_prev = 0.0
+            r_prev = r_prev.view(-1, 2, n_bh)
+            psi_prev = torch.full(
+                (n_bh, self.vocab_size), 0.0, device=self.device,
+            )
         else:
-            r_prev, psi_prev = state
+            r_prev, psi_prev = states
 
         # for partial search
         if candidates is not None:
+            # The first index of each candidate.
+            cand_offset = self.batch_index * self.vocab_size
             scoring_table = torch.full(
-                (self.batch_size * self.beam_size, self.vocab_size),
+                (n_bh, self.vocab_size),
                 -1,
                 dtype=torch.long,
                 device=self.device,
             )
             # Assign indices of candidates to their positions in the table
-            col_index = torch.arange(
-                self.batch_size * self.beam_size, device=self.device
-            ).unsqueeze(1)
+            col_index = torch.arange(n_bh, device=self.device).unsqueeze(1)
             scoring_table[col_index, candidates] = torch.arange(
                 self.num_candidates, device=self.device
             )
             # Select candidates indices for scoring
             scoring_index = (
                 candidates
-                + self.cand_offset.unsqueeze(1)
-                .repeat(1, self.beam_size)
-                .view(-1, 1)
+                + cand_offset.unsqueeze(1).repeat(1, beam_size).view(-1, 1)
             ).view(-1)
             x_inflate = torch.index_select(
                 self.x.view(2, -1, self.batch_size * self.vocab_size),
                 2,
                 scoring_index,
-            ).view(2, -1, self.batch_size * self.beam_size, self.num_candidates)
+            ).view(2, -1, n_bh, self.num_candidates)
         # for full search
         else:
             scoring_table = None
+            # Inflate x to (2, -1, batch_size * beam_size, num_candidates)
+            # It is used to compute forward probs in a batched way
             x_inflate = (
                 self.x.unsqueeze(3)
-                .repeat(1, 1, 1, self.beam_size, 1)
-                .view(
-                    2, -1, self.batch_size * self.beam_size, self.num_candidates
-                )
+                .repeat(1, 1, 1, beam_size, 1)
+                .view(2, -1, n_bh, self.num_candidates)
             )
 
         # Prepare forward probs
         r = torch.full(
-            (
-                self.max_enc_len,
-                2,
-                self.batch_size * self.beam_size,
-                self.num_candidates,
-            ),
+            (self.max_enc_len, 2, n_bh, self.num_candidates,),
             self.minus_inf,
             device=self.device,
         )
         r.fill_(self.minus_inf)
 
         # (Alg.2-6)
-        if prefix_length == 0:
+        if self.prefix_length == 0:
             r[0, 0] = x_inflate[0, 0]
         # (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g)
         r_sum = torch.logsumexp(r_prev, 1)
@@ -178,24 +175,24 @@ class CTCPrefixScorer:
 
         # (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0
         if candidates is not None:
-            for i in range(self.batch_size * self.beam_size):
+            for i in range(n_bh):
                 pos = scoring_table[i, last_char[i]]
                 if pos != -1:
                     phi[:, i, pos] = r_prev[:, 1, i]
         else:
-            for i in range(self.batch_size * self.beam_size):
+            for i in range(n_bh):
                 phi[:, i, last_char[i]] = r_prev[:, 1, i]
 
         # Start, end frames for scoring (|g| < |h|).
         # Scoring based on attn peak if ctc_window_size > 0
         if self.ctc_window_size == 0 or attn is None:
-            start = max(1, prefix_length)
+            start = max(1, self.prefix_length)
             end = self.max_enc_len
         else:
             _, attn_peak = torch.max(attn, dim=1)
             max_frame = torch.max(attn_peak).item() + self.ctc_window_size
             min_frame = torch.min(attn_peak).item() - self.ctc_window_size
-            start = max(max(1, prefix_length), int(min_frame))
+            start = max(max(1, self.prefix_length), int(min_frame))
             end = min(self.max_enc_len, int(max_frame))
 
         # Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)):
@@ -205,7 +202,7 @@ class CTCPrefixScorer:
             # (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank)
             rb_prev = r[t - 1, 1]
             r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view(
-                2, 2, self.batch_size * self.beam_size, self.num_candidates
+                2, 2, n_bh, self.num_candidates
             )
             r[t] = torch.logsumexp(r_, 1) + x_inflate[:, t]
 
@@ -216,15 +213,13 @@ class CTCPrefixScorer:
         # (Alg.2-13): psi = psi + phi * p(c)
         if candidates is not None:
             psi = torch.full(
-                (self.batch_size * self.beam_size, self.vocab_size),
-                self.minus_inf,
-                device=self.device,
+                (n_bh, self.vocab_size), self.minus_inf, device=self.device,
             )
             psi_ = torch.logsumexp(
                 torch.cat((phix[start:end], psi_init), dim=0), dim=0
             )
             # only assign prob to candidates
-            for i in range(self.batch_size * self.beam_size):
+            for i in range(n_bh):
                 psi[i, candidates[i]] = psi_[i]
         else:
             psi = torch.logsumexp(
@@ -232,13 +227,14 @@ class CTCPrefixScorer:
             )
 
         # (Alg.2-3): if c = <eos>, psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames
-        for i in range(self.batch_size * self.beam_size):
+        for i in range(n_bh):
             psi[i, self.eos_index] = r_sum[
-                self.last_frame_index[i // self.beam_size], i
+                self.last_frame_index[i // beam_size], i
             ]
 
-        # Exclude blank probs for joint scoring
-        psi[:, self.blank_index] = self.minus_inf
+        if self.eos_index != self.blank_index:
+            # Exclude blank probs for joint scoring
+            psi[:, self.blank_index] = self.minus_inf
 
         return psi - psi_prev, (r, psi, scoring_table)
 
@@ -258,38 +254,41 @@ class CTCPrefixScorer:
         The variable of the memory being permuted.
 
         """
+
         r, psi, scoring_table = memory
-        # The index of top-K vocab came from in (t-1) timesteps.
-        best_index = (
-            index
-            + (self.beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size)
-        ).view(-1)
+
+        beam_size = index.size(1)
+        n_bh = self.batch_size * beam_size
+
+        # The first index of each batch.
+        beam_offset = self.batch_index * beam_size
+        # The index of top-K vocab came from in (t-1) timesteps at batch * beam * vocab dimension.
+        cand_index = (
+            index + beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size
+        ).view(n_bh)
         # synchronize forward prob
-        psi = torch.index_select(psi.view(-1), dim=0, index=best_index)
+        psi = torch.index_select(psi.view(-1), dim=0, index=cand_index)
         psi = (
             psi.view(-1, 1)
             .repeat(1, self.vocab_size)
-            .view(self.batch_size * self.beam_size, self.vocab_size)
+            .view(n_bh, self.vocab_size)
         )
-
+        # The index of top-K vocab came from in (t-1) timesteps at batch * beam dimension.
+        hyp_index = (
+            torch.div(index, self.vocab_size, rounding_mode="floor")
+            + beam_offset.unsqueeze(1).expand_as(index)
+        ).view(n_bh)
         # synchronize ctc states
         if scoring_table is not None:
-            effective_index = (
-                index // self.vocab_size + self.beam_offset.view(-1, 1)
-            ).view(-1)
             selected_vocab = (index % self.vocab_size).view(-1)
-            score_index = scoring_table[effective_index, selected_vocab]
+            score_index = scoring_table[hyp_index, selected_vocab]
             score_index[score_index == -1] = 0
-            best_index = score_index + effective_index * self.num_candidates
+            cand_index = score_index + hyp_index * self.num_candidates
 
         r = torch.index_select(
-            r.view(
-                -1, 2, self.batch_size * self.beam_size * self.num_candidates
-            ),
-            dim=-1,
-            index=best_index,
+            r.view(-1, 2, n_bh * self.num_candidates), dim=-1, index=cand_index,
         )
-        r = r.view(-1, 2, self.batch_size * self.beam_size)
+        r = r.view(-1, 2, n_bh)
 
         return r, psi
 
@@ -338,7 +337,7 @@ def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1):
     ---------
     probabilities : torch.tensor
         Output probabilities (or log-probabilities) from the network with shape
-        [batch, probabilities, time]
+        [batch, lengths, probabilities]
     seq_lens : torch.tensor
         Relative true sequence lengths (to deal with padded inputs),
         the longest sequence has length 1.0, others a value between zero and one
@@ -374,3 +373,1858 @@ def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1):
         out = filter_ctc_output(predictions.tolist(), blank_id=blank_id)
         batch_outputs.append(out)
     return batch_outputs
+
+
+@dataclasses.dataclass
+class CTCBeam:
+    """This class handle the CTC beam informations during decoding.
+
+    Arguments
+    ---------
+    text : str
+        The current text of the beam.
+    full_text : str
+        The full text of the beam.
+    next_word : str
+        The next word to be added to the beam.
+    partial_word : str
+        The partial word being added to the beam.
+    last_token : str, optional
+        The last token of the beam.
+    last_token_index : int, optional
+        The index of the last token of the beam.
+    text_frames : List[Tuple[int, int]]
+        The start and end frame of the text.
+    partial_frames : Tuple[int, int]
+        The start and end frame of the partial word.
+    p : float
+        The probability of the beam.
+    p_b : float
+        The probability of the beam ending in a blank.
+    p_nb : float
+        The probability of the beam not ending in a blank.
+    n_p_b : float
+        The previous probability of the beam ending in a blank.
+    n_p_nb : float
+        The previous probability of the beam not ending in a blank.
+    score : float
+        The score of the beam (LM + CTC)
+    score_ctc : float
+        The CTC score computed.
+
+    Example
+    -------
+    >>> beam = CTCBeam(
+    ...     text="",
+    ...     full_text="",
+    ...     next_word="",
+    ...     partial_word="",
+    ...     last_token=None,
+    ...     last_token_index=None,
+    ...     text_frames=[(0, 0)],
+    ...     partial_frames=(0, 0),
+    ...     p=-math.inf,
+    ...     p_b=-math.inf,
+    ...     p_nb=-math.inf,
+    ...     n_p_b=-math.inf,
+    ...     n_p_nb=-math.inf,
+    ...     score=-math.inf,
+    ...     score_ctc=-math.inf,
+    ... )
+    """
+
+    text: str
+    full_text: str
+    next_word: str
+    partial_word: str
+    last_token: Optional[str]
+    last_token_index: Optional[int]
+    text_frames: List[Tuple[int, int]]
+    partial_frames: Tuple[int, int]
+    p: float = -math.inf
+    p_b: float = -math.inf
+    p_nb: float = -math.inf
+    n_p_b: float = -math.inf
+    n_p_nb: float = -math.inf
+    score: float = -math.inf
+    score_ctc: float = -math.inf
+
+    @classmethod
+    def from_lm_beam(self, lm_beam: "LMCTCBeam") -> "CTCBeam":
+        """Create a CTCBeam from a LMCTCBeam
+
+        Arguments
+        ---------
+        lm_beam : LMCTCBeam
+            The LMCTCBeam to convert.
+
+        Returns
+        -------
+        CTCBeam
+            The CTCBeam converted.
+        """
+        return CTCBeam(
+            text=lm_beam.text,
+            full_text=lm_beam.full_text,
+            next_word=lm_beam.next_word,
+            partial_word=lm_beam.partial_word,
+            last_token=lm_beam.last_token,
+            last_token_index=lm_beam.last_token_index,
+            text_frames=lm_beam.text_frames,
+            partial_frames=lm_beam.partial_frames,
+            p=lm_beam.p,
+            p_b=lm_beam.p_b,
+            p_nb=lm_beam.p_nb,
+            n_p_b=lm_beam.n_p_b,
+            n_p_nb=lm_beam.n_p_nb,
+            score=lm_beam.score,
+            score_ctc=lm_beam.score_ctc,
+        )
+
+    def step(self) -> None:
+        """Update the beam probabilities."""
+        self.p_b, self.p_nb = self.n_p_b, self.n_p_nb
+        self.n_p_b = self.n_p_nb = -math.inf
+        self.score_ctc = np.logaddexp(self.p_b, self.p_nb)
+        self.score = self.score_ctc
+
+
+@dataclasses.dataclass
+class LMCTCBeam(CTCBeam):
+    """This class handle the LM scores during decoding.
+
+    Arguments
+    ---------
+    lm_score: float
+        The LM score of the beam.
+    **kwargs
+        See CTCBeam for the other arguments.
+    """
+
+    lm_score: float = -math.inf
+
+
+@dataclasses.dataclass
+class CTCHypothesis:
+    """This class is a data handler over the generated hypotheses.
+
+    This class is the default output of the CTC beam searchers.
+
+    It can be re-used for other decoders if using
+    the beam searchers in an online fashion.
+
+    Arguments
+    ---------
+    text : str
+        The text of the hypothesis.
+    last_lm_state : None
+        The last LM state of the hypothesis.
+    score : float
+        The score of the hypothesis.
+    lm_score : float
+        The LM score of the hypothesis.
+    text_frames : List[Tuple[str, Tuple[int, int]]], optional
+        The list of the text and the corresponding frames.
+    """
+
+    text: str
+    last_lm_state: None
+    score: float
+    lm_score: float
+    text_frames: list = None
+
+
+class CTCBaseSearcher(torch.nn.Module):
+    """CTCBaseSearcher class to be inherited by other
+    CTC beam searchers.
+
+    This class provides the basic functionalities for
+    CTC beam search decoding.
+
+    The space_token is required with a non-sentencepiece vocabulary list
+    if your transcription is expecting to contain spaces.
+
+    Arguments
+    ---------
+    blank_index : int
+        The index of the blank token.
+    vocab_list : list
+        The list of the vocabulary tokens.
+    space_token : int, optional
+        The index of the space token. (default: -1)
+    kenlm_model_path : str, optional
+        The path to the kenlm model. Use .bin for a faster loading.
+        If None, no language model will be used. (default: None)
+    unigrams : list, optional
+        The list of known word unigrams. (default: None)
+    alpha : float
+        Weight for language model during shallow fusion. (default: 0.5)
+    beta : float
+        Weight for length score adjustment of during scoring. (default: 1.5)
+    unk_score_offset : float
+        Amount of log score offset for unknown tokens. (default: -10.0)
+    score_boundary : bool
+        Whether to have kenlm respect boundaries when scoring. (default: True)
+    beam_size : int, optional
+        The width of the beam. (default: 100)
+    beam_prune_logp : float, optional
+        The pruning threshold for the beam. (default: -10.0)
+    token_prune_min_logp : float, optional
+        The pruning threshold for the tokens. (default: -5.0)
+    prune_history : bool, optional
+        Whether to prune the history. (default: True)
+        Note: when using topk > 1, this should be set to False as
+        it is pruning a lot of beams.
+    blank_skip_threshold : float, optional
+        Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
+        Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk. (default: 1.0)
+    topk : int, optional
+        The number of top hypotheses to return. (default: 1)
+    spm_token: str, optional
+        The sentencepiece token. (default: "▁")
+
+    Example
+    -------
+    >>> blank_index = 0
+    >>> vocab_list = ['blank', 'a', 'b', 'c', ' ']
+    >>> space_token = ' '
+    >>> kenlm_model_path = None
+    >>> unigrams = None
+    >>> beam_size = 100
+    >>> beam_prune_logp = -10.0
+    >>> token_prune_min_logp = -5.0
+    >>> prune_history = True
+    >>> blank_skip_threshold = 1.0
+    >>> topk = 1
+    >>> searcher = CTCBaseSearcher(
+    ...     blank_index=blank_index,
+    ...     vocab_list=vocab_list,
+    ...     space_token=space_token,
+    ...     kenlm_model_path=kenlm_model_path,
+    ...     unigrams=unigrams,
+    ...     beam_size=beam_size,
+    ...     beam_prune_logp=beam_prune_logp,
+    ...     token_prune_min_logp=token_prune_min_logp,
+    ...     prune_history=prune_history,
+    ...     blank_skip_threshold=blank_skip_threshold,
+    ...     topk=topk,
+    ... )
+    """
+
+    def __init__(
+        self,
+        blank_index: int,
+        vocab_list: List[str],
+        space_token: str = " ",
+        kenlm_model_path: Union[None, str] = None,
+        unigrams: Union[None, List[str]] = None,
+        alpha: float = 0.5,
+        beta: float = 1.5,
+        unk_score_offset: float = -10.0,
+        score_boundary: bool = True,
+        beam_size: int = 100,
+        beam_prune_logp: int = -10.0,
+        token_prune_min_logp: int = -5.0,
+        prune_history: bool = True,
+        blank_skip_threshold: Union[None, int] = 1.0,
+        topk: int = 1,
+        spm_token: str = "▁",
+    ):
+        super().__init__()
+
+        self.blank_index = blank_index
+        self.vocab_list = vocab_list
+        self.space_token = space_token
+        self.kenlm_model_path = kenlm_model_path
+        self.unigrams = unigrams
+        self.alpha = alpha
+        self.beta = beta
+        self.unk_score_offset = unk_score_offset
+        self.score_boundary = score_boundary
+        self.beam_size = beam_size
+        self.beam_prune_logp = beam_prune_logp
+        self.token_prune_min_logp = token_prune_min_logp
+        self.prune_history = prune_history
+        self.blank_skip_threshold = math.log(blank_skip_threshold)
+        self.topk = topk
+        self.spm_token = spm_token
+
+        # check if the vocab is coming from SentencePiece
+        self.is_spm = any(
+            [str(s).startswith(self.spm_token) for s in vocab_list]
+        )
+
+        # fetch the index of space_token
+        if not self.is_spm:
+            try:
+                self.space_index = vocab_list.index(space_token)
+            except ValueError:
+                logger.warning(
+                    f"space_token `{space_token}` not found in the vocabulary."
+                    "Using value -1 as `space_index`."
+                    "Note: If your transcription is not expected to contain spaces, "
+                    "you can ignore this warning."
+                )
+                self.space_index = -1
+            logger.info(f"Found `space_token` at index {self.space_index}.")
+
+        self.kenlm_model = None
+        if kenlm_model_path is not None:
+            try:
+                import kenlm  # type: ignore
+
+                from speechbrain.decoders.language_model import (
+                    LanguageModel,
+                    load_unigram_set_from_arpa,
+                )
+            except ImportError:
+                raise ImportError(
+                    "kenlm python bindings are not installed. To install it use: "
+                    "pip install https://github.com/kpu/kenlm/archive/master.zip"
+                )
+
+            self.kenlm_model = kenlm.Model(kenlm_model_path)
+
+        if kenlm_model_path is not None and kenlm_model_path.endswith(".arpa"):
+            logger.info(
+                "Using arpa instead of binary LM file, decoder instantiation might be slow."
+            )
+
+        if unigrams is None and kenlm_model_path is not None:
+            if kenlm_model_path.endswith(".arpa"):
+                unigrams = load_unigram_set_from_arpa(kenlm_model_path)
+            else:
+                logger.warning(
+                    "Unigrams not provided and cannot be automatically determined from LM file (only "
+                    "arpa format). Decoding accuracy might be reduced."
+                )
+
+        if self.kenlm_model is not None:
+            self.lm = LanguageModel(
+                kenlm_model=self.kenlm_model,
+                unigrams=unigrams,
+                alpha=self.alpha,
+                beta=self.beta,
+                unk_score_offset=self.unk_score_offset,
+                score_boundary=self.score_boundary,
+            )
+        else:
+            self.lm = None
+
+    def partial_decoding(
+        self,
+        log_probs: torch.Tensor,
+        beams: List[CTCBeam],
+        cached_lm_scores: dict,
+        cached_p_lm_scores: dict,
+        processed_frames: int = 0,
+    ):
+        """Perform a single step of decoding.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the CTC output.
+        beams : list
+            The list of the beams.
+        cached_lm_scores : dict
+            The cached language model scores.
+        cached_p_lm_scores : dict
+            The cached prefix language model scores.
+        processed_frames : int, default: 0
+            The start frame of the current decoding step.
+        """
+        raise NotImplementedError
+
+    def normalize_whitespace(self, text: str) -> str:
+        """Efficiently normalize whitespace.
+
+        Arguments
+        ---------
+        text : str
+            The text to normalize.
+
+        Returns
+        -------
+        str
+            The normalized text.
+        """
+        return " ".join(text.split())
+
+    def merge_tokens(self, token_1: str, token_2: str) -> str:
+        """Merge two tokens, and avoid empty ones.
+
+        Taken from: https://github.com/kensho-technologies/pyctcdecode
+
+        Arguments
+        ---------
+        token_1 : str
+            The first token.
+        token_2 : str
+            The second token.
+
+        Returns
+        -------
+        str
+            The merged token.
+        """
+        if len(token_2) == 0:
+            text = token_1
+        elif len(token_1) == 0:
+            text = token_2
+        else:
+            text = token_1 + " " + token_2
+        return text
+
+    def merge_beams(self, beams: List[CTCBeam]) -> List[CTCBeam]:
+        """Merge beams with the same text.
+
+        Taken from: https://github.com/kensho-technologies/pyctcdecode
+
+        Arguments
+        ---------
+        beams : list
+            The list of the beams.
+
+        Returns
+        -------
+        list
+            The list of CTCBeam merged.
+        """
+        beam_dict = {}
+        for beam in beams:
+            new_text = self.merge_tokens(beam.text, beam.next_word)
+            hash_idx = (new_text, beam.partial_word, beam.last_token)
+            if hash_idx not in beam_dict:
+                beam_dict[hash_idx] = beam
+            else:
+                # We've already seen this text - we want to combine the scores
+                beam_dict[hash_idx] = dataclasses.replace(
+                    beam,
+                    score=np.logaddexp(beam_dict[hash_idx].score, beam.score),
+                )
+        return list(beam_dict.values())
+
+    def sort_beams(self, beams: List[CTCBeam]) -> List[CTCBeam]:
+        """Sort beams by lm_score.
+
+        Arguments
+        ---------
+        beams : list
+            The list of CTCBeam.
+
+        Returns
+        -------
+        list
+            The list of CTCBeam sorted.
+        """
+        return heapq.nlargest(self.beam_size, beams, key=lambda x: x.lm_score)
+
+    def _prune_history(
+        self, beams: List[CTCBeam], lm_order: int
+    ) -> List[CTCBeam]:
+        """Filter out beams that are the same over max_ngram history.
+
+        Since n-gram language models have a finite history when scoring a new token, we can use that
+        fact to prune beams that only differ early on (more than n tokens in the past) and keep only the
+        higher scoring ones. Note that this helps speed up the decoding process but comes at the cost of
+        some amount of beam diversity. If more than the top beam is used in the output it should
+        potentially be disabled.
+
+        Taken from: https://github.com/kensho-technologies/pyctcdecode
+
+        Arguments
+        ---------
+        beams : list
+            The list of the beams.
+        lm_order : int
+            The order of the language model.
+
+        Returns
+        -------
+        list
+            The list of CTCBeam.
+        """
+        # let's keep at least 1 word of history
+        min_n_history = max(1, lm_order - 1)
+        seen_hashes = set()
+        filtered_beams = []
+        # for each beam after this, check if we need to add it
+        for lm_beam in beams:
+            # hash based on history that can still affect lm scoring going forward
+            hash_idx = (
+                tuple(lm_beam.text.split()[-min_n_history:]),
+                lm_beam.partial_word,
+                lm_beam.last_token,
+            )
+            if hash_idx not in seen_hashes:
+                filtered_beams.append(CTCBeam.from_lm_beam(lm_beam))
+                seen_hashes.add(hash_idx)
+        return filtered_beams
+
+    def finalize_decoding(
+        self,
+        beams: List[CTCBeam],
+        cached_lm_scores: dict,
+        cached_p_lm_scores: dict,
+        force_next_word=False,
+        is_end=False,
+    ) -> List[CTCBeam]:
+        """Finalize the decoding process by adding and scoring the last partial word.
+
+        Arguments
+        ---------
+        beams : list
+            The list of CTCBeam.
+        cached_lm_scores : dict
+            The cached language model scores.
+        cached_p_lm_scores : dict
+            The cached prefix language model scores.
+        force_next_word : bool, default: False
+            Whether to force the next word.
+        is_end : bool, default: False
+            Whether the end of the sequence has been reached.
+
+        Returns
+        -------
+        list
+            The list of the CTCBeam.
+        """
+        if force_next_word or is_end:
+            new_beams = []
+            for beam in beams:
+                new_token_times = (
+                    beam.text_frames
+                    if beam.partial_word == ""
+                    else beam.text_frames + [beam.partial_frames]
+                )
+                new_beams.append(
+                    CTCBeam(
+                        text=beam.text,
+                        full_text=beam.full_text,
+                        next_word=beam.partial_word,
+                        partial_word="",
+                        last_token=None,
+                        last_token_index=None,
+                        text_frames=new_token_times,
+                        partial_frames=(-1, -1),
+                        score=beam.score,
+                    )
+                )
+
+            new_beams = self.merge_beams(new_beams)
+        else:
+            new_beams = list(beams)
+
+        scored_beams = self.get_lm_beams(
+            new_beams, cached_lm_scores, cached_p_lm_scores,
+        )
+        # remove beam outliers
+        max_score = max([b.lm_score for b in scored_beams])
+        scored_beams = [
+            b
+            for b in scored_beams
+            if b.lm_score >= max_score + self.beam_prune_logp
+        ]
+
+        sorted_beams = self.sort_beams(scored_beams)
+        return sorted_beams
+
+    def decode_beams(
+        self,
+        log_probs: torch.Tensor,
+        wav_lens: Optional[torch.Tensor] = None,
+        lm_start_state: Any = None,
+    ) -> List[List[CTCHypothesis]]:
+        """Decodes the input log probabilities of the CTC output.
+
+        It automatically converts the SpeechBrain's relative length of the wav input
+        to the absolute length.
+
+        Make sure that the input are in the log domain. The decoder will fail to decode
+        logits or probabilities. The input should be the log probabilities of the CTC output.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the CTC output.
+            The expected shape is [batch_size, seq_length, vocab_size].
+        wav_lens : torch.Tensor, optional (default: None)
+            The SpeechBrain's relative length of the wav input.
+        lm_start_state : Any, optional (default: None)
+            The start state of the language model.
+
+        Returns
+        -------
+        list of list
+            The list of topk list of CTCHypothesis.
+        """
+        # check that the last dimension of log_probs is equal to the vocab size
+        if log_probs.size(2) != len(self.vocab_list):
+            warnings.warn(
+                f"Vocab size mismatch: log_probs vocab dim is {log_probs.size(2)} "
+                f"while vocab_list is {len(self.vocab_list)}. "
+                "During decoding, going to truncate the log_probs vocab dim to match vocab_list."
+            )
+
+        # compute wav_lens and cast to numpy as it is faster
+        if wav_lens is not None:
+            wav_lens = log_probs.size(1) * wav_lens
+            wav_lens = wav_lens.cpu().numpy().astype(int)
+        else:
+            wav_lens = [log_probs.size(1)] * log_probs.size(0)
+
+        log_probs = log_probs.cpu().numpy()
+
+        hyps = [
+            self.decode_log_probs(log_prob, wav_len, lm_start_state)
+            for log_prob, wav_len in zip(log_probs, wav_lens)
+        ]
+        return hyps
+
+    def __call__(
+        self,
+        log_probs: torch.Tensor,
+        wav_lens: Optional[torch.Tensor] = None,
+        lm_start_state: Any = None,
+    ) -> List[List[CTCHypothesis]]:
+        """Decodes the log probabilities of the CTC output.
+
+        It automatically converts the SpeechBrain's relative length of the wav input
+        to the absolute length.
+
+        Each tensors is converted to numpy and CPU as it is faster and consummes less memory.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the CTC output.
+            The expected shape is [batch_size, seq_length, vocab_size].
+        wav_lens : torch.Tensor, optional (default: None)
+            The SpeechBrain's relative length of the wav input.
+        lm_start_state : Any, optional (default: None)
+            The start state of the language model.
+
+        Returns
+        -------
+        list of list
+            The list of topk list of CTCHypothesis.
+        """
+        return self.decode_beams(log_probs, wav_lens, lm_start_state)
+
+    def partial_decode_beams(
+        self,
+        log_probs: torch.Tensor,
+        cached_lm_scores: dict,
+        cached_p_lm_scores: dict,
+        beams: List[CTCBeam],
+        processed_frames: int,
+        force_next_word=False,
+        is_end=False,
+    ) -> List[CTCBeam]:
+        """ Perform a single step of decoding.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the CTC output.
+        cached_lm_scores : dict
+            The cached language model scores.
+        cached_p_lm_scores : dict
+            The cached prefix language model scores.
+        beams : list
+            The list of the beams.
+        processed_frames : int
+            The start frame of the current decoding step.
+        force_next_word : bool, optional (default: False)
+            Whether to force the next word.
+        is_end : bool, optional (default: False)
+            Whether the end of the sequence has been reached.
+
+        Returns
+        -------
+        list
+            The list of CTCBeam.
+        """
+        beams = self.partial_decoding(
+            log_probs,
+            beams,
+            cached_lm_scores,
+            cached_p_lm_scores,
+            processed_frames=processed_frames,
+        )
+
+        trimmed_beams = self.finalize_decoding(
+            beams,
+            cached_lm_scores,
+            cached_p_lm_scores,
+            force_next_word=force_next_word,
+            is_end=is_end,
+        )
+
+        return trimmed_beams
+
+    def decode_log_probs(
+        self,
+        log_probs: torch.Tensor,
+        wav_len: int,
+        lm_start_state: Optional[Any] = None,
+    ) -> List[CTCHypothesis]:
+        """Decodes the log probabilities of the CTC output.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the CTC output.
+            The expected shape is [seq_length, vocab_size].
+        wav_len : int
+            The length of the wav input.
+        lm_start_state : Any, optional (default: None)
+            The start state of the language model.
+
+        Returns
+        -------
+        list
+            The topk list of CTCHypothesis.
+        """
+        # prepare caching/state for language model
+        language_model = self.lm
+        if language_model is None:
+            cached_lm_scores = {}
+        else:
+            if lm_start_state is None:
+                start_state = language_model.get_start_state()
+            else:
+                start_state = lm_start_state
+            cached_lm_scores = {("", False): (0.0, start_state)}
+        cached_p_lm_scores: Dict[str, float] = {}
+
+        beams = [
+            CTCBeam(
+                text="",
+                full_text="",
+                next_word="",
+                partial_word="",
+                last_token=None,
+                last_token_index=None,
+                text_frames=[],
+                partial_frames=(-1, -1),
+                score=0.0,
+                score_ctc=0.0,
+                p_b=0.0,
+            )
+        ]
+
+        # loop over the frames and perform the decoding
+        beams = self.partial_decoding(
+            log_probs, wav_len, beams, cached_lm_scores, cached_p_lm_scores,
+        )
+
+        # finalize decoding by adding and scoring the last partial word
+        trimmed_beams = self.finalize_decoding(
+            beams,
+            cached_lm_scores,
+            cached_p_lm_scores,
+            force_next_word=True,
+            is_end=True,
+        )
+
+        # transform the beams into hypotheses and select the topk
+        output_beams = [
+            CTCHypothesis(
+                text=self.normalize_whitespace(lm_beam.text),
+                last_lm_state=(
+                    cached_lm_scores[(lm_beam.text, True)][-1]
+                    if (lm_beam.text, True) in cached_lm_scores
+                    else None
+                ),
+                text_frames=list(
+                    zip(lm_beam.text.split(), lm_beam.text_frames)
+                ),
+                score=lm_beam.score,
+                lm_score=lm_beam.lm_score,
+            )
+            for lm_beam in trimmed_beams
+        ][: self.topk]
+        return output_beams
+
+
+class CTCBeamSearcher(CTCBaseSearcher):
+    """CTC Beam Search is a Beam Search for CTC which does not keep track of
+    the blank and non-blank probabilities. Each new token probability is
+    added to the general score, and each beams that share the same text are
+    merged together.
+
+    The implementation suppors n-gram scoring on words and SentencePiece tokens. The input
+    is expected to be a log-probabilities tensor of shape [batch, time, vocab_size].
+
+    The main advantage of this CTCBeamSearcher over the CTCPrefixBeamSearcher is that it is
+    relatively faster, and obtains slightly better results. However, the implementation is
+    based on the one from the PyCTCDecode toolkit, adpated for the SpeechBrain's needs and does
+    not follow a specific paper. We do recommand to use the CTCPrefixBeamSearcher if you want
+    to cite the appropriate paper for the decoding method.
+
+    Several heuristics are implemented to speed up the decoding process:
+    - pruning of the beam : the beams are pruned if their score is lower than
+        the best beam score minus the beam_prune_logp
+    - pruning of the tokens : the tokens are pruned if their score is lower than
+        the token_prune_min_logp
+    - pruning of the history : the beams are pruned if they are the same over
+        max_ngram history
+    - skipping of the blank : the frame is skipped if the blank probability is
+        higher than the blank_skip_threshold
+
+    Note: if the Acoustic Model is not trained, the Beam Search will
+    take a lot of time. We do recommand to use Greedy Search during validation
+    until the model is fully trained and ready to be evaluated on test sets.
+
+    Arguments
+    ---------
+    **kwargs
+        see CTCBaseSearcher, arguments are directly passed.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.decoders import CTCBeamSearcher
+    >>> probs = torch.tensor([[[0.2, 0.0, 0.8],
+    ...                   [0.4, 0.0, 0.6]]])
+    >>> log_probs = torch.log(probs)
+    >>> lens = torch.tensor([1.0])
+    >>> blank_index = 2
+    >>> vocab_list = ['a', 'b', '-']
+    >>> searcher = CTCBeamSearcher(blank_index=blank_index, vocab_list=vocab_list)
+    >>> hyps = searcher(probs, lens)
+    """
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def get_lm_beams(
+        self,
+        beams: List[CTCBeam],
+        cached_lm_scores: dict,
+        cached_partial_token_scores: dict,
+        is_eos=False,
+    ) -> List[LMCTCBeam]:
+        """Score the beams with the language model if not None, and
+        return the new beams.
+
+        This function is modified and adapted from
+        https://github.com/kensho-technologies/pyctcdecode
+
+        Arguments
+        ---------
+        beams : list
+            The list of the beams.
+        cached_lm_scores : dict
+            The cached language model scores.
+        cached_partial_token_scores : dict
+            The cached partial token scores.
+        is_eos : bool (default: False)
+            Whether the end of the sequence has been reached.
+
+        Returns
+        -------
+        new_beams : list
+            The list of the new beams.
+        """
+        if self.lm is None:
+            # no lm is used, lm_score is equal to score and we can return the beams
+            new_beams = []
+            for beam in beams:
+                new_text = self.merge_tokens(beam.text, beam.next_word)
+                new_beams.append(
+                    LMCTCBeam(
+                        text=new_text,
+                        full_text=beam.full_text,
+                        next_word="",
+                        partial_word=beam.partial_word,
+                        last_token=beam.last_token,
+                        last_token_index=beam.last_token,
+                        text_frames=beam.text_frames,
+                        partial_frames=beam.partial_frames,
+                        score=beam.score,
+                        lm_score=beam.score,
+                    )
+                )
+            return new_beams
+        else:
+            # lm is used, we need to compute the lm_score
+            # first we compute the lm_score of the next word
+            # we check if the next word is in the cache
+            # if not, we compute the score and add it to the cache
+            new_beams = []
+            for beam in beams:
+                # fast token merge
+                new_text = self.merge_tokens(beam.text, beam.next_word)
+                cache_key = (new_text, is_eos)
+                if cache_key not in cached_lm_scores:
+                    prev_raw_lm_score, start_state = cached_lm_scores[
+                        (beam.text, False)
+                    ]
+                    score, end_state = self.lm.score(
+                        start_state, beam.next_word, is_last_word=is_eos
+                    )
+                    raw_lm_score = prev_raw_lm_score + score
+                    cached_lm_scores[cache_key] = (raw_lm_score, end_state)
+                lm_score, _ = cached_lm_scores[cache_key]
+
+                # we score the partial word
+                word_part = beam.partial_word
+                if len(word_part) > 0:
+                    if word_part not in cached_partial_token_scores:
+
+                        cached_partial_token_scores[
+                            word_part
+                        ] = self.lm.score_partial_token(word_part)
+                    lm_score += cached_partial_token_scores[word_part]
+
+                new_beams.append(
+                    LMCTCBeam(
+                        text=new_text,
+                        full_text=beam.full_text,
+                        next_word="",
+                        partial_word=word_part,
+                        last_token=beam.last_token,
+                        last_token_index=beam.last_token,
+                        text_frames=beam.text_frames,
+                        partial_frames=beam.partial_frames,
+                        score=beam.score,
+                        lm_score=beam.score + lm_score,
+                    )
+                )
+            return new_beams
+
+    def partial_decoding(
+        self,
+        log_probs: torch.Tensor,
+        wav_len: int,
+        beams: List[CTCBeam],
+        cached_lm_scores: dict,
+        cached_p_lm_scores: dict,
+        processed_frames: int = 0,
+    ) -> List[CTCBeam]:
+        """Perform CTC Prefix Beam Search decoding.
+
+        If self.lm is not None, the language model scores are computed and added to the CTC scores.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the CTC input.
+            Shape: (seq_length, vocab_size)
+        wav_len : int
+            The length of the input sequence.
+        beams : list
+            The list of CTCBeam objects.
+        cached_lm_scores : dict
+            The cached language model scores.
+        cached_p_lm_scores : dict
+            The cached prefix language model scores.
+        processed_frames : int
+            The start frame of the current decoding step. (default: 0)
+
+        Returns
+        -------
+        beams : list
+            The list of CTCBeam objects.
+        """
+        # select only the valid frames i.e. the frames that are not padded
+        log_probs = log_probs[:wav_len]
+
+        for frame_index, logit_col in enumerate(
+            log_probs, start=processed_frames
+        ):
+            # skip the frame if the blank probability is higher than the threshold
+            if logit_col[self.blank_index] > self.blank_skip_threshold:
+                continue
+
+            # get the tokens with the highest probability
+            max_index = logit_col.argmax()
+            tokens_index_list = set(
+                np.where(logit_col > self.token_prune_min_logp)[0]
+            ) | {max_index}
+            new_beams = []
+
+            # select tokens that are in the vocab
+            # this is useful if the logit vocab_size is larger than the vocab_list
+            tokens_index_list = tokens_index_list & set(
+                range(len(self.vocab_list))
+            )
+
+            for token_index in tokens_index_list:
+                p_token = logit_col[token_index]
+                token = self.vocab_list[token_index]
+
+                for beam in beams:
+
+                    if (
+                        token_index == self.blank_index
+                        or beam.last_token == token
+                    ):
+                        if token_index == self.blank_index:
+                            new_end_frame = beam.partial_frames[0]
+                        else:
+                            new_end_frame = frame_index + 1
+
+                        new_part_frames = (
+                            beam.partial_frames
+                            if token_index == self.blank_index
+                            else (beam.partial_frames[0], new_end_frame)
+                        )
+
+                        # if blank or repeated token, we only change the score
+                        new_beams.append(
+                            CTCBeam(
+                                text=beam.text,
+                                full_text=beam.full_text,
+                                next_word=beam.next_word,
+                                partial_word=beam.partial_word,
+                                last_token=token,
+                                last_token_index=token_index,
+                                text_frames=beam.text_frames,
+                                partial_frames=new_part_frames,
+                                score=beam.score + p_token,
+                            )
+                        )
+
+                    elif self.is_spm and token[:1] == self.spm_token:
+                        # remove the spm token at the beginning of the token
+                        clean_token = token[1:]
+
+                        new_frame_list = (
+                            beam.text_frames
+                            if beam.partial_word == ""
+                            else beam.text_frames + [beam.partial_frames]
+                        )
+
+                        # If the beginning of the token is the spm_token
+                        # then it means that we are extending the beam with a new word.
+                        # We need to change the new_word with the partial_word
+                        # and reset the partial_word with the new token
+                        new_beams.append(
+                            CTCBeam(
+                                text=beam.text,
+                                full_text=beam.full_text,
+                                next_word=beam.partial_word,
+                                partial_word=clean_token,
+                                last_token=token,
+                                last_token_index=token_index,
+                                text_frames=new_frame_list,
+                                partial_frames=(frame_index, frame_index + 1),
+                                score=beam.score + p_token,
+                            )
+                        )
+
+                    elif not self.is_spm and token_index == self.space_index:
+                        new_frame_list = (
+                            beam.text_frames
+                            if beam.partial_word == ""
+                            else beam.text_frames + [beam.partial_frames]
+                        )
+
+                        # same as before but in the case of a non spm vocab
+                        new_beams.append(
+                            CTCBeam(
+                                text=beam.text,
+                                full_text=beam.full_text,
+                                next_word=beam.partial_word,
+                                partial_word="",
+                                last_token=token,
+                                last_token_index=token_index,
+                                text_frames=new_frame_list,
+                                partial_frames=(-1, -1),
+                                score=beam.score + p_token,
+                            )
+                        )
+                    else:
+                        new_part_frames = (
+                            (frame_index, frame_index + 1)
+                            if beam.partial_frames[0] < 0
+                            else (beam.partial_frames[0], frame_index + 1)
+                        )
+
+                        # last case, we are extending the partial_word with a new token
+                        new_beams.append(
+                            CTCBeam(
+                                text=beam.text,
+                                full_text=beam.full_text,
+                                next_word=beam.next_word,
+                                partial_word=beam.partial_word + token,
+                                last_token=token,
+                                last_token_index=token_index,
+                                text_frames=beam.text_frames,
+                                partial_frames=new_part_frames,
+                                score=beam.score + p_token,
+                            )
+                        )
+
+            # we merge the beams with the same text
+            new_beams = self.merge_beams(new_beams)
+
+            # kenlm scoring
+            scored_beams = self.get_lm_beams(
+                new_beams, cached_lm_scores, cached_p_lm_scores,
+            )
+
+            # remove beam outliers
+            max_score = max([b.lm_score for b in scored_beams])
+            scored_beams = [
+                b
+                for b in scored_beams
+                if b.lm_score >= max_score + self.beam_prune_logp
+            ]
+
+            trimmed_beams = self.sort_beams(scored_beams)
+
+            if self.prune_history:
+                lm_order = 1 if self.lm is None else self.lm.order
+                beams = self._prune_history(trimmed_beams, lm_order=lm_order)
+            else:
+                beams = [CTCBeam.from_lm_beam(b) for b in trimmed_beams]
+
+        return beams
+
+
+class CTCPrefixBeamSearcher(CTCBaseSearcher):
+    """CTC Prefix Beam Search is based on the paper
+    `First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs`
+    by Awni Y. Hannun and al (https://arxiv.org/abs/1408.2873).
+
+    The implementation keep tracks of the blank and non-blank probabilities.
+    It also suppors n-gram scoring on words and SentencePiece tokens. The input
+    is expected to be a log-probabilities tensor of shape [batch, time, vocab_size].
+
+    Several heuristics are implemented to speed up the decoding process:
+    - pruning of the beam : the beams are pruned if their score is lower than
+        the best beam score minus the beam_prune_logp
+    - pruning of the tokens : the tokens are pruned if their score is lower than
+        the token_prune_min_logp
+    - pruning of the history : the beams are pruned if they are the same over
+        max_ngram history
+    - skipping of the blank : the frame is skipped if the blank probability is
+        higher than the blank_skip_threshold
+
+    Note: The CTCPrefixBeamSearcher can be more unstable than the CTCBeamSearcher
+    or the TorchAudioCTCPrefixBeamSearch searcher. Please, use it with caution
+    and check the results carefully.
+
+    Note: if the Acoustic Model is not trained, the Beam Search will
+    take a lot of time. We do recommand to use Greedy Search during validation
+    until the model is fully trained and ready to be evaluated on test sets.
+
+    Note: This implementation does not provide the time alignment of the
+    hypothesis. If you need it, please use the CTCBeamSearcher.
+
+    Arguments
+    ---------
+    **kwargs
+        see CTCBaseSearcher, arguments are directly passed.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.decoders import CTCPrefixBeamSearcher
+    >>> probs = torch.tensor([[[0.2, 0.0, 0.8],
+    ...                   [0.4, 0.0, 0.6]]])
+    >>> log_probs = torch.log(probs)
+    >>> lens = torch.tensor([1.0])
+    >>> blank_index = 2
+    >>> vocab_list = ['a', 'b', '-']
+    >>> searcher = CTCPrefixBeamSearcher(blank_index=blank_index, vocab_list=vocab_list)
+    >>> hyps = searcher(probs, lens)
+    """
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def get_lm_beams(
+        self,
+        beams: List[CTCBeam],
+        cached_lm_scores: dict,
+        cached_partial_token_scores: dict,
+        is_eos=False,
+    ) -> List[LMCTCBeam]:
+        """Score the beams with the language model if not None, and
+        return the new beams.
+
+        This function is modified and adapted from
+        https://github.com/kensho-technologies/pyctcdecode
+
+        Arguments
+        ---------
+        beams : list
+            The list of the beams.
+        cached_lm_scores : dict
+            The cached language model scores.
+        cached_partial_token_scores : dict
+            The cached partial token scores.
+        is_eos : bool (default: False)
+            Whether the end of the sequence has been reached.
+
+        Returns
+        -------
+        new_beams : list
+            The list of the new beams.
+        """
+        if self.lm is None:
+            # no lm is used, lm_score is equal to score and we can return the beams
+            # we have to keep track of the probabilities as well
+            new_beams = []
+            for beam in beams:
+                new_text = self.merge_tokens(beam.full_text, beam.next_word)
+                new_beams.append(
+                    LMCTCBeam(
+                        text=beam.text,
+                        full_text=new_text,
+                        next_word="",
+                        partial_word=beam.partial_word,
+                        last_token=beam.last_token,
+                        last_token_index=beam.last_token_index,
+                        text_frames=beam.text_frames,
+                        partial_frames=beam.partial_frames,
+                        p=beam.p,
+                        p_b=beam.p_b,
+                        p_nb=beam.p_nb,
+                        n_p_b=beam.n_p_b,
+                        n_p_nb=beam.n_p_nb,
+                        score=beam.score,
+                        score_ctc=beam.score_ctc,
+                        lm_score=beam.score,
+                    )
+                )
+            return new_beams
+        else:
+            # lm is used, we need to compute the lm_score
+            # first we compute the lm_score of the next word
+            # we check if the next word is in the cache
+            # if not, we compute the score and add it to the cache
+            new_beams = []
+            for beam in beams:
+                # fast token merge
+                new_text = self.merge_tokens(beam.full_text, beam.next_word)
+                cache_key = (new_text, is_eos)
+                if cache_key not in cached_lm_scores:
+                    prev_raw_lm_score, start_state = cached_lm_scores[
+                        (beam.full_text, False)
+                    ]
+                    score, end_state = self.lm.score(
+                        start_state, beam.next_word, is_last_word=is_eos
+                    )
+                    raw_lm_score = prev_raw_lm_score + score
+                    cached_lm_scores[cache_key] = (raw_lm_score, end_state)
+                lm_score, _ = cached_lm_scores[cache_key]
+                word_part = beam.partial_word
+
+                # we score the partial word
+                if len(word_part) > 0:
+                    if word_part not in cached_partial_token_scores:
+
+                        cached_partial_token_scores[
+                            word_part
+                        ] = self.lm.score_partial_token(word_part)
+                    lm_score += cached_partial_token_scores[word_part]
+
+                new_beams.append(
+                    LMCTCBeam(
+                        text=beam.text,
+                        full_text=new_text,
+                        next_word="",
+                        partial_word=beam.partial_word,
+                        last_token=beam.last_token,
+                        last_token_index=beam.last_token_index,
+                        text_frames=beam.text_frames,
+                        partial_frames=beam.partial_frames,
+                        p=beam.p,
+                        p_b=beam.p_b,
+                        p_nb=beam.p_nb,
+                        n_p_b=beam.n_p_b,
+                        n_p_nb=beam.n_p_nb,
+                        score=beam.score,
+                        score_ctc=beam.score_ctc,
+                        lm_score=beam.score + lm_score,
+                    )
+                )
+            return new_beams
+
+    def _get_new_beam(
+        self,
+        frame_index: int,
+        new_prefix: str,
+        new_token: str,
+        new_token_index: int,
+        beams: List[CTCBeam],
+        p: float,
+        previous_beam: CTCBeam,
+    ) -> CTCBeam:
+        """Create a new beam and add it to the list of beams.
+
+        Arguments
+        ---------
+        frame_index : int
+            The index of the current frame.
+        new_prefix : str
+            The new prefix.
+        new_token : str
+            The new token.
+        new_token_index : int
+            The index of the new token.
+        beams : list
+            The list of beams.
+        p : float
+            The probability of the new token.
+        previous_beam : CTCBeam
+            The previous beam.
+
+        Returns
+        -------
+        new_beam : CTCBeam
+            The new beam.
+        """
+        for beam in beams:
+            if beam.text == new_prefix:
+                if p and p > beam.p:
+                    beam.p = p
+                return beam
+
+        if not self.is_spm and new_token_index == self.space_index:
+            new_frame_list = (
+                previous_beam.text_frames
+                if previous_beam.partial_word == ""
+                else previous_beam.text_frames + [previous_beam.partial_frames]
+            )
+
+            # if we extend the beam with a space, we need to reset the partial word
+            # and move it to the next word
+            new_beam = CTCBeam(
+                text=new_prefix,
+                full_text=previous_beam.full_text,
+                next_word=previous_beam.partial_word,
+                partial_word="",
+                last_token=new_token,
+                last_token_index=new_token_index,
+                text_frames=new_frame_list,
+                partial_frames=(-1, -1),
+                score=-math.inf,
+                score_ctc=-math.inf,
+                p_b=-math.inf,
+            )
+        elif self.is_spm and new_token[:1] == self.spm_token:
+            # remove the spm token at the beginning of the token
+            clean_token = new_token[1:]
+
+            new_frame_list = (
+                previous_beam.text_frames
+                if previous_beam.partial_word == ""
+                else previous_beam.text_frames + [previous_beam.partial_frames]
+            )
+
+            # If the beginning of the token is the spm_token
+            # then it means that we are extending the beam with a new word.
+            # We need to change the new_word with the partial_word
+            # and reset the partial_word with the new token
+            new_prefix = previous_beam.text + " " + clean_token
+            new_beam = CTCBeam(
+                text=new_prefix,
+                full_text=previous_beam.full_text,
+                next_word=previous_beam.partial_word,
+                partial_word=clean_token,
+                last_token=new_token,
+                last_token_index=new_token_index,
+                text_frames=new_frame_list,
+                partial_frames=(frame_index, frame_index + 1),
+                score=-math.inf,
+                score_ctc=-math.inf,
+                p_b=-math.inf,
+            )
+        elif new_token_index == previous_beam.last_token_index:
+            new_end_frame = frame_index + 1
+
+            new_part_frames = (
+                previous_beam.partial_frames
+                if new_token_index == self.blank_index
+                else (previous_beam.partial_frames[0], new_end_frame)
+            )
+
+            # if repeated token, we only change the score
+            new_beam = CTCBeam(
+                text=new_prefix,
+                full_text=previous_beam.full_text,
+                next_word="",
+                partial_word=previous_beam.partial_word,
+                last_token=new_token,
+                last_token_index=new_token_index,
+                text_frames=previous_beam.text_frames,
+                partial_frames=new_part_frames,
+                score=-math.inf,
+                score_ctc=-math.inf,
+                p_b=-math.inf,
+            )
+        else:
+            new_part_frames = (
+                (frame_index, frame_index + 1)
+                if previous_beam.partial_frames[0] < 0
+                else (previous_beam.partial_frames[0], frame_index + 1)
+            )
+
+            # last case, we are extending the partial_word with a new token
+            new_beam = CTCBeam(
+                text=new_prefix,
+                full_text=previous_beam.full_text,
+                next_word="",
+                partial_word=previous_beam.partial_word + new_token,
+                last_token=new_token,
+                last_token_index=new_token_index,
+                text_frames=previous_beam.text_frames,
+                partial_frames=new_part_frames,
+                score=-math.inf,
+                score_ctc=-math.inf,
+                p_b=-math.inf,
+            )
+        beams.append(new_beam)
+        if previous_beam:
+            new_beam.p = previous_beam.p
+        return new_beam
+
+    def partial_decoding(
+        self,
+        log_probs: torch.Tensor,
+        wav_len: int,
+        beams: List[CTCBeam],
+        cached_lm_scores: dict,
+        cached_p_lm_scores: dict,
+        processed_frames: int = 0,
+    ) -> List[CTCBeam]:
+        """Perform CTC Prefix Beam Search decoding.
+
+        If self.lm is not None, the language model scores are computed and added to the CTC scores.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the CTC input.
+            Shape: (seq_length, vocab_size)
+        wav_len : int
+            The length of the input sequence.
+        beams : list
+            The list of CTCBeam objects.
+        cached_lm_scores : dict
+            The cached language model scores.
+        cached_p_lm_scores : dict
+            The cached prefix language model scores.
+        processed_frames : int
+            The start frame of the current decoding step. (default: 0)
+
+        Returns
+        -------
+        beams : list
+            The list of CTCBeam objects.
+        """
+        # select only the valid frames, i.e., the frames that are not padded
+        log_probs = log_probs[:wav_len]
+
+        for frame_index, logit_col in enumerate(
+            log_probs, start=processed_frames
+        ):
+            # skip the frame if the blank probability is higher than the threshold
+            if logit_col[self.blank_index] > self.blank_skip_threshold:
+                continue
+
+            # get the tokens with the highest probability
+            max_index = logit_col.argmax()
+            tokens_index_list = set(
+                np.where(logit_col > self.token_prune_min_logp)[0]
+            ) | {max_index}
+
+            curr_beams = beams.copy()
+
+            # select tokens that are in the vocab
+            # this is useful if the logit vocab_size is larger than the vocab_list
+            tokens_index_list = tokens_index_list & set(
+                range(len(self.vocab_list))
+            )
+
+            for token_index in tokens_index_list:
+                p_token = logit_col[token_index]
+                token = self.vocab_list[token_index]
+
+                for beam in curr_beams:
+                    p_b, p_nb = beam.p_b, beam.p_nb
+
+                    # blank case
+                    if token_index == self.blank_index:
+                        beam.n_p_b = np.logaddexp(
+                            beam.n_p_b, beam.score_ctc + p_token
+                        )
+                        continue
+
+                    if token == beam.last_token:
+                        beam.n_p_nb = np.logaddexp(beam.n_p_nb, p_nb + p_token)
+
+                    new_text = beam.text + token
+
+                    new_beam = self._get_new_beam(
+                        frame_index,
+                        new_text,
+                        token,
+                        token_index,
+                        beams,
+                        p=p_token,
+                        previous_beam=beam,
+                    )
+
+                    n_p_nb = new_beam.n_p_nb
+
+                    if token_index == beam.last_token_index and p_b > -math.inf:
+                        n_p_nb = np.logaddexp(n_p_nb, p_b + p_token)
+                    elif token_index != beam.last_token_index:
+                        n_p_nb = np.logaddexp(n_p_nb, beam.score_ctc + p_token)
+                    new_beam.n_p_nb = n_p_nb
+
+            # update the CTC probabilities
+            for beam in beams:
+                beam.step()
+
+            # kenLM scores
+            scored_beams = self.get_lm_beams(
+                beams, cached_lm_scores, cached_p_lm_scores,
+            )
+
+            # remove beams outliers
+            max_score = max([b.lm_score for b in scored_beams])
+            scored_beams = [
+                b
+                for b in scored_beams
+                if b.lm_score >= max_score + self.beam_prune_logp
+            ]
+            trimmed_beams = self.sort_beams(scored_beams)
+
+            if self.prune_history:
+                lm_order = 1 if self.lm is None else self.lm.order
+                beams = self._prune_history(trimmed_beams, lm_order=lm_order)
+            else:
+                beams = [CTCBeam.from_lm_beam(b) for b in trimmed_beams]
+
+        return beams
+
+
+class TorchAudioCTCPrefixBeamSearcher:
+    """TorchAudio CTC Prefix Beam Search Decoder.
+
+    This class is a wrapper around the CTC decoder from TorchAudio. It provides a simple interface
+    where you can either use the CPU or CUDA CTC decoder.
+
+    The CPU decoder is slower but uses less memory. The CUDA decoder is faster but uses more memory.
+    The CUDA decoder is also only available in the nightly version of torchaudio.
+
+    A lot of features are missing in the CUDA decoder, such as the ability to use a language model,
+    constraint search, and more. If you want to use those features, you have to use the CPU decoder.
+
+    For more information about the CPU decoder, please refer to the documentation of TorchAudio:
+    https://pytorch.org/audio/main/generated/torchaudio.models.decoder.ctc_decoder.html
+
+    For more information about the CUDA decoder, please refer to the documentation of TorchAudio:
+    https://pytorch.org/audio/main/generated/torchaudio.models.decoder.cuda_ctc_decoder.html#torchaudio.models.decoder.cuda_ctc_decoder
+
+    If you want to use the language model, or the lexicon search, please make sure that your
+    tokenizer/acoustic model uses the same tokens as the language model/lexicon. Otherwise, the decoding will fail.
+
+    The implementation is compatible with Sentenpiece Tokens.
+
+    Note: When using CUDA CTC decoder, the blank_index has to be 0. Furthermore, using CUDA CTC decoder
+    requires the nightly version of torchaudio and a lot of VRAM memory (if you want to use a lot of beams).
+    Overall, we do recommand to use the CTCBeamSearcher or CTCPrefixBeamSearcher in SpeechBrain if you wants to use
+    n-gram + beam search decoding. If you wants to have constraint search, please use the CPU version of torchaudio,
+    and if you want to speedup as much as possible the decoding, please use the CUDA version.
+
+    Arguments
+    ---------
+    tokens : list or str
+        The list of tokens or the path to the tokens file.
+        If this is a path, then the file should contain one token per line.
+    lexicon : str, default: None
+        Lexicon file containing the possible words and corresponding spellings. Each line consists of a word and its space separated spelling.
+        If None, uses lexicon-free decoding. (default: None)
+    lm : str, optional
+        A path containing KenLM language model or None if not using a language model. (default: None)
+    lm_dict : str, optional
+        File consisting of the dictionary used for the LM, with a word per line sorted by LM index.
+        If decoding with a lexicon, entries in lm_dict must also occur in the lexicon file.
+        If None, dictionary for LM is constructed using the lexicon file. (default: None)
+    topk : int, optional
+        Number of top CTCHypothesis to return. (default: 1)
+    beam_size : int, optional
+        Numbers of hypotheses to hold after each decode step. (default: 50)
+    beam_size_token : int, optional
+        Max number of tokens to consider at each decode step. If None, it is set to the total number of tokens. (default: None)
+    beam_threshold : float, optional
+        Threshold for pruning hypothesis. (default: 50)
+    lm_weight : float, optional
+        Weight of language model. (default: 2)
+    word_score : float, optional
+        Word insertion score. (default: 0)
+    unk_score : float, optional
+        Unknown word insertion score. (default: float("-inf"))
+    sil_score : float, optional
+        Silence insertion score. (default: 0)
+    log_add : bool, optional
+        Whether to use use logadd when merging hypotheses. (default: False)
+    blank_index : int or str, optional
+        Index of the blank token. If tokens is a file path, then this should be an str. Otherwise, this should be a int. (default: 0)
+    sil_index : int or str, optional
+        Index of the silence token. If tokens is a file path, then this should be an str. Otherwise, this should be a int. (default: 0)
+    unk_word : str, optional
+        Unknown word token. (default: "<unk>")
+    using_cpu_decoder : bool, optional
+        Whether to use the CPU searcher. If False, then the CUDA decoder is used. (default: True)
+    blank_skip_threshold : float, optional
+        Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding (default: 1.0).
+        Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.decoders import TorchAudioCTCPrefixBeamSearcher
+    >>> probs = torch.tensor([[[0.2, 0.0, 0.8],
+    ...                   [0.4, 0.0, 0.6]]])
+    >>> log_probs = torch.log(probs)
+    >>> lens = torch.tensor([1.0])
+    >>> blank_index = 2
+    >>> vocab_list = ['a', 'b', '-']
+    >>> searcher = TorchAudioCTCPrefixBeamSearcher(tokens=vocab_list, blank_index=blank_index, sil_index=blank_index) # doctest: +SKIP
+    >>> hyps = searcher(probs, lens) # doctest: +SKIP
+    """
+
+    def __init__(
+        self,
+        tokens: Union[list, str],
+        lexicon: Optional[str] = None,
+        lm: Optional[str] = None,
+        lm_dict: Optional[str] = None,
+        topk: int = 1,
+        beam_size: int = 50,
+        beam_size_token: Optional[int] = None,
+        beam_threshold: float = 50,
+        lm_weight: float = 2,
+        word_score: float = 0,
+        unk_score: float = float("-inf"),
+        sil_score: float = 0,
+        log_add: bool = False,
+        blank_index: Union[str, int] = 0,
+        sil_index: Union[str, int] = 0,
+        unk_word: str = "<unk>",
+        using_cpu_decoder: bool = True,
+        blank_skip_threshold: float = 1.0,
+    ):
+        self.lexicon = lexicon
+        self.tokens = tokens
+        self.lm = lm
+        self.lm_dict = lm_dict
+        self.topk = topk
+        self.beam_size = beam_size
+        self.beam_size_token = beam_size_token
+        self.beam_threshold = beam_threshold
+        self.lm_weight = lm_weight
+        self.word_score = word_score
+        self.unk_score = unk_score
+        self.sil_score = sil_score
+        self.log_add = log_add
+        self.blank_index = blank_index
+        self.sil_index = sil_index
+        self.unk_word = unk_word
+        self.using_cpu_decoder = using_cpu_decoder
+        self.blank_skip_threshold = blank_skip_threshold
+
+        if self.using_cpu_decoder:
+            try:
+                from torchaudio.models.decoder import ctc_decoder
+            except ImportError:
+                raise ImportError(
+                    "ctc_decoder not found. Please install torchaudio and flashlight to use this decoder."
+                )
+
+            # if this is a path, then torchaudio expect to be an index
+            # while if its a list then it expects to be a token
+            if isinstance(self.tokens, str):
+                blank_token = self.blank_index
+                sil_token = self.sil_index
+            else:
+                blank_token = self.tokens[self.blank_index]
+                sil_token = self.tokens[self.sil_index]
+
+            self._ctc_decoder = ctc_decoder(
+                lexicon=self.lexicon,
+                tokens=self.tokens,
+                lm=self.lm,
+                lm_dict=self.lm_dict,
+                nbest=self.topk,
+                beam_size=self.beam_size,
+                beam_size_token=self.beam_size_token,
+                beam_threshold=self.beam_threshold,
+                lm_weight=self.lm_weight,
+                word_score=self.word_score,
+                unk_score=self.unk_score,
+                sil_score=self.sil_score,
+                log_add=self.log_add,
+                blank_token=blank_token,
+                sil_token=sil_token,
+                unk_word=self.unk_word,
+            )
+        else:
+            try:
+                from torchaudio.models.decoder import cuda_ctc_decoder
+            except ImportError:
+                raise ImportError(
+                    "cuda_ctc_decoder not found. Please install the latest version of torchaudio to use this decoder."
+                )
+            assert (
+                self.blank_index == 0
+            ), "Index of blank token has to be 0 when using CUDA CTC decoder."
+
+            self._ctc_decoder = cuda_ctc_decoder(
+                tokens=self.tokens,
+                nbest=self.topk,
+                beam_size=self.beam_size,
+                blank_skip_threshold=self.blank_skip_threshold,
+            )
+
+    def decode_beams(
+        self, log_probs: torch.Tensor, wav_len: Union[torch.Tensor, None] = None
+    ) -> List[List[CTCHypothesis]]:
+        """Decode log_probs using TorchAudio CTC decoder.
+
+        If `using_cpu_decoder=True` then log_probs and wav_len are moved to CPU before decoding.
+        When using CUDA CTC decoder, the timestep information is not available. Therefore, the timesteps
+        in the returned hypotheses are set to None.
+
+        Make sure that the input are in the log domain. The decoder will fail to decode
+        logits or probabilities. The input should be the log probabilities of the CTC output.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the input audio.
+            Shape: (batch_size, seq_length, vocab_size)
+        wav_len : torch.Tensor, default: None
+            The speechbrain-style relative length. Shape: (batch_size,)
+            If None, then the length of each audio is assumed to be seq_length.
+
+        Returns
+        -------
+        list of list of CTCHypothesis
+            The decoded hypotheses. The outer list is over the batch dimension, and the inner list is over the topk dimension.
+        """
+        if wav_len is not None:
+            wav_len = log_probs.size(1) * wav_len
+        else:
+            wav_len = torch.tensor(
+                [log_probs.size(1)] * log_probs.size(0),
+                device=log_probs.device,
+                dtype=torch.int32,
+            )
+
+        if wav_len.dtype != torch.int32:
+            wav_len = wav_len.to(torch.int32)
+
+        if log_probs.dtype != torch.float32:
+            raise ValueError("log_probs must be float32.")
+
+        # When using CPU decoder, we need to move the log_probs and wav_len to CPU
+        if self.using_cpu_decoder and log_probs.is_cuda:
+            log_probs = log_probs.cpu()
+
+        if self.using_cpu_decoder and wav_len.is_cuda:
+            wav_len = wav_len.cpu()
+
+        if not log_probs.is_contiguous():
+            raise RuntimeError("log_probs must be contiguous.")
+
+        results = self._ctc_decoder(log_probs, wav_len)
+
+        tokens_preds = []
+        words_preds = []
+        scores_preds = []
+        timesteps_preds = []
+
+        # over batch dim
+        for i in range(len(results)):
+
+            if self.using_cpu_decoder:
+
+                preds = [
+                    results[i][j].tokens.tolist()
+                    for j in range(len(results[i]))
+                ]
+                preds = [
+                    [self.tokens[token] for token in tokens] for tokens in preds
+                ]
+                tokens_preds.append(preds)
+
+                timesteps = [
+                    results[i][j].timesteps.tolist()
+                    for j in range(len(results[i]))
+                ]
+                timesteps_preds.append(timesteps)
+
+            else:
+                # no timesteps is available for CUDA CTC decoder
+                timesteps = [None for _ in range(len(results[i]))]
+                timesteps_preds.append(timesteps)
+
+                preds = [results[i][j].tokens for j in range(len(results[i]))]
+                preds = [
+                    [self.tokens[token] for token in tokens] for tokens in preds
+                ]
+                tokens_preds.append(preds)
+
+            words = [results[i][j].words for j in range(len(results[i]))]
+            words_preds.append(words)
+
+            scores = [results[i][j].score for j in range(len(results[i]))]
+            scores_preds.append(scores)
+
+        hyps = []
+        for (
+            batch_index,
+            (batch_text, batch_score, batch_timesteps),
+        ) in enumerate(zip(tokens_preds, scores_preds, timesteps_preds)):
+            hyps.append([])
+            for text, score, timestep in zip(
+                batch_text, batch_score, batch_timesteps
+            ):
+                hyps[batch_index].append(
+                    CTCHypothesis(
+                        text="".join(text),
+                        last_lm_state=None,
+                        score=score,
+                        lm_score=score,
+                        text_frames=timestep,
+                    )
+                )
+        return hyps
+
+    def __call__(
+        self, log_probs: torch.Tensor, wav_len: Union[torch.Tensor, None] = None
+    ) -> List[List[CTCHypothesis]]:
+        """Decode log_probs using TorchAudio CTC decoder.
+
+        If `using_cpu_decoder=True` then log_probs and wav_len are moved to CPU before decoding.
+        When using CUDA CTC decoder, the timestep information is not available. Therefore, the timesteps
+        in the returned hypotheses are set to None.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log probabilities of the input audio.
+            Shape: (batch_size, seq_length, vocab_size)
+        wav_len : torch.Tensor, default: None
+            The speechbrain-style relative length. Shape: (batch_size,)
+            If None, then the length of each audio is assumed to be seq_length.
+
+        Returns
+        -------
+        list of list of CTCHypothesis
+            The decoded hypotheses. The outer list is over the batch dimension, and the inner list is over the topk dimension.
+        """
+        return self.decode_beams(log_probs, wav_len)
diff --git a/speechbrain/decoders/language_model.py b/speechbrain/decoders/language_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af191dfb818e2ab8fea052416cbb28d945938cb
--- /dev/null
+++ b/speechbrain/decoders/language_model.py
@@ -0,0 +1,260 @@
+"""Language model wrapper for kenlm n-gram.
+
+This file is based on the implementation of the kenLM wrapper from
+PyCTCDecode (see: https://github.com/kensho-technologies/pyctcdecode) and
+is used in CTC decoders.
+
+See: speechbrain.decoders.ctc.py
+
+Authors
+ * Adel Moumen 2023
+"""
+import logging
+from typing import (
+    Collection,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
+
+from pygtrie import CharTrie
+
+import math
+
+logger = logging.getLogger(__name__)
+
+try:
+    import kenlm
+except ImportError:
+    raise ImportError(
+        "kenlm python bindings are not installed. To install it use: "
+        "pip install https://github.com/kpu/kenlm/archive/master.zip"
+    )
+
+
+def load_unigram_set_from_arpa(arpa_path: str) -> Set[str]:
+    """Read unigrams from arpa file.
+
+    Taken from: https://github.com/kensho-technologies/pyctcdecode
+
+    Arguments
+    ---------
+    arpa_path : str
+        Path to arpa file.
+
+    Returns
+    -------
+    unigrams : set
+        Set of unigrams.
+    """
+    unigrams = set()
+    with open(arpa_path) as f:
+        start_1_gram = False
+        for line in f:
+            line = line.strip()
+            if line == "\\1-grams:":
+                start_1_gram = True
+            elif line == "\\2-grams:":
+                break
+            if start_1_gram and len(line) > 0:
+                parts = line.split("\t")
+                if len(parts) == 3:
+                    unigrams.add(parts[1])
+    if len(unigrams) == 0:
+        raise ValueError(
+            "No unigrams found in arpa file. Something is wrong with the file."
+        )
+    return unigrams
+
+
+class KenlmState:
+    """Wrapper for kenlm state.
+
+    This is a wrapper for the kenlm state object. It is used to make sure that the
+    state is not modified outside of the language model class.
+
+    Taken from: https://github.com/kensho-technologies/pyctcdecode
+
+    Arguments
+    ---------
+    state : kenlm.State
+        Kenlm state object.
+    """
+
+    def __init__(self, state: "kenlm.State"):
+        self._state = state
+
+    @property
+    def state(self) -> "kenlm.State":
+        """Get the raw state object."""
+        return self._state
+
+
+def _prepare_unigram_set(
+    unigrams: Collection[str], kenlm_model: "kenlm.Model"
+) -> Set[str]:
+    """Filter unigrams down to vocabulary that exists in kenlm_model.
+
+    Taken from: https://github.com/kensho-technologies/pyctcdecode
+
+    Arguments
+    ---------
+    unigrams : list
+        List of unigrams.
+    kenlm_model : kenlm.Model
+        Kenlm model.
+
+    Returns
+    -------
+    unigram_set : set
+        Set of unigrams.
+    """
+    if len(unigrams) < 1000:
+        logger.warning(
+            "Only %s unigrams passed as vocabulary. Is this small or artificial data?",
+            len(unigrams),
+        )
+    unigram_set = set(unigrams)
+    unigram_set = set([t for t in unigram_set if t in kenlm_model])
+    retained_fraction = (
+        1.0 if len(unigrams) == 0 else len(unigram_set) / len(unigrams)
+    )
+    if retained_fraction < 0.1:
+        logger.warning(
+            "Only %s%% of unigrams in vocabulary found in kenlm model-- this might mean that your "
+            "vocabulary and language model are incompatible. Is this intentional?",
+            round(retained_fraction * 100, 1),
+        )
+    return unigram_set
+
+
+def _get_empty_lm_state() -> "kenlm.State":
+    """Get unintialized kenlm state.
+
+    Taken from: https://github.com/kensho-technologies/pyctcdecode
+
+    Returns
+    -------
+    kenlm_state : kenlm.State
+        Empty kenlm state.
+    """
+    try:
+        kenlm_state = kenlm.State()
+    except ImportError:
+        raise ValueError("To use a language model, you need to install kenlm.")
+    return kenlm_state
+
+
+class LanguageModel:
+    """Language model container class to consolidate functionality.
+
+    This class is a wrapper around the kenlm language model. It provides
+    functionality to score tokens and to get the initial state.
+
+    Taken from: https://github.com/kensho-technologies/pyctcdecode
+
+    Arguments
+    ---------
+    kenlm_model : kenlm.Model
+        Kenlm model.
+    unigrams : list
+        List of known word unigrams.
+    alpha : float
+        Weight for language model during shallow fusion.
+    beta : float
+        Weight for length score adjustment of during scoring.
+    unk_score_offset : float
+        Amount of log score offset for unknown tokens.
+    score_boundary : bool
+        Whether to have kenlm respect boundaries when scoring.
+    """
+
+    def __init__(
+        self,
+        kenlm_model: "kenlm.Model",
+        unigrams: Optional[Collection[str]] = None,
+        alpha: float = 0.5,
+        beta: float = 1.5,
+        unk_score_offset: float = -10.0,
+        score_boundary: bool = True,
+    ) -> None:
+        self._kenlm_model = kenlm_model
+        if unigrams is None:
+            logger.warning(
+                "No known unigrams provided, decoding results might be a lot worse."
+            )
+            unigram_set = set()
+            char_trie = None
+        else:
+            unigram_set = _prepare_unigram_set(unigrams, self._kenlm_model)
+            char_trie = CharTrie.fromkeys(unigram_set)
+        self._unigram_set = unigram_set
+        self._char_trie = char_trie
+        self.alpha = alpha
+        self.beta = beta
+        self.unk_score_offset = unk_score_offset
+        self.score_boundary = score_boundary
+
+    @property
+    def order(self) -> int:
+        """Get the order of the n-gram language model."""
+        return cast(int, self._kenlm_model.order)
+
+    def get_start_state(self) -> KenlmState:
+        """Get initial lm state."""
+        start_state = _get_empty_lm_state()
+        if self.score_boundary:
+            self._kenlm_model.BeginSentenceWrite(start_state)
+        else:
+            self._kenlm_model.NullContextWrite(start_state)
+        return KenlmState(start_state)
+
+    def _get_raw_end_score(self, start_state: "kenlm.State") -> float:
+        """Calculate final lm score."""
+        if self.score_boundary:
+            end_state = _get_empty_lm_state()
+            score: float = self._kenlm_model.BaseScore(
+                start_state, "</s>", end_state
+            )
+        else:
+            score = 0.0
+        return score
+
+    def score_partial_token(self, partial_token: str) -> float:
+        """Get partial token score."""
+        if self._char_trie is None:
+            is_oov = 1.0
+        else:
+            is_oov = int(self._char_trie.has_node(partial_token) == 0)
+        unk_score = self.unk_score_offset * is_oov
+        # if unk token length exceeds expected length then additionally decrease score
+        if len(partial_token) > 6:
+            unk_score = unk_score * len(partial_token) / 6
+        return unk_score
+
+    def score(
+        self, prev_state, word: str, is_last_word: bool = False
+    ) -> Tuple[float, KenlmState]:
+        """Score word conditional on start state."""
+        if not isinstance(prev_state, KenlmState):
+            raise AssertionError(
+                f"Wrong input state type found. Expected KenlmState, got {type(prev_state)}"
+            )
+        end_state = _get_empty_lm_state()
+        lm_score = self._kenlm_model.BaseScore(
+            prev_state.state, word, end_state
+        )
+        # override UNK prob. use unigram set if we have because it's faster
+        if (
+            len(self._unigram_set) > 0
+            and word not in self._unigram_set
+            or word not in self._kenlm_model
+        ):
+            lm_score += self.unk_score_offset
+        # add end of sentence context if needed
+        if is_last_word:
+            # note that we want to return the unmodified end_state to keep extension capabilities
+            lm_score = lm_score + self._get_raw_end_score(end_state)
+        lm_score = self.alpha * lm_score * 1.0 / math.log10(math.e) + self.beta
+        return lm_score, KenlmState(end_state)
diff --git a/speechbrain/decoders/scorer.py b/speechbrain/decoders/scorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b698c304ed61a25d6e3ffa6bc2230bf4dac577db
--- /dev/null
+++ b/speechbrain/decoders/scorer.py
@@ -0,0 +1,2061 @@
+"""
+Token scorer abstraction and specifications.
+
+Authors:
+ * Adel Moumen 2022, 2023
+ * Sung-Lin Yeh 2021
+"""
+
+import torch
+import numpy as np
+import speechbrain as sb
+from speechbrain.decoders.ctc import CTCPrefixScore
+
+
+class BaseScorerInterface:
+    """A scorer abstraction to be inherited by other
+    scoring approaches for beam search.
+
+    A scorer is a module that scores tokens in vocabulary
+    based on the current timestep input and the previous
+    scorer states. It can be used to score on full vocabulary
+    set (i.e., full scorers) or a pruned set of tokens (i.e. partial scorers)
+    to prevent computation overhead. In the latter case, the partial scorers
+    will be called after the full scorers. It will only scores the
+    top-k candidates (i.e., pruned set of tokens) extracted from the full scorers.
+    The top-k candidates are extracted based on the beam size and the
+    scorer_beam_scale such that the number of candidates is
+    int(beam_size * scorer_beam_scale). It can be very useful
+    when the full scorers are computationally expensive (e.g., KenLM scorer).
+
+    Inherit this class to implement your own scorer compatible with
+    speechbrain.decoders.seq2seq.S2SBeamSearcher().
+
+    See:
+        - speechbrain.decoders.scorer.CTCPrefixScorer
+        - speechbrain.decoders.scorer.RNNLMScorer
+        - speechbrain.decoders.scorer.TransformerLMScorer
+        - speechbrain.decoders.scorer.KenLMScorer
+        - speechbrain.decoders.scorer.CoverageScorer
+        - speechbrain.decoders.scorer.LengthScorer
+    """
+
+    def score(self, inp_tokens, memory, candidates, attn):
+        """This method scores the new beams based on the
+        informations of the current timestep.
+
+        A score is a tensor of shape (batch_size x beam_size, vocab_size).
+        It is the log probability of the next token given the current
+        timestep input and the previous scorer states.
+
+        It can be used to score on pruned top-k candidates
+        to prevent computation overhead, or on full vocabulary set
+        when candidates is None.
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The input tensor of the current timestep.
+        memory : No limit
+            The scorer states for this timestep.
+        candidates : torch.Tensor
+            (batch_size x beam_size, scorer_beam_size).
+            The top-k candidates to be scored after the full scorers.
+            If None, scorers will score on full vocabulary set.
+        attn : torch.Tensor
+            The attention weight to be used in CoverageScorer or CTCScorer.
+
+        Returns
+        ---------
+        torch.Tensor
+            (batch_size x beam_size, vocab_size), Scores for the next tokens.
+        memory : No limit
+            The memory variables input for this timestep.
+        """
+        raise NotImplementedError
+
+    def permute_mem(self, memory, index):
+        """This method permutes the scorer memory to synchronize
+        the memory index with the current output and perform
+        batched beam search.
+
+        Arguments
+        ---------
+        memory : No limit
+            The memory variables input for this timestep.
+        index : torch.Tensor
+            (batch_size, beam_size). The index of the previous path.
+        """
+        return None
+
+    def reset_mem(self, x, enc_lens):
+        """This method should implement the resetting of
+        memory variables for the scorer.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            The precomputed encoder states to be used when decoding.
+            (ex. the encoded speech representation to be attended).
+        enc_lens : torch.Tensor
+            The speechbrain-style relative length.
+        """
+        return None
+
+
+class CTCScorer(BaseScorerInterface):
+    """A wrapper of CTCPrefixScore based on the BaseScorerInterface.
+
+    This Scorer is used to provides the CTC label-synchronous scores
+    of the next input tokens. The implementation is based on
+    https://www.merl.com/publications/docs/TR2017-190.pdf.
+
+    See:
+        - speechbrain.decoders.scorer.CTCPrefixScore
+
+    Arguments
+    ---------
+    ctc_fc : torch.nn.Module
+        A output linear layer for ctc.
+    blank_index : int
+        The index of the blank token.
+    eos_index : int
+        The index of the end-of-sequence (eos) token.
+    ctc_window_size : int
+        Compute the ctc scores over the time frames using windowing
+        based on attention peaks. If 0, no windowing applied. (default: 0)
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.nnet.linear import Linear
+    >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
+    >>> from speechbrain.decoders import S2STransformerBeamSearcher, CTCScorer, ScorerBuilder
+    >>> batch_size=8
+    >>> n_channels=6
+    >>> input_size=40
+    >>> d_model=128
+    >>> tgt_vocab=140
+    >>> src = torch.rand([batch_size, n_channels, input_size])
+    >>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels])
+    >>> net = TransformerASR(
+    ...    tgt_vocab, input_size, d_model, 8, 1, 1, 1024, activation=torch.nn.GELU
+    ... )
+    >>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab)
+    >>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab)
+    >>> eos_index = 2
+    >>> ctc_scorer = CTCScorer(
+    ...    ctc_fc=ctc_lin,
+    ...    blank_index=0,
+    ...    eos_index=eos_index,
+    ... )
+    >>> scorer = ScorerBuilder(
+    ...     full_scorers=[ctc_scorer],
+    ...     weights={'ctc': 1.0}
+    ... )
+    >>> searcher = S2STransformerBeamSearcher(
+    ...     modules=[net, lin],
+    ...     bos_index=1,
+    ...     eos_index=eos_index,
+    ...     min_decode_ratio=0.0,
+    ...     max_decode_ratio=1.0,
+    ...     using_eos_threshold=False,
+    ...     beam_size=7,
+    ...     temperature=1.15,
+    ...     scorer=scorer
+    ... )
+    >>> enc, dec = net.forward(src, tgt)
+    >>> hyps, _, _, _ = searcher(enc, torch.ones(batch_size))
+    """
+
+    def __init__(
+        self, ctc_fc, blank_index, eos_index, ctc_window_size=0,
+    ):
+        self.ctc_fc = ctc_fc
+        self.blank_index = blank_index
+        self.eos_index = eos_index
+        self.ctc_window_size = ctc_window_size
+        self.softmax = sb.nnet.activations.Softmax(apply_log=True)
+
+    def score(self, inp_tokens, memory, candidates, attn):
+        """This method scores the new beams based on the
+        CTC scores computed over the time frames.
+
+        See:
+            - speechbrain.decoders.scorer.CTCPrefixScore
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The input tensor of the current timestep.
+        memory : No limit
+            The scorer states for this timestep.
+        candidates : torch.Tensor
+            (batch_size x beam_size, scorer_beam_size).
+            The top-k candidates to be scored after the full scorers.
+            If None, scorers will score on full vocabulary set.
+        attn : torch.Tensor
+            The attention weight to be used in CoverageScorer or CTCScorer.
+        """
+        scores, memory = self.ctc_score.forward_step(
+            inp_tokens, memory, candidates, attn
+        )
+        return scores, memory
+
+    def permute_mem(self, memory, index):
+        """This method permutes the scorer memory to synchronize
+        the memory index with the current output and perform
+        batched CTC beam search.
+
+        Arguments
+        ---------
+        memory : No limit
+            The memory variables input for this timestep.
+        index : torch.Tensor
+            (batch_size, beam_size). The index of the previous path.
+        """
+        r, psi = self.ctc_score.permute_mem(memory, index)
+        return r, psi
+
+    def reset_mem(self, x, enc_lens):
+        """This method implement the resetting of
+        memory variables for the CTC scorer.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            The precomputed encoder states to be used when decoding.
+            (ex. the encoded speech representation to be attended).
+        enc_lens : torch.Tensor
+            The speechbrain-style relative length.
+        """
+        logits = self.ctc_fc(x)
+        x = self.softmax(logits)
+        self.ctc_score = CTCPrefixScore(
+            x, enc_lens, self.blank_index, self.eos_index, self.ctc_window_size
+        )
+        return None
+
+
+class RNNLMScorer(BaseScorerInterface):
+    """A wrapper of RNNLM based on BaseScorerInterface.
+
+    The RNNLMScorer is used to provide the RNNLM scores of the next input tokens
+    based on the current timestep input and the previous scorer states.
+
+    Arguments
+    ---------
+    language_model : torch.nn.Module
+        A RNN-based language model.
+    temperature : float
+        Temperature factor applied to softmax. It changes the probability
+        distribution, being softer when T>1 and sharper with T<1. (default: 1.0)
+
+    Example
+    -------
+    >>> from speechbrain.nnet.linear import Linear
+    >>> from speechbrain.lobes.models.RNNLM import RNNLM
+    >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder
+    >>> from speechbrain.decoders import S2SRNNBeamSearcher, RNNLMScorer, ScorerBuilder
+    >>> input_size=17
+    >>> vocab_size=11
+    >>> emb = torch.nn.Embedding(
+    ...     embedding_dim=input_size,
+    ...     num_embeddings=vocab_size,
+    ... )
+    >>> d_model=7
+    >>> dec = AttentionalRNNDecoder(
+    ...     rnn_type="gru",
+    ...     attn_type="content",
+    ...     hidden_size=3,
+    ...     attn_dim=3,
+    ...     num_layers=1,
+    ...     enc_dim=d_model,
+    ...     input_size=input_size,
+    ... )
+    >>> n_channels=3
+    >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size)
+    >>> lm_weight = 0.4
+    >>> lm_model = RNNLM(
+    ...     embedding_dim=d_model,
+    ...     output_neurons=vocab_size,
+    ...     dropout=0.0,
+    ...     rnn_neurons=128,
+    ...     dnn_neurons=64,
+    ...     return_hidden=True,
+    ... )
+    >>> rnnlm_scorer = RNNLMScorer(
+    ...     language_model=lm_model,
+    ...     temperature=1.25,
+    ... )
+    >>> scorer = ScorerBuilder(
+    ...     full_scorers=[rnnlm_scorer],
+    ...     weights={'rnnlm': lm_weight}
+    ... )
+    >>> beam_size=5
+    >>> searcher = S2SRNNBeamSearcher(
+    ...     embedding=emb,
+    ...     decoder=dec,
+    ...     linear=seq_lin,
+    ...     bos_index=1,
+    ...     eos_index=2,
+    ...     min_decode_ratio=0.0,
+    ...     max_decode_ratio=1.0,
+    ...     topk=2,
+    ...     using_eos_threshold=False,
+    ...     beam_size=beam_size,
+    ...     temperature=1.25,
+    ...     scorer=scorer
+    ... )
+    >>> batch_size=2
+    >>> enc = torch.rand([batch_size, n_channels, d_model])
+    >>> wav_len = torch.ones([batch_size])
+    >>> hyps, _, _, _ = searcher(enc, wav_len)
+    """
+
+    def __init__(self, language_model, temperature=1.0):
+        self.lm = language_model
+        self.lm.eval()
+        self.temperature = temperature
+        self.softmax = sb.nnet.activations.Softmax(apply_log=True)
+
+    def score(self, inp_tokens, memory, candidates, attn):
+        """This method scores the new beams based on the
+        RNNLM scores computed over the previous tokens.
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The input tensor of the current timestep.
+        memory : No limit
+            The scorer states for this timestep.
+        candidates : torch.Tensor
+            (batch_size x beam_size, scorer_beam_size).
+            The top-k candidates to be scored after the full scorers.
+            If None, scorers will score on full vocabulary set.
+        attn : torch.Tensor
+            The attention weight to be used in CoverageScorer or CTCScorer.
+        """
+        with torch.no_grad():
+            logits, hs = self.lm(inp_tokens, hx=memory)
+            log_probs = self.softmax(logits / self.temperature)
+        return log_probs, hs
+
+    def permute_mem(self, memory, index):
+        """This method permutes the scorer memory to synchronize
+        the memory index with the current output and perform
+        batched beam search.
+
+        Arguments
+        ---------
+        memory : No limit
+            The memory variables input for this timestep.
+        index : torch.Tensor
+            (batch_size, beam_size). The index of the previous path.
+        """
+        if isinstance(memory, tuple):
+            memory_0 = torch.index_select(memory[0], dim=1, index=index)
+            memory_1 = torch.index_select(memory[1], dim=1, index=index)
+            memory = (memory_0, memory_1)
+        else:
+            memory = torch.index_select(memory, dim=1, index=index)
+        return memory
+
+    def reset_mem(self, x, enc_lens):
+        """This method implement the resetting of
+        memory variables for the RNNLM scorer.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            The precomputed encoder states to be used when decoding.
+            (ex. the encoded speech representation to be attended).
+        enc_lens : torch.Tensor
+            The speechbrain-style relative length.
+        """
+        return None
+
+
+class TransformerLMScorer(BaseScorerInterface):
+    """A wrapper of TransformerLM based on BaseScorerInterface.
+
+    The TransformerLMScorer is used to provide the TransformerLM scores
+    of the next input tokens based on the current timestep input and the
+    previous scorer states.
+
+    Arguments
+    ---------
+    language_model : torch.nn.Module
+        A Transformer-based language model.
+    temperature : float
+        Temperature factor applied to softmax. It changes the probability
+        distribution, being softer when T>1 and sharper with T<1. (default: 1.0)
+
+    Example
+    -------
+    >>> from speechbrain.nnet.linear import Linear
+    >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
+    >>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM
+    >>> from speechbrain.decoders import S2STransformerBeamSearcher, TransformerLMScorer, CTCScorer, ScorerBuilder
+    >>> input_size=17
+    >>> vocab_size=11
+    >>> d_model=128
+    >>> net = TransformerASR(
+    ...     tgt_vocab=vocab_size,
+    ...     input_size=input_size,
+    ...     d_model=d_model,
+    ...     nhead=8,
+    ...     num_encoder_layers=1,
+    ...     num_decoder_layers=1,
+    ...     d_ffn=256,
+    ...     activation=torch.nn.GELU
+    ... )
+    >>> lm_model = TransformerLM(
+    ...     vocab=vocab_size,
+    ...     d_model=d_model,
+    ...     nhead=8,
+    ...     num_encoder_layers=1,
+    ...     num_decoder_layers=0,
+    ...     d_ffn=256,
+    ...     activation=torch.nn.GELU,
+    ... )
+    >>> n_channels=6
+    >>> ctc_lin = Linear(input_size=d_model, n_neurons=vocab_size)
+    >>> seq_lin = Linear(input_size=d_model, n_neurons=vocab_size)
+    >>> eos_index = 2
+    >>> ctc_scorer = CTCScorer(
+    ...     ctc_fc=ctc_lin,
+    ...     blank_index=0,
+    ...     eos_index=eos_index,
+    ... )
+    >>> transformerlm_scorer = TransformerLMScorer(
+    ...     language_model=lm_model,
+    ...     temperature=1.15,
+    ... )
+    >>> ctc_weight_decode=0.4
+    >>> lm_weight=0.6
+    >>> scorer = ScorerBuilder(
+    ...     full_scorers=[transformerlm_scorer, ctc_scorer],
+    ...     weights={'transformerlm': lm_weight, 'ctc': ctc_weight_decode}
+    ... )
+    >>> beam_size=5
+    >>> searcher = S2STransformerBeamSearcher(
+    ...     modules=[net, seq_lin],
+    ...     bos_index=1,
+    ...     eos_index=eos_index,
+    ...     min_decode_ratio=0.0,
+    ...     max_decode_ratio=1.0,
+    ...     using_eos_threshold=False,
+    ...     beam_size=beam_size,
+    ...     temperature=1.15,
+    ...     scorer=scorer
+    ... )
+    >>> batch_size=2
+    >>> wav_len = torch.ones([batch_size])
+    >>> src = torch.rand([batch_size, n_channels, input_size])
+    >>> tgt = torch.randint(0, vocab_size, [batch_size, n_channels])
+    >>> enc, dec = net.forward(src, tgt)
+    >>> hyps, _, _, _ = searcher(enc, wav_len)
+    """
+
+    def __init__(self, language_model, temperature=1.0):
+        self.lm = language_model
+        self.lm.eval()
+        self.temperature = temperature
+        self.softmax = sb.nnet.activations.Softmax(apply_log=True)
+
+    def score(self, inp_tokens, memory, candidates, attn):
+        """This method scores the new beams based on the
+        TransformerLM scores computed over the previous tokens.
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The input tensor of the current timestep.
+        memory : No limit
+            The scorer states for this timestep.
+        candidates : torch.Tensor
+            (batch_size x beam_size, scorer_beam_size).
+            The top-k candidates to be scored after the full scorers.
+            If None, scorers will score on full vocabulary set.
+        attn : torch.Tensor
+            The attention weight to be used in CoverageScorer or CTCScorer.
+        """
+        with torch.no_grad():
+            if memory is None:
+                memory = torch.empty(
+                    inp_tokens.size(0), 0, device=inp_tokens.device
+                )
+            # Append the predicted token of the previous step to existing memory.
+            memory = torch.cat([memory, inp_tokens.unsqueeze(1)], dim=-1)
+            if not next(self.lm.parameters()).is_cuda:
+                self.lm.to(inp_tokens.device)
+            logits = self.lm(memory)
+            log_probs = self.softmax(logits / self.temperature)
+        return log_probs[:, -1, :], memory
+
+    def permute_mem(self, memory, index):
+        """This method permutes the scorer memory to synchronize
+        the memory index with the current output and perform
+        batched beam search.
+
+        Arguments
+        ---------
+        memory : No limit
+            The memory variables input for this timestep.
+        index : torch.Tensor
+            (batch_size, beam_size). The index of the previous path.
+        """
+        memory = torch.index_select(memory, dim=0, index=index)
+        return memory
+
+    def reset_mem(self, x, enc_lens):
+        """This method implement the resetting of
+        memory variables for the RNNLM scorer.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            The precomputed encoder states to be used when decoding.
+            (ex. the encoded speech representation to be attended).
+        enc_lens : torch.Tensor
+            The speechbrain-style relative length.
+        """
+        return None
+
+
+class KenLMScorer(BaseScorerInterface):
+    """KenLM N-gram scorer.
+
+    This scorer is based on KenLM, which is a fast and efficient
+    N-gram language model toolkit. It is used to provide the n-gram scores
+    of the next input tokens.
+
+    This scorer is dependent on the KenLM package. It can be installed
+    with the following command:
+            > pip install https://github.com/kpu/kenlm/archive/master.zip
+
+    Note: The KenLM scorer is computationally expensive. It is recommended
+    to use it as a partial scorer to score on the top-k candidates instead
+    of the full vocabulary set.
+
+    Arguments
+    ---------
+    lm_path : str
+        The path of ngram model.
+    vocab_size: int
+        The total number of tokens.
+    token_list : list
+        The tokens set.
+
+    # Example
+    # -------
+    # >>> from speechbrain.nnet.linear import Linear
+    # >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder
+    # >>> from speechbrain.decoders import S2SRNNBeamSearcher, KenLMScorer, ScorerBuilder
+    # >>> input_size=17
+    # >>> vocab_size=11
+    # >>> lm_path='path/to/kenlm_model.arpa' # or .bin
+    # >>> token_list=['<pad>', '<bos>', '<eos>', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
+    # >>> emb = torch.nn.Embedding(
+    # ...     embedding_dim=input_size,
+    # ...     num_embeddings=vocab_size,
+    # ... )
+    # >>> d_model=7
+    # >>> dec = AttentionalRNNDecoder(
+    # ...     rnn_type="gru",
+    # ...     attn_type="content",
+    # ...     hidden_size=3,
+    # ...     attn_dim=3,
+    # ...     num_layers=1,
+    # ...     enc_dim=d_model,
+    # ...     input_size=input_size,
+    # ... )
+    # >>> n_channels=3
+    # >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size)
+    # >>> kenlm_weight = 0.4
+    # >>> kenlm_model = KenLMScorer(
+    # ...     lm_path=lm_path,
+    # ...     vocab_size=vocab_size,
+    # ...     token_list=token_list,
+    # ... )
+    # >>> scorer = ScorerBuilder(
+    # ...     full_scorers=[kenlm_model],
+    # ...     weights={'kenlm': kenlm_weight}
+    # ... )
+    # >>> beam_size=5
+    # >>> searcher = S2SRNNBeamSearcher(
+    # ...     embedding=emb,
+    # ...     decoder=dec,
+    # ...     linear=seq_lin,
+    # ...     bos_index=1,
+    # ...     eos_index=2,
+    # ...     min_decode_ratio=0.0,
+    # ...     max_decode_ratio=1.0,
+    # ...     topk=2,
+    # ...     using_eos_threshold=False,
+    # ...     beam_size=beam_size,
+    # ...     temperature=1.25,
+    # ...     scorer=scorer
+    # ... )
+    # >>> batch_size=2
+    # >>> enc = torch.rand([batch_size, n_channels, d_model])
+    # >>> wav_len = torch.ones([batch_size])
+    # >>> hyps, _, _, _ = searcher(enc, wav_len)
+    """
+
+    def __init__(self, lm_path, vocab_size, token_list):
+        try:
+            import kenlm
+
+            self.kenlm = kenlm
+        except ImportError:
+            MSG = """Couldn't import KenLM
+            It is an optional dependency; it is not installed with SpeechBrain
+            by default. Install it with:
+            > pip install https://github.com/kpu/kenlm/archive/master.zip
+            """
+            raise ImportError(MSG)
+        self.lm = self.kenlm.Model(lm_path)
+        self.vocab_size = vocab_size
+        self.full_candidates = np.arange(self.vocab_size)
+        self.minus_inf = -1e20
+        if len(token_list) != vocab_size:
+            MSG = "The size of the token_list and vocab_size are not matched."
+            raise ValueError(MSG)
+        self.id2char = token_list
+
+    def score(self, inp_tokens, memory, candidates, attn):
+        """This method scores the new beams based on the
+        n-gram scores.
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The input tensor of the current timestep.
+        memory : No limit
+            The scorer states for this timestep.
+        candidates : torch.Tensor
+            (batch_size x beam_size, scorer_beam_size).
+            The top-k candidates to be scored after the full scorers.
+            If None, scorers will score on full vocabulary set.
+        attn : torch.Tensor
+            The attention weight to be used in CoverageScorer or CTCScorer.
+        """
+        n_bh = inp_tokens.size(0)
+        scale = 1.0 / np.log10(np.e)
+
+        if memory is None:
+            state = self.kenlm.State()
+            state = np.array([state] * n_bh)
+            scoring_table = np.ones(n_bh)
+        else:
+            state, scoring_table = memory
+
+        # Perform full scorer mode, not recommend
+        if candidates is None:
+            candidates = [self.full_candidates] * n_bh
+
+        # Store new states and scores
+        scores = np.ones((n_bh, self.vocab_size)) * self.minus_inf
+        new_memory = np.zeros((n_bh, self.vocab_size), dtype=object)
+        new_scoring_table = np.ones((n_bh, self.vocab_size)) * -1
+        # Scoring
+        for i in range(n_bh):
+            if scoring_table[i] == -1:
+                continue
+            parent_state = state[i]
+            for token_id in candidates[i]:
+                char = self.id2char[token_id.item()]
+                out_state = self.kenlm.State()
+                score = scale * self.lm.BaseScore(parent_state, char, out_state)
+                scores[i, token_id] = score
+                new_memory[i, token_id] = out_state
+                new_scoring_table[i, token_id] = 1
+        scores = torch.from_numpy(scores).float().to(inp_tokens.device)
+        return scores, (new_memory, new_scoring_table)
+
+    def permute_mem(self, memory, index):
+        """This method permutes the scorer memory to synchronize
+        the memory index with the current output and perform
+        batched beam search.
+
+        Arguments
+        ---------
+        memory : No limit
+            The memory variables input for this timestep.
+        index : torch.Tensor
+            (batch_size, beam_size). The index of the previous path.
+        """
+        state, scoring_table = memory
+
+        index = index.cpu().numpy()
+        # The first index of each sentence.
+        beam_size = index.shape[1]
+        beam_offset = self.batch_index * beam_size
+        hyp_index = (
+            index
+            + np.broadcast_to(np.expand_dims(beam_offset, 1), index.shape)
+            * self.vocab_size
+        )
+        hyp_index = hyp_index.reshape(-1)
+        # Update states
+        state = state.reshape(-1)
+        state = state[hyp_index]
+        scoring_table = scoring_table.reshape(-1)
+        scoring_table = scoring_table[hyp_index]
+        return state, scoring_table
+
+    def reset_mem(self, x, enc_lens):
+        """This method implement the resetting of
+        memory variables for the KenLM scorer.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            The precomputed encoder states to be used when decoding.
+            (ex. the encoded speech representation to be attended).
+        enc_lens : torch.Tensor
+            The speechbrain-style relative length.
+        """
+        state = self.kenlm.State()
+        self.lm.NullContextWrite(state)
+        self.batch_index = np.arange(x.size(0))
+        return None
+
+
+class CoverageScorer(BaseScorerInterface):
+    """A coverage penalty scorer to prevent looping of hyps,
+    where ```coverage``` is the cumulative attention probability vector.
+    Reference: https://arxiv.org/pdf/1612.02695.pdf,
+               https://arxiv.org/pdf/1808.10792.pdf
+
+    Arguments
+    ---------
+    vocab_size: int
+        The total number of tokens.
+    threshold: float
+        The penalty increases when the coverage of a frame is more
+        than given threshold. (default: 0.5)
+
+    Example
+    -------
+    >>> from speechbrain.nnet.linear import Linear
+    >>> from speechbrain.lobes.models.RNNLM import RNNLM
+    >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder
+    >>> from speechbrain.decoders import S2SRNNBeamSearcher, RNNLMScorer, CoverageScorer, ScorerBuilder
+    >>> input_size=17
+    >>> vocab_size=11
+    >>> emb = torch.nn.Embedding(
+    ...     num_embeddings=vocab_size,
+    ...     embedding_dim=input_size
+    ... )
+    >>> d_model=7
+    >>> dec = AttentionalRNNDecoder(
+    ...     rnn_type="gru",
+    ...     attn_type="content",
+    ...     hidden_size=3,
+    ...     attn_dim=3,
+    ...     num_layers=1,
+    ...     enc_dim=d_model,
+    ...     input_size=input_size,
+    ... )
+    >>> n_channels=3
+    >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size)
+    >>> lm_weight = 0.4
+    >>> coverage_penalty = 1.0
+    >>> lm_model = RNNLM(
+    ...     embedding_dim=d_model,
+    ...     output_neurons=vocab_size,
+    ...     dropout=0.0,
+    ...     rnn_neurons=128,
+    ...     dnn_neurons=64,
+    ...     return_hidden=True,
+    ... )
+    >>> rnnlm_scorer = RNNLMScorer(
+    ...     language_model=lm_model,
+    ...     temperature=1.25,
+    ... )
+    >>> coverage_scorer = CoverageScorer(vocab_size=vocab_size)
+    >>> scorer = ScorerBuilder(
+    ...     full_scorers=[rnnlm_scorer, coverage_scorer],
+    ...     weights={'rnnlm': lm_weight, 'coverage': coverage_penalty}
+    ... )
+    >>> beam_size=5
+    >>> searcher = S2SRNNBeamSearcher(
+    ...     embedding=emb,
+    ...     decoder=dec,
+    ...     linear=seq_lin,
+    ...     bos_index=1,
+    ...     eos_index=2,
+    ...     min_decode_ratio=0.0,
+    ...     max_decode_ratio=1.0,
+    ...     topk=2,
+    ...     using_eos_threshold=False,
+    ...     beam_size=beam_size,
+    ...     temperature=1.25,
+    ...     scorer=scorer
+    ... )
+    >>> batch_size=2
+    >>> enc = torch.rand([batch_size, n_channels, d_model])
+    >>> wav_len = torch.ones([batch_size])
+    >>> hyps, _, _, _ = searcher(enc, wav_len)
+    """
+
+    def __init__(self, vocab_size, threshold=0.5):
+        self.vocab_size = vocab_size
+        self.threshold = threshold
+        # Use time_step to normalize the coverage over steps
+        self.time_step = 0
+
+    def score(self, inp_tokens, coverage, candidates, attn):
+        """This method scores the new beams based on the
+        Coverage scorer.
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The input tensor of the current timestep.
+        coverage : No limit
+            The scorer states for this timestep.
+        candidates : torch.Tensor
+            (batch_size x beam_size, scorer_beam_size).
+            The top-k candidates to be scored after the full scorers.
+            If None, scorers will score on full vocabulary set.
+        attn : torch.Tensor
+            The attention weight to be used in CoverageScorer or CTCScorer.
+        """
+        n_bh = attn.size(0)
+        self.time_step += 1
+
+        if coverage is None:
+            coverage = torch.zeros_like(attn, device=attn.device)
+
+        # Current coverage
+        if len(attn.size()) > 2:
+            # the attn of transformer is [batch_size x beam_size, current_step, source_len]
+            coverage = torch.sum(attn, dim=1)
+        else:
+            coverage = coverage + attn
+
+        # Compute coverage penalty and add it to scores
+        penalty = torch.max(
+            coverage, coverage.clone().fill_(self.threshold)
+        ).sum(-1)
+        penalty = penalty - coverage.size(-1) * self.threshold
+        penalty = penalty.view(n_bh).unsqueeze(1).expand(-1, self.vocab_size)
+        return -1 * penalty / self.time_step, coverage
+
+    def permute_mem(self, coverage, index):
+        """This method permutes the scorer memory to synchronize
+        the memory index with the current output and perform
+        batched beam search.
+
+        Arguments
+        ---------
+        coverage : No limit
+            The memory variables input for this timestep.
+        index : torch.Tensor
+            (batch_size, beam_size). The index of the previous path.
+        """
+        # Update coverage
+        coverage = torch.index_select(coverage, dim=0, index=index)
+        return coverage
+
+    def reset_mem(self, x, enc_lens):
+        """This method implement the resetting of
+        memory variables for the RNNLM scorer.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            The precomputed encoder states to be used when decoding.
+            (ex. the encoded speech representation to be attended).
+        enc_lens : torch.Tensor
+            The speechbrain-style relative length.
+        """
+        self.time_step = 0
+        return None
+
+
+class LengthScorer(BaseScorerInterface):
+    """A length rewarding scorer.
+
+    The LengthScorer is used to provide the length rewarding scores.
+    It is used to prevent the beam search from favoring short hypotheses.
+
+    Note: length_normalization is not compatible with this scorer. Make sure
+    to set is to False when using LengthScorer.
+
+    Arguments
+    ---------
+    vocab_size: int
+        The total number of tokens.
+
+    Example
+    -------
+    >>> from speechbrain.nnet.linear import Linear
+    >>> from speechbrain.lobes.models.RNNLM import RNNLM
+    >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder
+    >>> from speechbrain.decoders import S2SRNNBeamSearcher, RNNLMScorer, CoverageScorer, ScorerBuilder
+    >>> input_size=17
+    >>> vocab_size=11
+    >>> emb = torch.nn.Embedding(
+    ...     num_embeddings=vocab_size,
+    ...     embedding_dim=input_size
+    ... )
+    >>> d_model=7
+    >>> dec = AttentionalRNNDecoder(
+    ...     rnn_type="gru",
+    ...     attn_type="content",
+    ...     hidden_size=3,
+    ...     attn_dim=3,
+    ...     num_layers=1,
+    ...     enc_dim=d_model,
+    ...     input_size=input_size,
+    ... )
+    >>> n_channels=3
+    >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size)
+    >>> lm_weight = 0.4
+    >>> length_weight = 1.0
+    >>> lm_model = RNNLM(
+    ...     embedding_dim=d_model,
+    ...     output_neurons=vocab_size,
+    ...     dropout=0.0,
+    ...     rnn_neurons=128,
+    ...     dnn_neurons=64,
+    ...     return_hidden=True,
+    ... )
+    >>> rnnlm_scorer = RNNLMScorer(
+    ...     language_model=lm_model,
+    ...     temperature=1.25,
+    ... )
+    >>> length_scorer = LengthScorer(vocab_size=vocab_size)
+    >>> scorer = ScorerBuilder(
+    ...     full_scorers=[rnnlm_scorer, length_scorer],
+    ...     weights={'rnnlm': lm_weight, 'length': length_weight}
+    ... )
+    >>> beam_size=5
+    >>> searcher = S2SRNNBeamSearcher(
+    ...     embedding=emb,
+    ...     decoder=dec,
+    ...     linear=seq_lin,
+    ...     bos_index=1,
+    ...     eos_index=2,
+    ...     min_decode_ratio=0.0,
+    ...     max_decode_ratio=1.0,
+    ...     topk=2,
+    ...     using_eos_threshold=False,
+    ...     beam_size=beam_size,
+    ...     temperature=1.25,
+    ...     length_normalization=False,
+    ...     scorer=scorer
+    ... )
+    >>> batch_size=2
+    >>> enc = torch.rand([batch_size, n_channels, d_model])
+    >>> wav_len = torch.ones([batch_size])
+    >>> hyps, _, _, _ = searcher(enc, wav_len)
+    """
+
+    def __init__(self, vocab_size):
+        self.vocab_size = vocab_size
+
+    def score(self, inp_tokens, memory, candidates, attn):
+        """This method scores the new beams based on the
+        Length scorer.
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The input tensor of the current timestep.
+        memory : No limit
+            The scorer states for this timestep.
+        candidates : torch.Tensor
+            (batch_size x beam_size, scorer_beam_size).
+            The top-k candidates to be scored after the full scorers.
+            If None, scorers will score on full vocabulary set.
+        attn : torch.Tensor
+            The attention weight to be used in CoverageScorer or CTCScorer.
+        """
+        return (
+            torch.tensor(
+                [1.0], device=inp_tokens.device, dtype=inp_tokens.dtype
+            ).expand(inp_tokens.size(0), self.vocab_size),
+            None,
+        )
+
+
+class ScorerBuilder:
+    """ Builds scorer instance for beamsearch.
+
+    The ScorerBuilder class is responsible for building a scorer instance for
+    beam search. It takes weights for full and partial scorers, as well as
+    instances of full and partial scorer classes. It combines the scorers based
+    on the weights specified and provides methods for scoring tokens, permuting
+    scorer memory, and resetting scorer memory.
+
+    This is the class to be used for building scorer instances for beam search.
+
+    See speechbrain.decoders.seq2seq.S2SBeamSearcher()
+
+    Arguments
+    ---------
+    weights : dict
+        Weights of full/partial scorers specified.
+    full_scorers : list
+        Scorers that score on full vocabulary set.
+    partial_scorers : list
+        Scorers that score on pruned tokens to prevent computation overhead.
+        Partial scoring is performed after full scorers.
+    scorer_beam_scale : float
+        The scale decides the number of pruned tokens for partial scorers:
+        int(beam_size * scorer_beam_scale).
+
+    Example
+    -------
+    >>> from speechbrain.nnet.linear import Linear
+    >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
+    >>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM
+    >>> from speechbrain.decoders import S2STransformerBeamSearcher, TransformerLMScorer, CoverageScorer, CTCScorer, ScorerBuilder
+    >>> input_size=17
+    >>> vocab_size=11
+    >>> d_model=128
+    >>> net = TransformerASR(
+    ...     tgt_vocab=vocab_size,
+    ...     input_size=input_size,
+    ...     d_model=d_model,
+    ...     nhead=8,
+    ...     num_encoder_layers=1,
+    ...     num_decoder_layers=1,
+    ...     d_ffn=256,
+    ...     activation=torch.nn.GELU
+    ... )
+    >>> lm_model = TransformerLM(
+    ...     vocab=vocab_size,
+    ...     d_model=d_model,
+    ...     nhead=8,
+    ...     num_encoder_layers=1,
+    ...     num_decoder_layers=0,
+    ...     d_ffn=256,
+    ...     activation=torch.nn.GELU,
+    ... )
+    >>> n_channels=6
+    >>> ctc_lin = Linear(input_size=d_model, n_neurons=vocab_size)
+    >>> seq_lin = Linear(input_size=d_model, n_neurons=vocab_size)
+    >>> eos_index = 2
+    >>> ctc_scorer = CTCScorer(
+    ...     ctc_fc=ctc_lin,
+    ...     blank_index=0,
+    ...     eos_index=eos_index,
+    ... )
+    >>> transformerlm_scorer = TransformerLMScorer(
+    ...     language_model=lm_model,
+    ...     temperature=1.15,
+    ... )
+    >>> coverage_scorer = CoverageScorer(vocab_size=vocab_size)
+    >>> ctc_weight_decode=0.4
+    >>> lm_weight=0.6
+    >>> coverage_penalty = 1.0
+    >>> scorer = ScorerBuilder(
+    ...     full_scorers=[transformerlm_scorer, coverage_scorer],
+    ...     partial_scorers=[ctc_scorer],
+    ...     weights={'transformerlm': lm_weight, 'ctc': ctc_weight_decode, 'coverage': coverage_penalty}
+    ... )
+    >>> beam_size=5
+    >>> searcher = S2STransformerBeamSearcher(
+    ...     modules=[net, seq_lin],
+    ...     bos_index=1,
+    ...     eos_index=eos_index,
+    ...     min_decode_ratio=0.0,
+    ...     max_decode_ratio=1.0,
+    ...     using_eos_threshold=False,
+    ...     beam_size=beam_size,
+    ...     topk=3,
+    ...     temperature=1.15,
+    ...     scorer=scorer
+    ... )
+    >>> batch_size=2
+    >>> wav_len = torch.ones([batch_size])
+    >>> src = torch.rand([batch_size, n_channels, input_size])
+    >>> tgt = torch.randint(0, vocab_size, [batch_size, n_channels])
+    >>> enc, dec = net.forward(src, tgt)
+    >>> hyps, _, _, _  = searcher(enc, wav_len)
+    """
+
+    def __init__(
+        self,
+        weights=dict(),
+        full_scorers=list(),
+        partial_scorers=list(),
+        scorer_beam_scale=2,
+    ):
+        assert len(weights) == len(full_scorers) + len(
+            partial_scorers
+        ), "Weights and scorers are not matched."
+
+        self.scorer_beam_scale = scorer_beam_scale
+        all_scorer_names = [
+            k.lower().split("scorer")[0]
+            for k in globals().keys()
+            if k.endswith("Scorer")
+        ]
+        full_scorer_names = [
+            impl.__class__.__name__.lower().split("scorer")[0]
+            for impl in full_scorers
+        ]
+        partial_scorer_names = [
+            impl.__class__.__name__.lower().split("scorer")[0]
+            for impl in partial_scorers
+        ]
+
+        # Have a default 0.0 weight for scorer not specified
+        init_weights = {k: 0.0 for k in all_scorer_names}
+        self.weights = {**init_weights, **weights}
+        self.full_scorers = dict(zip(full_scorer_names, full_scorers))
+        self.partial_scorers = dict(zip(partial_scorer_names, partial_scorers))
+
+        # Check if scorers are valid
+        self._validate_scorer(all_scorer_names)
+
+    def score(self, inp_tokens, memory, attn, log_probs, beam_size):
+        """This method scores tokens in vocabulary based on defined full scorers
+        and partial scorers. Scores will be added to the log probs for beamsearch.
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            See BaseScorerInterface().
+        memory : dict[str, scorer memory]
+            The states of scorers for this timestep.
+        attn : torch.Tensor
+            See BaseScorerInterface().
+        log_probs : torch.Tensor
+            (batch_size x beam_size, vocab_size). The log probs at this timestep.
+        beam_size : int
+            The beam size.
+
+        Returns
+        ---------
+        log_probs : torch.Tensor
+            (batch_size x beam_size, vocab_size). Log probs updated by scorers.
+        new_memory : dict[str, scorer memory]
+            The updated states of scorers.
+        """
+        new_memory = dict()
+        # score full candidates
+        for k, impl in self.full_scorers.items():
+            if k == "ctc":
+                # block blank token if CTC is used
+                log_probs[:, impl.blank_index] = impl.ctc_score.minus_inf
+
+            score, new_memory[k] = impl.score(inp_tokens, memory[k], None, attn)
+            log_probs += score * self.weights[k]
+
+        # select candidates from the results of full scorers for partial scorers
+        _, candidates = log_probs.topk(
+            int(beam_size * self.scorer_beam_scale), dim=-1
+        )
+
+        # score pruned tokens candidates
+        for k, impl in self.partial_scorers.items():
+            score, new_memory[k] = impl.score(
+                inp_tokens, memory[k], candidates, attn
+            )
+            log_probs += score * self.weights[k]
+
+        return log_probs, new_memory
+
+    def permute_scorer_mem(self, memory, index, candidates):
+        """Update memory variables of scorers to synchronize
+        the memory index with the current output and perform
+        batched beam search.
+
+        Arguments
+        ---------
+        memory : dict[str, scorer memory]
+            The states of scorers for this timestep.
+        index : torch.Tensor
+            (batch_size x beam_size). The index of the previous path.
+        candidates : torch.Tensor
+            (batch_size, beam_size). The index of the topk candidates.
+        """
+        for k, impl in self.full_scorers.items():
+            # ctc scorer should always be scored by candidates
+            if k == "ctc" or k == "kenlm":
+                memory[k] = impl.permute_mem(memory[k], candidates)
+                continue
+            memory[k] = impl.permute_mem(memory[k], index)
+        for k, impl in self.partial_scorers.items():
+            memory[k] = impl.permute_mem(memory[k], candidates)
+        return memory
+
+    def reset_scorer_mem(self, x, enc_lens):
+        """Reset memory variables for scorers.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            See BaseScorerInterface().
+        wav_len : torch.Tensor
+            See BaseScorerInterface().
+        """
+        memory = dict()
+        for k, impl in {**self.full_scorers, **self.partial_scorers}.items():
+            memory[k] = impl.reset_mem(x, enc_lens)
+        return memory
+
+    def _validate_scorer(self, scorer_names):
+        """These error messages indicate scorers are not properly set.
+
+        Arguments
+        ---------
+        scorer_names : list
+            Prefix of scorers defined in speechbrain.decoders.scorer.
+        """
+        if len(self.weights) > len(scorer_names):
+            raise ValueError(
+                "The keys of weights should be named in {}".format(scorer_names)
+            )
+
+        if not 0.0 <= self.weights["ctc"] <= 1.0:
+            raise ValueError("ctc_weight should not > 1.0 and < 0.0")
+
+        if self.weights["ctc"] == 1.0:
+            if "ctc" not in self.full_scorers.keys():
+                raise ValueError(
+                    "CTC scorer should be a full scorer when it's weight is 1.0"
+                )
+            if self.weights["coverage"] > 0.0:
+                raise ValueError(
+                    "Pure CTC scorer doesn't have attention weights for coverage scorer"
+                )
+
+
+class BaseRescorerInterface(BaseScorerInterface):
+    """A scorer abstraction intended for inheritance by other scoring approaches used in beam search.
+
+    In this approach, a neural network is employed to assign scores to potential text transcripts.
+    The beam search decoding process produces a collection of the top K hypotheses.
+    These candidates are subsequently sent to a language model (LM) for ranking.
+    The ranking is carried out by the LM, which assigns a score to each candidate.
+
+    The score is computed as follows:
+
+    score = beam_search_score + lm_weight * rescorer_score
+
+    See:
+        - speechbrain.decoders.scorer.RNNLMRescorer
+        - speechbrain.decoders.scorer.TransformerLMRescorer
+        - speechbrain.decoders.scorer.HuggingFaceLMRescorer
+    """
+
+    def normalize_text(self, text):
+        """This method should implement the normalization of the text before scoring.
+
+        Arguments
+        ---------
+        text : list of str
+            The text to be normalized.
+        """
+        return text
+
+    def preprocess_func(self, hyps):
+        """This method should implement the preprocessing of the hypotheses before scoring.
+
+        Arguments
+        ---------
+        hyps : list of str
+            The hypotheses to be preprocessed.
+        """
+        raise NotImplementedError
+
+    def rescore_hyps(self, hyps):
+        """This method should implement the rescoring of the hypotheses.
+
+        Arguments
+        ---------
+        hyps : list of str
+            The hypotheses to be rescored.
+        """
+        raise NotImplementedError
+
+    def to_device(self, device=None):
+        """This method should implement the moving of the scorer to a device.
+
+        If device is None, the scorer should be moved to the default device provided
+        in the constructor.
+
+        Arguments
+        ---------
+        device : str
+            The device to move the scorer to.
+        """
+        raise NotImplementedError
+
+
+class RNNLMRescorer(BaseRescorerInterface):
+    """A wrapper of RNNLM based on the BaseRescorerInterface.
+
+    Arguments
+    ---------
+    language_model : torch.nn.Module
+        A RNN-based language model.
+    tokenizer : SentencePieceProcessor
+        A SentencePiece tokenizer.
+    device : str
+        The device to move the scorer to.
+    temperature : float
+        Temperature factor applied to softmax. It changes the probability
+        distribution, being softer when T>1 and sharper with T<1. (default: 1.0)
+    bos_index : int
+        The index of the beginning-of-sequence (bos) token.
+    eos_index : int
+        The index of the end-of-sequence (eos) token.
+    pad_index : int
+        The index of the padding token.
+
+    NOTE
+    ----
+    This class is intented to be used with a pretrained TransformerLM model.
+    Please see: https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech
+
+    By default, this model is using SentencePiece tokenizer.
+
+    Example
+    -------
+    >>> import torch
+    >>> from sentencepiece import SentencePieceProcessor
+    >>> from speechbrain.lobes.models.RNNLM import RNNLM
+    >>> from speechbrain.utils.parameter_transfer import Pretrainer
+    >>> source = "speechbrain/asr-crdnn-rnnlm-librispeech"
+    >>> lm_model_path = source + "/lm.ckpt"
+    >>> tokenizer_path = source + "/tokenizer.ckpt"
+    >>> # define your tokenizer and RNNLM from the HF hub
+    >>> tokenizer = SentencePieceProcessor()
+    >>> lm_model = RNNLM(
+    ...    output_neurons = 1000,
+    ...    embedding_dim = 128,
+    ...    activation = torch.nn.LeakyReLU,
+    ...    dropout = 0.0,
+    ...    rnn_layers = 2,
+    ...    rnn_neurons = 2048,
+    ...    dnn_blocks = 1,
+    ...    dnn_neurons = 512,
+    ...    return_hidden = True,
+    ... )
+    >>> pretrainer = Pretrainer(
+    ...     collect_in = getfixture("tmp_path"),
+    ...    loadables = {
+    ...     "lm" : lm_model,
+    ...     "tokenizer" : tokenizer,
+    ...     },
+    ...    paths = {
+    ...     "lm" : lm_model_path,
+    ...     "tokenizer" : tokenizer_path,
+    ... })
+    >>> _ = pretrainer.collect_files()
+    >>> pretrainer.load_collected()
+    >>> from speechbrain.decoders.scorer import RNNLMRescorer, RescorerBuilder
+    >>> rnnlm_rescorer = RNNLMRescorer(
+    ...    language_model = lm_model,
+    ...    tokenizer = tokenizer,
+    ...    temperature = 1.0,
+    ...    bos_index = 0,
+    ...    eos_index = 0,
+    ...    pad_index = 0,
+    ... )
+    >>> # Define a rescorer builder
+    >>> rescorer = RescorerBuilder(
+    ...    rescorers=[rnnlm_rescorer],
+    ...    weights={"rnnlm":1.0}
+    ... )
+    >>> # topk hyps
+    >>> topk_hyps = [["HELLO", "HE LLO", "H E L L O"]]
+    >>> topk_scores = [[-2, -2, -2]]
+    >>> rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores)
+    >>> # NOTE: the returned hypotheses are already sorted by score.
+    >>> rescored_hyps # doctest: +SKIP
+    [['HELLO', 'H E L L O', 'HE LLO']]
+    >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better.
+    >>> rescored_scores # doctest: +SKIP
+    [[-17.863974571228027, -25.12890625, -26.075977325439453]]
+    """
+
+    def __init__(
+        self,
+        language_model,
+        tokenizer,
+        device="cuda",
+        temperature=1.0,
+        bos_index=0,
+        eos_index=0,
+        pad_index=0,
+    ):
+        self.lm = language_model
+        self.lm.eval()
+        self.tokenizer = tokenizer
+        self.temperature = temperature
+        self.softmax = sb.nnet.activations.Softmax(apply_log=True)
+
+        self.device = device
+        self.bos_index = bos_index
+        self.eos_index = eos_index
+        self.pad_index = pad_index
+
+    def normalize_text(self, text):
+        """This method should implement the normalization of the text before scoring.
+
+        Default to uppercasing the text because the (current) language models are trained on
+        LibriSpeech which is all uppercase.
+
+        Arguments
+        ---------
+        text : str
+            The text to be normalized.
+
+        Returns
+        -------
+        str
+            The normalized text.
+        """
+        return text.upper()
+
+    def to_device(self, device=None):
+        """This method moves the scorer to a device.
+
+        If device is None, the scorer is moved to the default device provided
+        in the constructor.
+
+        Arguments
+        ---------
+        device : str
+            The device to move the scorer to.
+        """
+        if device is None:
+            self.lm.to(self.device)
+        else:
+            self.lm.to(device)
+
+    def preprocess_func(self, topk_hyps):
+        """This method preprocesses the hypotheses before scoring.
+
+        Arguments
+        ---------
+        topk_hyps : list of list of str
+            The hypotheses to be preprocessed.
+
+        Returns
+        -------
+        padded_hyps : torch.Tensor
+            The padded hypotheses.
+        enc_hyps_length : list of int
+            The length of each hypothesis.
+        """
+        # 1. normalize text
+        decoded_seq = []
+        for batch in topk_hyps:
+            for seq in batch:
+                decoded_seq.append(self.normalize_text(seq))
+
+        # 2. encode text
+        enc_hyps = []
+        for seq in decoded_seq:
+            enc_hyps.append(
+                torch.tensor(
+                    [self.bos_index]
+                    + self.tokenizer.encode_as_ids(seq)
+                    + [self.eos_index]
+                )
+            )
+
+        enc_hyps_length = [enc_seq.shape[0] for enc_seq in enc_hyps]
+
+        # 3. pad sequences
+        padded_hyps = torch.nn.utils.rnn.pad_sequence(
+            enc_hyps, batch_first=True, padding_value=self.pad_index
+        ).to(self.lm.parameters().__next__().device)
+
+        return padded_hyps, enc_hyps_length
+
+    @torch.no_grad()
+    def rescore_hyps(self, topk_hyps):
+        """This method implement the rescoring of the hypotheses.
+
+        Arguments
+        ---------
+        topk_hyps : list of list of str
+            The hypotheses to be rescored.
+
+        Returns
+        -------
+        log_probs_scores : torch.Tensor[B * Topk, 1]
+            The rescored hypotheses scores
+        """
+        # preprocess hypotheses
+        padded_hyps, enc_hyps_length = self.preprocess_func(topk_hyps)
+
+        bool_mask = [
+            [1 if i < length else 0 for i in range(max(enc_hyps_length))]
+            for length in enc_hyps_length
+        ]
+
+        bool_mask_tensor = torch.tensor(
+            bool_mask, dtype=torch.bool, device=padded_hyps.device
+        )
+
+        if not next(self.lm.parameters()).is_cuda:
+            self.lm.to(padded_hyps.device)
+
+        # compute scores
+        logits, _ = self.lm(padded_hyps)
+        log_probs = self.softmax(logits / self.temperature)
+
+        target_log_probs = (
+            log_probs[:, :-1]
+            .gather(2, padded_hyps[:, 1:].unsqueeze(2))
+            .squeeze(2)
+        )
+
+        log_probs_scores = torch.nansum(
+            target_log_probs * bool_mask_tensor[:, 1:], dim=-1
+        )
+
+        return log_probs_scores
+
+
+class TransformerLMRescorer(BaseRescorerInterface):
+    """ A wrapper of TransformerLM based on the BaseRescorerInterface.
+
+    Arguments
+    ---------
+    language_model : torch.nn.Module
+        A Transformer-based language model.
+    tokenizer : SentencePieceProcessor
+        A SentencePiece tokenizer.
+    device : str
+        The device to move the scorer to.
+    temperature : float
+        Temperature factor applied to softmax. It changes the probability
+        distribution, being softer when T>1 and sharper with T<1. (default: 1.0)
+    bos_index : int
+        The index of the beginning-of-sequence (bos) token.
+    eos_index : int
+        The index of the end-of-sequence (eos) token.
+    pad_index : int
+        The index of the padding token.
+
+    NOTE
+    ----
+    This class is intented to be used with a pretrained TransformerLM model.
+    Please see: https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech
+
+    By default, this model is using SentencePiece tokenizer.
+
+    Example
+    -------
+    >>> import torch
+    >>> from sentencepiece import SentencePieceProcessor
+    >>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM
+    >>> from speechbrain.utils.parameter_transfer import Pretrainer
+    >>> source = "speechbrain/asr-transformer-transformerlm-librispeech"
+    >>> lm_model_path = source + "/lm.ckpt"
+    >>> tokenizer_path = source + "/tokenizer.ckpt"
+    >>> tokenizer = SentencePieceProcessor()
+    >>> lm_model = TransformerLM(
+    ...     vocab=5000,
+    ...     d_model=768,
+    ...     nhead=12,
+    ...     num_encoder_layers=12,
+    ...     num_decoder_layers=0,
+    ...     d_ffn=3072,
+    ...     dropout=0.0,
+    ...     activation=torch.nn.GELU,
+    ...     normalize_before=False,
+    ... )
+    >>> pretrainer = Pretrainer(
+    ...     collect_in = getfixture("tmp_path"),
+    ...     loadables={
+    ...         "lm": lm_model,
+    ...         "tokenizer": tokenizer,
+    ...     },
+    ...     paths={
+    ...         "lm": lm_model_path,
+    ...         "tokenizer": tokenizer_path,
+    ...     }
+    ... )
+    >>> _ = pretrainer.collect_files()
+    >>> pretrainer.load_collected()
+    >>> from speechbrain.decoders.scorer import TransformerLMRescorer, RescorerBuilder
+    >>> transformerlm_rescorer = TransformerLMRescorer(
+    ...     language_model=lm_model,
+    ...     tokenizer=tokenizer,
+    ...     temperature=1.0,
+    ...     bos_index=1,
+    ...     eos_index=2,
+    ...     pad_index=0,
+    ... )
+    >>> rescorer = RescorerBuilder(
+    ...     rescorers=[transformerlm_rescorer],
+    ...     weights={"transformerlm": 1.0}
+    ... )
+    >>> topk_hyps = [["HELLO", "HE LLO", "H E L L O"]]
+    >>> topk_scores = [[-2, -2, -2]]
+    >>> rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores)
+    >>> # NOTE: the returned hypotheses are already sorted by score.
+    >>> rescored_hyps # doctest: +SKIP
+    [["HELLO", "HE L L O", "HE LLO"]]
+    >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better.
+    >>> rescored_scores  # doctest: +SKIP
+    [[-17.863974571228027, -25.12890625, -26.075977325439453]]
+    """
+
+    def __init__(
+        self,
+        language_model,
+        tokenizer,
+        device="cuda",
+        temperature=1.0,
+        bos_index=0,
+        eos_index=0,
+        pad_index=0,
+    ):
+        self.lm = language_model
+        self.lm.eval()
+
+        self.tokenizer = tokenizer
+        self.temperature = temperature
+        self.softmax = sb.nnet.activations.Softmax(apply_log=True)
+
+        self.device = device
+        self.bos_index = bos_index
+        self.eos_index = eos_index
+        self.pad_index = pad_index
+
+    def normalize_text(self, text):
+        """This method should implement the normalization of the text before scoring.
+
+        Default to uppercasing the text because the language models are trained on
+        LibriSpeech.
+
+        Arguments
+        ---------
+        text : str
+            The text to be normalized.
+
+        Returns
+        -------
+        str
+            The normalized text.
+        """
+        return text.upper()
+
+    def to_device(self, device=None):
+        """This method moves the scorer to a device.
+
+        If device is None, the scorer is moved to the default device provided
+        in the constructor.
+
+        This method is dynamically called in the recipes when the stage is equal
+        to TEST.
+
+        Arguments
+        ---------
+        device : str
+            The device to move the scorer to.
+        """
+        if device is None:
+            self.lm.to(self.device)
+        else:
+            self.lm.to(device)
+
+    def preprocess_func(self, topk_hyps):
+        """This method preprocesses the hypotheses before scoring.
+
+        Arguments
+        ---------
+        topk_hyps : list of list of str
+            The hypotheses to be preprocessed.
+
+        Returns
+        -------
+        padded_hyps : torch.Tensor
+            The padded hypotheses.
+        enc_hyps_length : list of int
+            The length of each hypothesis.
+        """
+        # 1. normalize
+        decoded_seq = []
+        for batch in topk_hyps:
+            for seq in batch:
+                decoded_seq.append(self.normalize_text(seq))
+
+        # 2. encode text
+        enc_hyps = []
+        for seq in decoded_seq:
+            enc_hyps.append(
+                torch.tensor(
+                    [self.bos_index]
+                    + self.tokenizer.encode_as_ids(seq)
+                    + [self.eos_index]
+                )
+            )
+
+        enc_hyps_length = [enc_seq.shape[0] for enc_seq in enc_hyps]
+
+        # 3. pad sequences
+        padded_hyps = torch.nn.utils.rnn.pad_sequence(
+            enc_hyps, batch_first=True, padding_value=self.pad_index
+        ).to(self.lm.parameters().__next__().device)
+
+        return padded_hyps, enc_hyps_length
+
+    @torch.no_grad()
+    def rescore_hyps(self, topk_hyps):
+        """This method implement the rescoring of the hypotheses.
+
+        Arguments
+        ---------
+        topk_hyps : list of list of str
+            The hypotheses to be rescored.
+
+        Returns
+        -------
+        log_probs_scores : torch.Tensor[B * Topk, 1]
+            The rescored hypotheses scores
+        """
+        # preprocess hypotheses
+        padded_hyps, enc_hyps_length = self.preprocess_func(topk_hyps)
+
+        bool_mask = [
+            [1 if i < length else 0 for i in range(max(enc_hyps_length))]
+            for length in enc_hyps_length
+        ]
+
+        bool_mask_tensor = torch.tensor(
+            bool_mask, dtype=torch.bool, device=padded_hyps.device
+        )
+
+        if not next(self.lm.parameters()).is_cuda:
+            self.lm.to(padded_hyps.device)
+
+        # compute scores
+        logits = self.lm(padded_hyps)
+        log_probs = self.softmax(logits / self.temperature)
+
+        log_probs[:, :, self.pad_index] = float("-inf")
+
+        target_log_probs = (
+            log_probs[:, :-1]
+            .gather(2, padded_hyps[:, 1:].unsqueeze(2))
+            .squeeze(2)
+        )
+
+        target_log_probs = target_log_probs - log_probs[:, :-1].logsumexp(
+            dim=-1
+        )
+        log_probs_scores = torch.nansum(
+            target_log_probs * bool_mask_tensor[:, 1:], dim=-1
+        )
+
+        return log_probs_scores
+
+
+class HuggingFaceLMRescorer(BaseRescorerInterface):
+    """ A wrapper of HuggingFace's TransformerLM based on the BaseRescorerInterface.
+
+    Arguments
+    ---------
+    model_name : str
+        The name of the model to be loaded.
+    device : str
+        The device to be used for scoring. (default: "cuda")
+
+    Example
+    -------
+    >>> from speechbrain.decoders.scorer import HuggingFaceLMRescorer, RescorerBuilder
+    >>> source = "gpt2-medium"
+    >>> huggingfacelm_rescorer = HuggingFaceLMRescorer(
+    ...     model_name=source,
+    ... )
+    >>> rescorer = RescorerBuilder(
+    ...     rescorers=[huggingfacelm_rescorer],
+    ...     weights={"huggingfacelm": 1.0}
+    ... )
+    >>> topk_hyps = [["Hello everyone.", "Hell o every one.", "Hello every one"]]
+    >>> topk_scores = [[-2, -2, -2]]
+    >>> rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores)
+    >>> # NOTE: the returned hypotheses are already sorted by score.
+    >>> rescored_hyps # doctest: +SKIP
+    [['Hello everyone.', 'Hello every one', 'Hell o every one.']]
+    >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better.
+    >>> rescored_scores # doctest: +SKIP
+    [[-20.03631591796875, -27.615638732910156, -42.662353515625]]
+    """
+
+    def __init__(
+        self, model_name, device="cuda",
+    ):
+        self.model_name = model_name
+        self.device = device
+
+        try:
+            from transformers import AutoModelForCausalLM, AutoTokenizer
+        except ImportError:
+            raise ImportError(
+                "Please install transformers with: pip install transformers"
+            )
+
+        self.lm = AutoModelForCausalLM.from_pretrained(
+            self.model_name, is_decoder=True
+        ).eval()
+
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            self.model_name, use_fast=True, add_special_tokens=False
+        )
+
+        if self.tokenizer.pad_token is None:
+            self.tokenizer.pad_token = "<|pad|>"
+            self.tokenizer.add_special_tokens(
+                {"additional_special_tokens": [self.tokenizer.pad_token]}
+            )
+            self.lm.resize_token_embeddings(
+                len(self.tokenizer), pad_to_multiple_of=32
+            )
+
+        self.bos_token = self.tokenizer.bos_token
+        self.eos_token = self.tokenizer.eos_token
+
+    def to_device(self, device=None):
+        """This method moves the scorer to a device.
+
+        If device is None, the scorer is moved to the default device provided
+        in the constructor.
+
+        This method is dynamically called in the recipes when the stage is equal
+        to TEST.
+
+        Arguments
+        ---------
+        device : str
+            The device to move the scorer to.
+        """
+        if device is None:
+            self.lm.to(self.device)
+        else:
+            self.lm.to(device)
+
+    def normalize_text(self, text):
+        """This method should implement the normalization of the text before scoring.
+
+        Arguments
+        ---------
+        text : str
+            The text to be normalized.
+
+        Returns
+        -------
+        normalized_text : str
+            The normalized text.
+            In this case we do not apply any normalization. However, this method
+            can be overriden to apply any normalization.
+        """
+        return text
+
+    def _add_special_tokens(self, text):
+        """This method adds the special tokens to the text.
+
+        Arguments
+        ---------
+        text : str
+            The text to be augmented.
+
+        Returns
+        -------
+        augmented_text : str
+            The augmented text.
+        """
+        return self.bos_token + text + self.eos_token
+
+    def preprocess_func(self, topk_hyps):
+        """This method preprocesses the hypotheses before scoring.
+
+        Arguments
+        ---------
+        topk_hyps : list of str
+            The hypotheses to be preprocessed.
+
+        Returns
+        -------
+        encoding : tensor
+            The encoding of the hypotheses.
+        """
+        # 1. normalize
+        normalized_hyps = []
+        for batch in topk_hyps:
+            for seq in batch:
+                normalized_hyps.append(self.normalize_text(seq))
+
+        text_augmented_with_tokens = list(
+            map(self._add_special_tokens, normalized_hyps)
+        )
+        encoding = self.tokenizer.batch_encode_plus(
+            text_augmented_with_tokens, return_tensors="pt", padding=True
+        )
+        return encoding
+
+    @torch.no_grad()
+    def rescore_hyps(self, topk_hyps):
+        """This method implement the rescoring of the hypotheses.
+
+        Arguments
+        ---------
+        topk_hyps : list of list of str
+            The hypotheses to be rescored.
+
+        Returns
+        -------
+        log_probs_scores : torch.Tensor[B * Topk, 1]
+            The rescored hypotheses scores
+        """
+        encoding = self.preprocess_func(topk_hyps)
+
+        ids = encoding["input_ids"].to(self.lm.device)
+        attention_mask = encoding["attention_mask"].to(self.lm.device)
+        logits = self.lm(ids, attention_mask=attention_mask)[0]
+
+        logits[:, :, self.tokenizer.pad_token_id :] = float("-inf")
+
+        target_log_probs = (
+            logits[:, :-1].gather(2, ids[:, 1:].unsqueeze(2)).squeeze(2)
+        )
+
+        target_log_probs = target_log_probs - logits[:, :-1].logsumexp(dim=-1)
+        log_probs_scores = torch.nansum(
+            target_log_probs * attention_mask[:, 1:], dim=-1
+        )
+
+        return log_probs_scores
+
+
+class RescorerBuilder:
+    """ Builds rescorer instance for beamsearch.
+
+    The RecorerBuilder class is responsible for building a scorer instance for
+    beam search. It takes weights and rescorers classes. It combines the scorers based
+    on the weights specified and provides methods for rescoring text.
+
+    This is the class to be used for building rescorer instances for beam search.
+
+    Arguments
+    ---------
+    weights : dict
+        Weights of rescorers specified.
+    rescorers : list
+        Rescorers that re-ranks topk hypotheses.
+    """
+
+    def __init__(
+        self, weights=dict(), rescorers=list(),
+    ):
+        assert len(weights) == len(
+            rescorers
+        ), "Weights and rescorers are not matched."
+
+        self.weights = weights
+
+        all_rescorer_names = [
+            k.lower().split("rescorer")[0]
+            for k in globals().keys()
+            if k.endswith("Rescorer")
+        ]
+        full_rescorer_names = [
+            impl.__class__.__name__.lower().split("rescorer")[0]
+            for impl in rescorers
+        ]
+
+        # Have a default 0.0 weight for scorer not specified
+        init_weights = {k: 0.0 for k in all_rescorer_names}
+        self.weights = {**init_weights, **weights}
+        self.rescorers = dict(zip(full_rescorer_names, rescorers))
+
+        self._validate_scorer(all_rescorer_names)
+
+    def rescore(self, topk_candidates, topk_scores):
+        """This method rescores the topk candidates.
+
+        Arguments
+        ---------
+        topk_candidates : list of list of str
+            The topk candidates to be rescored.
+        topk_scores : list of list of float
+            The scores of the topk candidates.
+
+        Returns
+        -------
+        output_candidates : list of list of str
+            The rescored candidates.
+        output_scores : list of list of float
+            The rescored scores.
+        """
+        new_scores = topk_scores.copy()
+
+        for k, impl in self.rescorers.items():
+            scores = impl.rescore_hyps(topk_candidates)
+
+            index_scores = 0
+            for i in range(len(new_scores)):
+                for j in range(len(new_scores[i])):
+                    new_scores[i][j] += (
+                        self.weights[k] * scores[index_scores].item()
+                    )
+                    index_scores += 1
+
+        sorted_candidates = [
+            list(
+                zip(
+                    *sorted(
+                        zip(sublist, score), key=lambda x: x[1], reverse=True
+                    )
+                )
+                for sublist, score in zip(topk_candidates, new_scores)
+            )
+        ]
+
+        output_candidates = []
+        output_scores = []
+        for sublist in sorted_candidates:
+            for item in sublist:
+                texts, scores = item
+                output_candidates.append(list(texts))
+                output_scores.append(list(scores))
+
+        return output_candidates, output_scores
+
+    def _validate_scorer(self, rescorer_names):
+        """These error messages indicate rescorers are not properly set.
+
+        Arguments
+        ---------
+        rescorer_names : list
+            Prefix of rescorers defined in speechbrain.decoders.scorer.
+        """
+        if len(self.weights) > len(rescorer_names):
+            raise ValueError(
+                "The keys of weights should be named in {}".format(
+                    rescorer_names
+                )
+            )
+
+    def move_rescorers_to_device(self, device=None):
+        """Moves rescorers to device.
+
+        Usefull to avoid having on GPU rescorers while being
+        on TRAIN and VALID Stages.
+
+        Arguments
+        ---------
+        device : str
+            The device to be used for scoring. (default: None)
+        """
+        for _, impl in self.rescorers.items():
+            impl.to_device(device)
diff --git a/speechbrain/decoders/seq2seq.py b/speechbrain/decoders/seq2seq.py
index 977537f7cb5e919d576a98fa24f5fb0e3aa6ce82..b955c6bcf9bf00dfc25a2a906734117af7786012 100644
--- a/speechbrain/decoders/seq2seq.py
+++ b/speechbrain/decoders/seq2seq.py
@@ -1,16 +1,41 @@
 """Decoding methods for seq2seq autoregressive model.
 
 Authors
- * Adel Moumen 2022
+ * Adel Moumen 2022, 2023
  * Ju-Chieh Chou 2020
  * Peter Plantinga 2020
  * Mirco Ravanelli 2020
  * Sung-Lin Yeh 2020
 """
 import torch
+from speechbrain.decoders.utils import (
+    inflate_tensor,
+    mask_by_condition,
+    _update_mem,
+)
+from speechbrain.utils.data_utils import undo_padding
 
-import speechbrain as sb
-from speechbrain.decoders.ctc import CTCPrefixScorer
+
+class AlivedHypotheses(torch.nn.Module):
+    """ This class handle the data for the hypotheses during the decoding.
+
+    Arguments
+    ---------
+    alived_seq : torch.Tensor
+        The sequence of tokens for each hypothesis.
+    alived_log_probs : torch.Tensor
+        The log probabilities of each token for each hypothesis.
+    sequence_scores : torch.Tensor
+        The sum of log probabilities for each hypothesis.
+    """
+
+    def __init__(
+        self, alived_seq, alived_log_probs, sequence_scores,
+    ):
+        super().__init__()
+        self.alived_seq = alived_seq
+        self.alived_log_probs = alived_log_probs
+        self.sequence_scores = sequence_scores
 
 
 class S2SBaseSearcher(torch.nn.Module):
@@ -22,7 +47,7 @@ class S2SBaseSearcher(torch.nn.Module):
     bos_index : int
         The index of the beginning-of-sequence (bos) token.
     eos_index : int
-        The index of end-of-sequence token.
+        The index of end-of-sequence (eos) token.
     min_decode_radio : float
         The ratio of minimum decoding steps to the length of encoder states.
     max_decode_radio : float
@@ -30,13 +55,15 @@ class S2SBaseSearcher(torch.nn.Module):
 
     Returns
     -------
-    predictions
-        Outputs as Python list of lists, with "ragged" dimensions; padding
-        has been removed.
-    scores
-        The sum of log probabilities (and possibly
-        additional heuristic scores) for each prediction.
-
+    hyps
+        The predicted tokens, as a list of lists or, if return_topk is True,
+        a Tensor of shape (batch, topk, max length of token_id sequences).
+    top_lengths
+        The length of each topk sequence in the batch.
+    top_scores
+        This final scores of topk hypotheses.
+    top_log_probs
+        The log probabilities of each hypotheses.
     """
 
     def __init__(
@@ -68,9 +95,9 @@ class S2SBaseSearcher(torch.nn.Module):
         Arguments
         ---------
         inp_tokens : torch.Tensor
-            The input tensor of the current timestep.
+            The input tensor of the current step.
         memory : No limit
-            The memory variables input for this timestep.
+            The memory variables input for this step.
             (ex. RNN hidden states).
         enc_states : torch.Tensor
             The encoder states to be attended.
@@ -80,9 +107,9 @@ class S2SBaseSearcher(torch.nn.Module):
         Returns
         -------
         log_probs : torch.Tensor
-            Log-probabilities of the current timestep output.
+            Log-probabilities of the current step output.
         memory : No limit
-            The memory variables generated in this timestep.
+            The memory variables generated in this step.
             (ex. RNN hidden states).
         attn : torch.Tensor
             The attention weight for doing penalty.
@@ -108,51 +135,17 @@ class S2SBaseSearcher(torch.nn.Module):
         """
         raise NotImplementedError
 
-    def lm_forward_step(self, inp_tokens, memory):
-        """This method should implement one step of
-        forwarding operation for language model.
-
-        Arguments
-        ---------
-        inp_tokens : torch.Tensor
-            The input tensor of the current timestep.
-        memory : No limit
-            The momory variables input for this timestep.
-            (e.g., RNN hidden states).
-
-        Return
-        ------
-        log_probs : torch.Tensor
-            Log-probabilities of the current timestep output.
-        memory : No limit
-            The memory variables generated in this timestep.
-            (e.g., RNN hidden states).
-        """
-        raise NotImplementedError
-
-    def reset_lm_mem(self, batch_size, device):
-        """This method should implement the resetting of
-        memory variables in the language model.
-        E.g., initializing zero vector as initial hidden states.
-
-        Arguments
-        ---------
-        batch_size : int
-            The size of the batch.
-        device : torch.device
-            The device to put the initial variables.
-
-        Return
-        ------
-        memory : No limit
-            The initial memory variable.
-        """
-        raise NotImplementedError
-
     def change_max_decoding_length(self, min_decode_steps, max_decode_steps):
         """set the minimum/maximum length the decoder can take."""
         return min_decode_steps, max_decode_steps
 
+    def set_n_out(self):
+        """set the number of output tokens.
+        Overrides this function if the fc layer is embedded
+        in the model, e.g., Whisper.
+        """
+        return self.fc.w.out_features
+
 
 class S2SGreedySearcher(S2SBaseSearcher):
     """This class implements the general forward-pass of
@@ -161,6 +154,7 @@ class S2SGreedySearcher(S2SBaseSearcher):
 
     def forward(self, enc_states, wav_len):
         """This method performs a greedy search.
+
         Arguments
         ---------
         enc_states : torch.Tensor
@@ -168,6 +162,16 @@ class S2SGreedySearcher(S2SBaseSearcher):
             (ex. the encoded speech representation to be attended).
         wav_len : torch.Tensor
             The speechbrain-style relative length.
+
+        Returns
+        -------
+        hyps : List containing hypotheses.
+        top_lengths : torch.Tensor (batch)
+            This tensor contains the final scores of hypotheses.
+        top_scores : torch.Tensor (batch)
+            The length of each topk sequence in the batch.
+        top_log_probs : torch.Tensor (batch, max length of token_id sequences)
+            The log probabilities of each hypotheses.
         """
         enc_lens = torch.round(enc_states.shape[1] * wav_len).int()
         device = enc_states.device
@@ -183,13 +187,14 @@ class S2SGreedySearcher(S2SBaseSearcher):
         log_probs_lst = []
         max_decode_steps = int(enc_states.shape[1] * self.max_decode_ratio)
 
-        # the decoding steps can be based on the max number of tokens that a decoder can process (e.g., 448 for Whisper).
+        # the decoding steps can be based on the max number of tokens that a decoder can process
+        # (e.g., 448 for Whisper).
         _, max_decode_steps = self.change_max_decoding_length(
             0, max_decode_steps
         )
 
         has_ended = enc_states.new_zeros(batch_size).bool()
-        for t in range(max_decode_steps):
+        for _ in range(max_decode_steps):
             log_probs, memory, _ = self.forward_step(
                 inp_tokens, memory, enc_states, enc_lens
             )
@@ -205,117 +210,68 @@ class S2SGreedySearcher(S2SBaseSearcher):
         mask = scores == float("inf")
         scores[mask] = 0
         predictions[mask] = self.eos_index
-        scores = scores.sum(dim=1).tolist()
-        predictions = batch_filter_seq2seq_output(
-            predictions, eos_id=self.eos_index
-        )
-
-        return predictions, scores
-
-
-class S2SWhisperGreedySearch(S2SGreedySearcher):
-    """
-    This class implements the greedy decoding
-    for Whisper neural nets made by OpenAI in
-    https://cdn.openai.com/papers/whisper.pdf.
-
-    Arguments
-    ---------
-    model : HuggingFaceWhisper
-        The Whisper model.
-    language_token : int
-        The language token to be used for the decoder input.
-    bos_token : int
-        The beginning of sentence token to be used for the decoder input.
-    task_token : int
-        The task token to be used for the decoder input.
-    timestamp_token : int
-        The timestamp token to be used for the decoder input.
-    max_length : int
-        The maximum decoding steps to perform.
-        The Whisper model has a maximum length of 448.
-    **kwargs
-        see S2SBaseSearcher, arguments are directly passed.
-    """
-
-    def __init__(
-        self,
-        model,
-        language_token=50259,
-        bos_token=50258,
-        task_token=50359,
-        timestamp_token=50363,
-        max_length=448,
-        **kwargs,
-    ):
-        super().__init__(**kwargs)
-        self.model = model
-        self.softmax = torch.nn.LogSoftmax(dim=-1)
-        self.decoder_input_tokens = None
-        self.language_token = language_token  # default language is english
-        self.bos_token = bos_token  # always this value
-        self.task_token = task_token  # default task is transcribe
-        self.timestamp_token = timestamp_token  # default is notimestamp
-        self.max_length = max_length - 3  # 3 tokens are added to the input
 
-    def set_language_token(self, language_token):
-        """set the language token to be used for the decoder input."""
-        self.language_token = language_token
+        (
+            top_hyps,
+            top_lengths,
+            top_scores,
+            top_log_probs,
+        ) = self._get_top_prediction(predictions, scores, log_probs)
 
-    def set_bos_token(self, bos_token):
-        """set the bos token to be used for the decoder input."""
-        self.bos_token = bos_token
+        # Convert best hypothesis to list
+        hyps = undo_padding(top_hyps[:, 0], top_lengths)
 
-    def set_task_token(self, task_token):
-        """set the task token to be used for the decoder input."""
-        self.task_token = task_token
+        return hyps, top_lengths, top_scores, top_log_probs
 
-    def set_timestamp_token(self, timestamp_token):
-        """set the timestamp token to be used for the decoder input."""
-        self.timestamp_token = timestamp_token
-        # need to reset bos_index too as timestamp_token is the first
-        # inp_token and need to be the first so that the first input gave
-        # to the model is [bos, language, task, timestamp] (order matters).
-        self.bos_index = self.timestamp_token
+    def _get_top_prediction(self, hyps, scores, log_probs):
+        """This method sorts the scores and return corresponding hypothesis and log probs.
 
-    def set_decoder_input_tokens(self, decoder_input_tokens):
-        """decoder_input_tokens are the tokens used as input to the decoder.
-        They are directly taken from the tokenizer.prefix_tokens attribute.
+        Arguments
+        ---------
+        hyps : torch.Tensor (batch, max length of token_id sequences)
+            This tensor stores the predicted hypothesis.
+        scores : torch.Tensor (batch)
+            The score of each hypotheses.
+        log_probs : torch.Tensor (batch, max length of token_id sequences)
+            The log probabilities of each hypotheses.
 
-        decoder_input_tokens = [bos_token, language_token, task_token, timestamp_token]
+        Returns
+        -------
+        top_hyps : torch.Tensor (batch, max length of token_id sequences)
+            This tensor stores the topk predicted hypothesis.
+        top_lengths : torch.Tensor (batch)
+            This tensor contains the final scores of hypotheses.
+        top_scores : torch.Tensor (batch)
+            The length of each topk sequence in the batch.
+        top_log_probs : torch.Tensor (batch, max length of token_id sequences)
+            The log probabilities of each hypotheses.
         """
-        self.set_bos_token(decoder_input_tokens[0])
-        self.set_language_token(decoder_input_tokens[1])
-        self.set_task_token(decoder_input_tokens[2])
-        self.set_timestamp_token(decoder_input_tokens[3])
-
-        # bos will be timestamp in our case.
-        self.decoder_input_tokens = [
-            self.bos_token,
-            self.language_token,
-            self.task_token,
-        ]
-
-    def reset_mem(self, batch_size, device):
-        """This method set the first tokens to be decoder_input_tokens during search."""
-        return torch.tensor([self.decoder_input_tokens] * batch_size).to(device)
-
-    def forward_step(self, inp_tokens, memory, enc_states, enc_lens):
-        """Performs a step in the implemented beamsearcher."""
-        memory = _update_mem(inp_tokens, memory)
+        batch_size = hyps.size(0)
+        max_length = hyps.size(1)
+        top_lengths = [max_length] * batch_size
+
+        # Collect lengths of top hyps
+        for pred_index in range(batch_size):
+            pred = hyps[pred_index]
+            pred_length = (pred == self.eos_index).nonzero(as_tuple=False)
+            if len(pred_length) > 0:
+                top_lengths[pred_index] = pred_length[0].item()
+        # Convert lists to tensors
+        top_lengths = torch.tensor(
+            top_lengths, dtype=torch.float, device=hyps.device
+        )
 
-        # WARNING: the max_decode_ratio need to be under 449 because
-        #  of positinal encoding
-        dec_out, attn = self.model.forward_decoder(enc_states, memory)
-        log_probs = self.softmax(dec_out[:, -1])
+        # Pick top log probabilities
+        top_log_probs = log_probs
 
-        return log_probs, memory, attn
+        # Use SpeechBrain style lengths
+        top_lengths = (top_lengths - 1).abs() / max_length
 
-    def change_max_decoding_length(self, min_decode_steps, max_decode_steps):
-        """set the minimum/maximum length the decoder can take."""
         return (
-            int(self.min_decode_ratio * self.max_length),
-            int(self.max_decode_ratio * self.max_length),
+            hyps.unsqueeze(1),
+            top_lengths.unsqueeze(1),
+            scores.unsqueeze(1),
+            top_log_probs.unsqueeze(1),
         )
 
 
@@ -338,6 +294,8 @@ class S2SRNNGreedySearcher(S2SGreedySearcher):
 
     Example
     -------
+    >>> import speechbrain as sb
+    >>> from speechbrain.decoders import S2SRNNGreedySearcher
     >>> emb = torch.nn.Embedding(5, 3)
     >>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
     ...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
@@ -347,14 +305,15 @@ class S2SRNNGreedySearcher(S2SGreedySearcher):
     ...     embedding=emb,
     ...     decoder=dec,
     ...     linear=lin,
-    ...     bos_index=4,
-    ...     eos_index=4,
+    ...     bos_index=0,
+    ...     eos_index=1,
     ...     min_decode_ratio=0,
     ...     max_decode_ratio=1,
     ... )
-    >>> enc = torch.rand([2, 6, 7])
-    >>> wav_len = torch.rand([2])
-    >>> hyps, scores = searcher(enc, wav_len)
+    >>> batch_size = 2
+    >>> enc = torch.rand([batch_size, 6, 7])
+    >>> wav_len = torch.ones([batch_size])
+    >>> top_hyps, top_lengths, _, _ = searcher(enc, wav_len)
     """
 
     def __init__(self, embedding, decoder, linear, **kwargs):
@@ -400,49 +359,29 @@ class S2SBeamSearcher(S2SBaseSearcher):
         The ratio of maximum decoding steps to length of encoder states.
     beam_size : int
         The width of beam.
+    scorer: speechbrain.decoders.scorers.ScorerBuilder
+        Scorer instance. Default: None.
+    return_topk : bool
+        Whether to return topk hypotheses. The topk hypotheses will be
+        padded to the same length. Default: False.
     topk : int
-        The number of hypothesis to return. (default: 1)
-    return_log_probs : bool
-        Whether to return log-probabilities. (default: False)
+        If return_topk is True, then return topk hypotheses. Default: 1.
     using_eos_threshold : bool
-        Whether to use eos threshold. (default: true)
+        Whether to use eos threshold. Default: True.
     eos_threshold : float
-        The threshold coefficient for eos token (default: 1.5). See 3.1.2 in
-        reference: https://arxiv.org/abs/1904.02619
+        The threshold coefficient for eos token. Default: 1.5.
+        See 3.1.2 in reference: https://arxiv.org/abs/1904.02619
     length_normalization : bool
-        Whether to divide the scores by the length. (default: True)
-    length_rewarding : float
-        The coefficient of length rewarding (γ).
-        log P(y|x) + λ log P_LM(y) + γ*len(y). (default: 0.0)
-    coverage_penalty: float
-        The coefficient of coverage penalty (η).
-        log P(y|x) + λ log P_LM(y) + γ*len(y) + η*coverage(x,y). (default: 0.0)
-        Reference: https://arxiv.org/pdf/1612.02695.pdf, https://arxiv.org/pdf/1808.10792.pdf
-    lm_weight : float
-        The weight of LM when performing beam search (λ).
-        log P(y|x) + λ log P_LM(y). (default: 0.0)
-    ctc_weight : float
-        The weight of CTC probabilities when performing beam search (λ).
-        (1-λ) log P(y|x) + λ log P_CTC(y|x). (default: 0.0)
-    blank_index : int
-        The index of the blank token.
-    ctc_score_mode: str
-        Default: "full"
-        CTC prefix scoring on "partial" token or "full: token.
-    ctc_window_size: int
-        Default: 0
-        Compute the ctc scores over the time frames using windowing based on attention peaks.
-        If 0, no windowing applied.
+        Whether to divide the scores by the length. Default: True.
     using_max_attn_shift: bool
-        Whether using the max_attn_shift constraint. (default: False)
+        Whether using the max_attn_shift constraint. Default: False.
     max_attn_shift: int
         Beam search will block the beams that attention shift more
-        than max_attn_shift.
+        than max_attn_shift. Default: 60.
         Reference: https://arxiv.org/abs/1904.02619
     minus_inf : float
-        DefaultL -1e20
         The value of minus infinity to block some path
-        of the search.
+        of the search. Default: -1e20.
     """
 
     def __init__(
@@ -452,19 +391,12 @@ class S2SBeamSearcher(S2SBaseSearcher):
         min_decode_ratio,
         max_decode_ratio,
         beam_size,
+        scorer=None,
+        return_topk=False,
         topk=1,
-        return_log_probs=False,
         using_eos_threshold=True,
         eos_threshold=1.5,
         length_normalization=True,
-        length_rewarding=0,
-        coverage_penalty=0.0,
-        lm_weight=0.0,
-        lm_modules=None,
-        ctc_weight=0.0,
-        blank_index=0,
-        ctc_score_mode="full",
-        ctc_window_size=0,
         using_max_attn_shift=False,
         max_attn_shift=60,
         minus_inf=-1e20,
@@ -473,46 +405,40 @@ class S2SBeamSearcher(S2SBaseSearcher):
             bos_index, eos_index, min_decode_ratio, max_decode_ratio,
         )
         self.beam_size = beam_size
+        self.scorer = scorer
+        self.return_topk = return_topk
         self.topk = topk
-        self.return_log_probs = return_log_probs
         self.length_normalization = length_normalization
-        self.length_rewarding = length_rewarding
-        self.coverage_penalty = coverage_penalty
-        self.coverage = None
-
-        if self.length_normalization and self.length_rewarding > 0:
-            raise ValueError(
-                "length normalization is not compatible with length rewarding."
-            )
-
         self.using_eos_threshold = using_eos_threshold
         self.eos_threshold = eos_threshold
         self.using_max_attn_shift = using_max_attn_shift
         self.max_attn_shift = max_attn_shift
-        self.lm_weight = lm_weight
-        self.lm_modules = lm_modules
-
-        # ctc related
-        self.ctc_weight = ctc_weight
-        self.blank_index = blank_index
-        self.att_weight = 1.0 - ctc_weight
-
-        assert (
-            0.0 <= self.ctc_weight <= 1.0
-        ), "ctc_weight should not > 1.0 and < 0.0"
+        self.attn_weight = 1.0
+        self.ctc_weight = 0.0
+        self.minus_inf = minus_inf
 
-        if self.ctc_weight > 0.0:
-            if len({self.bos_index, self.eos_index, self.blank_index}) < 3:
+        if self.scorer is not None:
+            # Check length normalization
+            if length_normalization and self.scorer.weights["length"] > 0.0:
                 raise ValueError(
-                    "To perform joint ATT/CTC decoding, set blank, eos and bos to different indexes."
+                    "Length normalization is not compatible with length rewarding."
                 )
+            if self.scorer.weights["ctc"] > 0.0:
+                # Check indices for ctc
+                all_scorers = {
+                    **self.scorer.full_scorers,
+                    **self.scorer.partial_scorers,
+                }
+                blank_index = all_scorers["ctc"].blank_index
+                if len({bos_index, eos_index, blank_index}) < 3:
+                    raise ValueError(
+                        "Set blank, eos and bos to different indexes for joint ATT/CTC or CTC decoding"
+                    )
 
-        # ctc already initialized
-        self.minus_inf = minus_inf
-        self.ctc_score_mode = ctc_score_mode
-        self.ctc_window_size = ctc_window_size
+                self.ctc_weight = self.scorer.weights["ctc"]
+                self.attn_weight = 1.0 - self.ctc_weight
 
-    def _check_full_beams(self, hyps, beam_size):
+    def _check_full_beams(self, hyps):
         """This method checks whether hyps has been full.
 
         Arguments
@@ -520,8 +446,6 @@ class S2SBeamSearcher(S2SBaseSearcher):
         hyps : List
             This list contains batch_size number.
             Each inside list contains a list stores all the hypothesis for this sentence.
-        beam_size : int
-            The number of beam_size.
 
         Returns
         -------
@@ -529,11 +453,8 @@ class S2SBeamSearcher(S2SBaseSearcher):
             Whether the hyps has been full.
         """
         hyps_len = [len(lst) for lst in hyps]
-        beam_size = [self.beam_size for _ in range(len(hyps_len))]
-        if hyps_len == beam_size:
-            return True
-        else:
-            return False
+        beams_size = [self.beam_size for _ in range(len(hyps_len))]
+        return hyps_len == beams_size
 
     def _check_attn_shift(self, attn, prev_attn_peak):
         """This method checks whether attention shift is more than attn_shift.
@@ -563,15 +484,14 @@ class S2SBeamSearcher(S2SBaseSearcher):
         return cond, attn_peak
 
     def _check_eos_threshold(self, log_probs):
-        """
-        This method checks whether eos log-probabilities exceed threshold.
+        """This method checks whether eos log-probabilities exceed threshold.
 
         Arguments
         ---------
         log_probs : torch.Tensor
             The log-probabilities.
 
-        Return
+        Returns
         ------
         cond : torch.BoolTensor
             Each element represents whether the eos log-probabilities will be kept.
@@ -581,134 +501,446 @@ class S2SBeamSearcher(S2SBaseSearcher):
         cond = eos_probs > (self.eos_threshold * max_probs)
         return cond
 
-    def _update_hyp_and_scores(
-        self,
-        inp_tokens,
-        alived_seq,
-        alived_log_probs,
-        hyps_and_scores,
-        scores,
-        timesteps,
+    def init_hypotheses(self):
+        """This method initializes the AlivedHypotheses object.
+
+        Returns
+        -------
+        AlivedHypotheses
+            The alived hypotheses filled with the initial values.
+        """
+        return AlivedHypotheses(
+            alived_seq=torch.empty(self.n_bh, 0, device=self.device).long(),
+            alived_log_probs=torch.empty(self.n_bh, 0, device=self.device),
+            sequence_scores=torch.empty(self.n_bh, device=self.device)
+            .fill_(float("-inf"))
+            .index_fill_(0, self.beam_offset, 0.0),
+        )
+
+    def _attn_weight_step(
+        self, inp_tokens, memory, enc_states, enc_lens, attn, log_probs
     ):
-        """This method will update hyps and scores if inp_tokens are eos.
+        """This method computes a forward_step if attn_weight is superior to 0.
 
         Arguments
         ---------
         inp_tokens : torch.Tensor
-            The current output.
-        alived_seq : torch.Tensor
-            The tensor to store the alived_seq.
-        alived_log_probs : torch.Tensor
-            The tensor to store the alived_log_probs.
-        hyps_and_scores : list
-            To store generated hypotheses and scores.
-        scores : torch.Tensor
-            The final scores of beam search.
-        timesteps : float
-            The current timesteps. This is for length rewarding.
+            The input tensor of the current step.
+        memory : No limit
+            The memory variables input for this step.
+            (ex. RNN hidden states).
+        enc_states : torch.Tensor
+            The encoder states to be attended.
+        enc_lens : torch.Tensor
+            The actual length of each enc_states sequence.
+        attn : torch.Tensor
+            The attention weight.
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
 
         Returns
         -------
-        is_eos : torch.BoolTensor
-            Each element represents whether the token is eos.
+        log_probs : torch.Tensor
+            Log-probabilities of the current step output.
+        memory : No limit
+            The memory variables generated in this step.
+            (ex. RNN hidden states).
+        attn : torch.Tensor
+            The attention weight.
         """
-        is_eos = inp_tokens.eq(self.eos_index)
-        (eos_indices,) = torch.nonzero(is_eos, as_tuple=True)
-
-        # Store the hypothesis and their scores when reaching eos.
-        if eos_indices.shape[0] > 0:
-            for index in eos_indices:
-                # convert to int
-                index = index.item()
-                batch_id = torch.div(
-                    index, self.beam_size, rounding_mode="floor"
-                )
-                if len(hyps_and_scores[batch_id]) == self.beam_size:
-                    continue
-                hyp = alived_seq[index, :]
-                log_probs = alived_log_probs[index, :]
-                final_scores = scores[index] + self.length_rewarding * (
-                    timesteps + 1
-                )
-                hyps_and_scores[batch_id].append((hyp, log_probs, final_scores))
-        return is_eos
+        if self.attn_weight > 0:
+            log_probs, memory, attn = self.forward_step(
+                inp_tokens, memory, enc_states, enc_lens
+            )
+            log_probs = self.attn_weight * log_probs
+        return log_probs, memory, attn
 
-    def _get_top_score_prediction(self, hyps_and_scores, topk):
-        """This method sorts the scores and return corresponding hypothesis and log probs.
+    def _max_attn_shift_step(self, attn, prev_attn_peak, log_probs):
+        """This method will block the beams that attention shift more
+        than max_attn_shift.
 
         Arguments
         ---------
-        hyps_and_scores : list
-            To store generated hypotheses and scores.
-        topk : int
-            Number of hypothesis to return.
+        attn : torch.Tensor
+            The attention weight.
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
 
         Returns
         -------
-        topk_hyps : torch.Tensor (batch, topk, max length of token_id sequences)
-            This tensor stores the topk predicted hypothesis.
-        topk_scores : torch.Tensor (batch, topk)
-            The length of each topk sequence in the batch.
-        topk_lengths : torch.Tensor (batch, topk)
-            This tensor contains the final scores of topk hypotheses.
-        topk_log_probs : list
-            The log probabilities of each hypotheses.
+        log_probs : torch.Tensor
+            Log-probabilities of the current step output.
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
         """
-        top_hyps, top_log_probs, top_scores, top_lengths = [], [], [], []
-        batch_size = len(hyps_and_scores)
-
-        # Collect hypotheses
-        for i in range(len(hyps_and_scores)):
-            hyps, log_probs, scores = zip(*hyps_and_scores[i])
-            top_hyps += hyps
-            top_scores += scores
-            top_log_probs += log_probs
-            top_lengths += [len(hyp) for hyp in hyps]
-        top_hyps = torch.nn.utils.rnn.pad_sequence(
-            top_hyps, batch_first=True, padding_value=0
-        )
-        top_scores = torch.stack((top_scores), dim=0).view(batch_size, -1)
-        top_lengths = torch.tensor(
-            top_lengths, dtype=torch.int, device=top_scores.device
-        )
-        # Get topk indices
-        topk_scores, indices = top_scores.topk(self.topk, dim=-1)
-        indices = (indices + self.beam_offset.unsqueeze(1)).view(
-            batch_size * self.topk
-        )
-        # Select topk hypotheses
-        topk_hyps = torch.index_select(top_hyps, dim=0, index=indices,)
-        topk_hyps = topk_hyps.view(batch_size, self.topk, -1)
-        topk_lengths = torch.index_select(top_lengths, dim=0, index=indices,)
-        topk_lengths = topk_lengths.view(batch_size, self.topk)
-        topk_log_probs = [top_log_probs[index.item()] for index in indices]
+        if self.using_max_attn_shift:
+            cond, prev_attn_peak = self._check_attn_shift(attn, prev_attn_peak)
+            log_probs = mask_by_condition(
+                log_probs, cond, fill_value=self.minus_inf
+            )
+        return log_probs, prev_attn_peak
 
-        return topk_hyps, topk_scores, topk_lengths, topk_log_probs
+    def _scorer_step(self, inp_tokens, scorer_memory, attn, log_probs):
+        """This method call the scorers if scorer is not None.
 
-    def forward(self, enc_states, wav_len):  # noqa: C901
-        """Applies beamsearch and returns the predicted tokens."""
-        enc_lens = torch.round(enc_states.shape[1] * wav_len).int()
-        device = enc_states.device
-        batch_size = enc_states.shape[0]
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The input tensor of the current step.
+        scorer_memory : No limit
+            The memory variables input for this step.
+            (ex. RNN hidden states).
+        attn : torch.Tensor
+            The attention weight.
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
 
-        memory = self.reset_mem(batch_size * self.beam_size, device=device)
+        Returns
+        -------
+        log_probs : torch.Tensor
+            Log-probabilities of the current step output.
+        scorer_memory : No limit
+            The memory variables generated in this step.
+        """
+        if self.scorer is not None:
+            log_probs, scorer_memory = self.scorer.score(
+                inp_tokens, scorer_memory, attn, log_probs, self.beam_size,
+            )
+        return log_probs, scorer_memory
 
-        if self.lm_weight > 0:
-            lm_memory = self.reset_lm_mem(batch_size * self.beam_size, device)
+    def _set_eos_minus_inf_step(self, log_probs, step, min_decode_steps):
+        """This method set the log_probs of eos to minus infinity if the step is less than min_decode_steps.
 
-        if self.ctc_weight > 0:
-            # (batch_size * beam_size, L, vocab_size)
-            ctc_outputs = self.ctc_forward_step(enc_states)
-            ctc_scorer = CTCPrefixScorer(
-                ctc_outputs,
-                enc_lens,
-                batch_size,
-                self.beam_size,
-                self.blank_index,
-                self.eos_index,
-                self.ctc_window_size,
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
+        step : int
+            The current decoding step.
+        min_decode_steps : int
+            The minimum decoding steps.
+
+        Returns
+        -------
+        log_probs : torch.Tensor
+            Log-probabilities of the current step output.
+        """
+        if step < min_decode_steps:
+            log_probs[:, self.eos_index] = self.minus_inf
+        return log_probs
+
+    def _eos_threshold_step(self, log_probs):
+        """This method set the log_probs of eos to minus infinity if the eos log-probabilities is less than eos_threshold.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
+
+        Returns
+        -------
+        log_probs : torch.Tensor
+            Log-probabilities of the current step output.
+        """
+        if self.using_eos_threshold:
+            cond = self._check_eos_threshold(log_probs)
+            log_probs[:, self.eos_index] = mask_by_condition(
+                log_probs[:, self.eos_index], cond, fill_value=self.minus_inf,
+            )
+        return log_probs
+
+    def _attn_weight_permute_memory_step(self, memory, predecessors):
+        """This method permute the memory if attn_weight is superior to 0.
+
+        Arguments
+        ---------
+        memory : No limit
+            The memory variables input for this step.
+            (ex. RNN hidden states).
+        predecessors : torch.Tensor
+            The index of which beam the current top-K output came from in (t-1) steps.
+
+        Returns
+        -------
+        memory : No limit
+            The memory variables generated in this step.
+            (ex. RNN hidden states).
+        """
+        if self.attn_weight > 0:
+            memory = self.permute_mem(memory, index=predecessors)
+        return memory
+
+    def _scorer_permute_memory_step(
+        self, scorer_memory, predecessors, candidates
+    ):
+        """This method permute the scorer_memory if scorer is not None.
+
+        Arguments
+        ---------
+        scorer_memory : No limit
+            The memory variables input for this step.
+            (ex. RNN hidden states).
+        predecessors : torch.Tensor
+            The index of which beam the current top-K output came from in (t-1) steps.
+        candidates : torch.Tensor
+            The index of the current top-K output.
+
+        Returns
+        -------
+        scorer_memory : No limit
+            The memory variables generated in this step.
+        """
+        if self.scorer is not None:
+            scorer_memory = self.scorer.permute_scorer_mem(
+                scorer_memory, index=predecessors, candidates=candidates
+            )
+        return scorer_memory
+
+    def _max_attn_shift_permute_memory_step(self, prev_attn_peak, predecessors):
+        """This method permute the prev_attn_peak if using_max_attn_shift is True.
+
+        Arguments
+        ---------
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
+        predecessors : torch.Tensor
+            The index of which beam the current top-K output came from in (t-1) steps.
+
+        Returns
+        -------
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
+        """
+        if self.using_max_attn_shift:
+            prev_attn_peak = torch.index_select(
+                prev_attn_peak, dim=0, index=predecessors
             )
-            ctc_memory = None
+        return prev_attn_peak
+
+    def _update_reset_memory(self, enc_states, enc_lens):
+        """ Call reset memory for each module.
+
+        Arguments
+        ---------
+        enc_states : torch.Tensor
+            The encoder states to be attended.
+        enc_lens : torch.Tensor
+            The actual length of each enc_states sequence.
+
+        Returns
+        -------
+        memory : No limit
+            The memory variables generated in this step.
+        scorer_memory : No limit
+            The memory variables generated in this step.
+        """
+        memory = self.reset_mem(self.n_bh, device=self.device)
+        scorer_memory = None
+        if self.scorer is not None:
+            scorer_memory = self.scorer.reset_scorer_mem(enc_states, enc_lens)
+        return memory, scorer_memory
+
+    def _update_permute_memory(
+        self, memory, scorer_memory, predecessors, candidates, prev_attn_peak
+    ):
+        """Call permute memory for each module. It allows us to synchronize the memory with the output.
+
+        Arguments
+        ---------
+        memory : No limit
+            The memory variables input for this step.
+            (ex. RNN hidden states).
+        scorer_memory : No limit
+            The memory variables input for this step.
+            (ex. RNN hidden states).
+        predecessors : torch.Tensor
+            The index of which beam the current top-K output came from in (t-1) steps.
+        candidates : torch.Tensor
+            The index of the current top-K output.
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
+
+        Returns
+        -------
+        memory : No limit
+            The memory variables generated in this step.
+        scorer_memory : No limit
+            The memory variables generated in this step.
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
+        """
+        memory = self._attn_weight_permute_memory_step(memory, predecessors)
+
+        scorer_memory = self._scorer_permute_memory_step(
+            scorer_memory, predecessors, candidates
+        )
+
+        # If using_max_attn_shift, then the previous attn peak has to be permuted too.
+        prev_attn_peak = self._max_attn_shift_permute_memory_step(
+            prev_attn_peak, predecessors
+        )
+
+        return memory, scorer_memory, prev_attn_peak
+
+    def _update_sequences_and_log_probs(
+        self, log_probs, inp_tokens, predecessors, candidates, alived_hyps,
+    ):
+        """This method update sequences and log probabilities by adding the new inp_tokens.
+
+        Arguments
+        ---------
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
+        inp_tokens : torch.Tensor
+            The input tensor of the current step.
+        predecessors : torch.Tensor
+            The index of which beam the current top-K output came from in (t-1) steps.
+        candidates : torch.Tensor
+            The index of the current top-K output.
+        alived_hyps : AlivedHypotheses
+            The alived hypotheses.
+
+        Returns
+        -------
+        alived_hyps : AlivedHypotheses
+            The alived hypotheses.
+        """
+        # Update alived_seq
+        alived_hyps.alived_seq = torch.cat(
+            [
+                torch.index_select(
+                    alived_hyps.alived_seq, dim=0, index=predecessors
+                ),
+                inp_tokens.unsqueeze(1),
+            ],
+            dim=-1,
+        )
+
+        # Takes the log-probabilities
+        beam_log_probs = log_probs[
+            torch.arange(self.batch_size).unsqueeze(1), candidates
+        ].reshape(self.n_bh)
+
+        # Update alived_log_probs
+        alived_hyps.alived_log_probs = torch.cat(
+            [
+                torch.index_select(
+                    alived_hyps.alived_log_probs, dim=0, index=predecessors
+                ),
+                beam_log_probs.unsqueeze(1),
+            ],
+            dim=-1,
+        )
+
+        return alived_hyps
+
+    def _compute_scores_and_next_inp_tokens(self, alived_hyps, log_probs, step):
+        """Compute scores and next input tokens.
+
+        Arguments
+        ---------
+        alived_hyps : AlivedHypotheses
+            The alived hypotheses.
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
+        step : int
+            The current decoding step.
+
+        Returns
+        -------
+        scores : torch.Tensor
+            The scores of the current step output.
+        candidates : torch.Tensor
+            The index of the current top-K output.
+        predecessors : torch.Tensor
+            The index of which beam the current top-K output came from in (t-1) steps.
+        inp_tokens : torch.Tensor
+            The input tensor of the current step.
+        alived_hyps : AlivedHypotheses
+            The alived hypotheses.
+        """
+        scores = alived_hyps.sequence_scores.unsqueeze(1).expand(-1, self.n_out)
+        scores = scores + log_probs
+
+        # length normalization
+        if self.length_normalization:
+            scores = scores / (step + 1)
+
+        # keep topk beams
+        scores, candidates = scores.view(self.batch_size, -1).topk(
+            self.beam_size, dim=-1
+        )
+
+        # The input for the next step, also the output of current step.
+        inp_tokens = (candidates % self.n_out).view(self.n_bh)
+
+        scores = scores.view(self.n_bh)
+        alived_hyps.sequence_scores = scores
+
+        # recover the length normalization
+        if self.length_normalization:
+            alived_hyps.sequence_scores = alived_hyps.sequence_scores * (
+                step + 1
+            )
+
+        # The index of which beam the current top-K output came from in (t-1) steps.
+        predecessors = (
+            torch.div(candidates, self.n_out, rounding_mode="floor")
+            + self.beam_offset.unsqueeze(1).expand_as(candidates)
+        ).view(self.n_bh)
+
+        return (
+            scores,
+            candidates,
+            predecessors,
+            inp_tokens,
+            alived_hyps,
+        )
+
+    def init_beam_search_data(self, enc_states, wav_len):
+        """Initialize the beam search data.
+
+        Arguments
+        ---------
+        enc_states : torch.Tensor
+            The encoder states to be attended.
+        wav_len : torch.Tensor
+            The actual length of each enc_states sequence.
+
+        Returns
+        -------
+        alived_hyps : AlivedHypotheses
+            The alived hypotheses.
+        inp_tokens : torch.Tensor
+            The input tensor of the current step.
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
+        eos_hyps_and_log_probs_scores : list
+            Generated hypotheses (the one that haved reached eos) and log probs scores.
+        memory : No limit
+            The memory variables generated in this step.
+        scorer_memory : No limit
+            The memory variables generated in this step.
+        attn : torch.Tensor
+            The attention weight.
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
+        enc_states : torch.Tensor
+            The encoder states to be attended.
+        enc_lens : torch.Tensor
+            The actual length of each enc_states sequence.
+        """
+        enc_lens = torch.round(enc_states.shape[1] * wav_len).int()
+
+        self.device = enc_states.device
+        self.batch_size = enc_states.shape[0]
+        self.n_bh = self.batch_size * self.beam_size
+
+        self.n_out = self.set_n_out()
+
+        memory, scorer_memory = self._update_reset_memory(enc_states, enc_lens)
 
         # Inflate the enc_states and enc_len by beam_size times
         enc_states = inflate_tensor(enc_states, times=self.beam_size, dim=0)
@@ -716,502 +948,436 @@ class S2SBeamSearcher(S2SBaseSearcher):
 
         # Using bos as the first input
         inp_tokens = (
-            torch.zeros(batch_size * self.beam_size, device=device)
+            torch.zeros(self.n_bh, device=self.device)
             .fill_(self.bos_index)
             .long()
         )
 
         # The first index of each sentence.
         self.beam_offset = (
-            torch.arange(batch_size, device=device) * self.beam_size
+            torch.arange(self.batch_size, device=self.device) * self.beam_size
         )
 
         # initialize sequence scores variables.
-        sequence_scores = torch.empty(
-            batch_size * self.beam_size, device=device
+        sequence_scores = torch.empty(self.n_bh, device=self.device).fill_(
+            self.minus_inf
         )
-        sequence_scores.fill_(float("-inf"))
 
         # keep only the first to make sure no redundancy.
         sequence_scores.index_fill_(0, self.beam_offset, 0.0)
 
         # keep the hypothesis that reaches eos and their corresponding score and log_probs.
-        hyps_and_scores = [[] for _ in range(batch_size)]
-
-        # keep the sequences that still not reaches eos.
-        alived_seq = torch.empty(
-            batch_size * self.beam_size, 0, device=device
-        ).long()
-
-        # Keep the log-probabilities of alived sequences.
-        alived_log_probs = torch.empty(
-            batch_size * self.beam_size, 0, device=device
-        )
+        eos_hyps_and_log_probs_scores = [[] for _ in range(self.batch_size)]
 
-        min_decode_steps = int(enc_states.shape[1] * self.min_decode_ratio)
-        max_decode_steps = int(enc_states.shape[1] * self.max_decode_ratio)
+        self.min_decode_steps = int(enc_states.shape[1] * self.min_decode_ratio)
+        self.max_decode_steps = int(enc_states.shape[1] * self.max_decode_ratio)
 
-        # the decoding steps can be based on the max number of tokens that a decoder can process (e.g., 448 for Whisper).
-        min_decode_steps, max_decode_steps = self.change_max_decoding_length(
-            min_decode_steps, max_decode_steps
+        # the decoding steps can be based on the max number of tokens that a decoder can process
+        # (e.g., 448 for Whisper).
+        (
+            self.min_decode_steps,
+            self.max_decode_steps,
+        ) = self.change_max_decoding_length(
+            self.min_decode_steps, self.max_decode_steps
         )
 
         # Initialize the previous attention peak to zero
         # This variable will be used when using_max_attn_shift=True
-        prev_attn_peak = torch.zeros(batch_size * self.beam_size, device=device)
+        prev_attn_peak = torch.zeros(self.n_bh, device=self.device)
+        attn = None
 
-        for t in range(max_decode_steps):
-            # terminate condition
-            if self._check_full_beams(hyps_and_scores, self.beam_size):
-                break
+        log_probs = torch.full((self.n_bh, self.n_out), 0.0, device=self.device)
 
-            log_probs, memory, attn = self.forward_step(
-                inp_tokens, memory, enc_states, enc_lens
-            )
-            log_probs = self.att_weight * log_probs
+        alived_hyps = self.init_hypotheses()
 
-            # Keep the original value
-            log_probs_clone = log_probs.clone().reshape(batch_size, -1)
-            vocab_size = log_probs.shape[-1]
+        return (
+            alived_hyps,
+            inp_tokens,
+            log_probs,
+            eos_hyps_and_log_probs_scores,
+            memory,
+            scorer_memory,
+            attn,
+            prev_attn_peak,
+            enc_states,
+            enc_lens,
+        )
 
-            if self.using_max_attn_shift:
-                # Block the candidates that exceed the max shift
-                cond, attn_peak = self._check_attn_shift(attn, prev_attn_peak)
-                log_probs = mask_by_condition(
-                    log_probs, cond, fill_value=self.minus_inf
+    def _update_hyps_and_scores_if_eos_token(
+        self, inp_tokens, alived_hyps, eos_hyps_and_log_probs_scores, scores,
+    ):
+        """This method will update hyps and scores if inp_tokens are eos.
+
+        Arguments
+        ---------
+        inp_tokens : torch.Tensor
+            The current output.
+        alived_hyps : AlivedHypotheses
+            alived_seq : torch.Tensor
+            alived_log_probs : torch.Tensor
+        eos_hyps_and_log_probs_scores : list
+            Generated hypotheses (the one that haved reached eos) and log probs scores.
+        scores : torch.Tensor
+            Scores at the current step.
+
+        Returns
+        -------
+        is_eos : torch.BoolTensor
+            Each element represents whether the token is eos.
+        """
+        is_eos = inp_tokens.eq(self.eos_index)
+        (eos_indices,) = torch.nonzero(is_eos, as_tuple=True)
+
+        # Store the hypothesis and their scores when reaching eos.
+        if eos_indices.shape[0] > 0:
+            for index in eos_indices:
+                # convert to int
+                index = index.item()
+                batch_id = torch.div(
+                    index, self.beam_size, rounding_mode="floor"
                 )
-                prev_attn_peak = attn_peak
-
-            # Set eos to minus_inf when less than minimum steps.
-            if t < min_decode_steps:
-                log_probs[:, self.eos_index] = self.minus_inf
-
-            # Set the eos prob to minus_inf when it doesn't exceed threshold.
-            if self.using_eos_threshold:
-                cond = self._check_eos_threshold(log_probs)
-                log_probs[:, self.eos_index] = mask_by_condition(
-                    log_probs[:, self.eos_index],
-                    cond,
-                    fill_value=self.minus_inf,
+                if (
+                    len(eos_hyps_and_log_probs_scores[batch_id])
+                    == self.beam_size
+                ):
+                    continue
+                hyp = alived_hyps.alived_seq[index, :]
+                log_probs = alived_hyps.alived_log_probs[index, :]
+                final_scores = scores[index].clone()
+                eos_hyps_and_log_probs_scores[batch_id].append(
+                    (hyp, log_probs, final_scores)
                 )
 
-            # adding LM scores to log_prob if lm_weight > 0
-            if self.lm_weight > 0:
-                lm_log_probs, lm_memory = self.lm_forward_step(
-                    inp_tokens, lm_memory
-                )
-                log_probs = log_probs + self.lm_weight * lm_log_probs
-
-            # adding CTC scores to log_prob if ctc_weight > 0
-            if self.ctc_weight > 0:
-                g = alived_seq
-                # block blank token
-                log_probs[:, self.blank_index] = self.minus_inf
-                if self.ctc_weight != 1.0 and self.ctc_score_mode == "partial":
-                    # pruning vocab for ctc_scorer
-                    _, ctc_candidates = log_probs.topk(
-                        self.beam_size * 2, dim=-1
-                    )
-                else:
-                    ctc_candidates = None
+        return is_eos
 
-                ctc_log_probs, ctc_memory = ctc_scorer.forward_step(
-                    g, ctc_memory, ctc_candidates, attn
-                )
-                log_probs = log_probs + self.ctc_weight * ctc_log_probs
+    def _get_topk_prediction(self, eos_hyps_and_log_probs_scores):
+        """This method sorts the scores and return corresponding hypothesis and log probs.
 
-            scores = sequence_scores.unsqueeze(1).expand(-1, vocab_size)
-            scores = scores + log_probs
+        Arguments
+        ---------
+        eos_hyps_and_log_probs_scores : list
+            Generated hypotheses (the one that haved reached eos) and log probs scores.
 
-            # length normalization
-            if self.length_normalization:
-                scores = scores / (t + 1)
+        Returns
+        -------
+        topk_hyps : torch.Tensor (batch, topk, max length of token_id sequences)
+            This tensor stores the topk predicted hypothesis.
+        topk_lengths : torch.Tensor (batch, topk)
+            This tensor contains the final scores of topk hypotheses.
+        topk_scores : torch.Tensor (batch, topk)
+            The length of each topk sequence in the batch.
+        topk_log_probs : torch.Tensor (batch, topk, max length of token_id sequences)
+            The log probabilities of each hypotheses.
+        """
+        top_hyps, top_log_probs, top_scores, top_lengths = [], [], [], []
+        batch_size = len(eos_hyps_and_log_probs_scores)
 
-            # keep topk beams
-            scores, candidates = scores.view(batch_size, -1).topk(
-                self.beam_size, dim=-1
-            )
+        # Collect hypotheses
+        for i in range(len(eos_hyps_and_log_probs_scores)):
+            hyps, log_probs, scores = zip(*eos_hyps_and_log_probs_scores[i])
+            top_hyps += hyps
+            top_scores += scores
+            top_log_probs += log_probs
+            top_lengths += [len(hyp) for hyp in hyps]
 
-            # The input for the next step, also the output of current step.
-            inp_tokens = (candidates % vocab_size).view(
-                batch_size * self.beam_size
-            )
+        # Convert lists to tensors
+        top_hyps = torch.nn.utils.rnn.pad_sequence(
+            top_hyps, batch_first=True, padding_value=0
+        )
+        top_log_probs = torch.nn.utils.rnn.pad_sequence(
+            top_log_probs, batch_first=True, padding_value=0
+        )
+        top_lengths = torch.tensor(
+            top_lengths, dtype=torch.float, device=top_hyps.device
+        )
+        top_scores = torch.stack((top_scores), dim=0).view(batch_size, -1)
 
-            scores = scores.view(batch_size * self.beam_size)
-            sequence_scores = scores
+        # Use SpeechBrain style lengths
+        top_lengths = (top_lengths - 1).abs() / top_hyps.size(1)
 
-            # recover the length normalization
-            if self.length_normalization:
-                sequence_scores = sequence_scores * (t + 1)
+        # Get topk indices
+        topk_scores, indices = top_scores.topk(self.topk, dim=-1)
+        indices = (indices + self.beam_offset.unsqueeze(1)).view(
+            batch_size * self.topk
+        )
+        # Select topk hypotheses
+        topk_hyps = torch.index_select(top_hyps, dim=0, index=indices,)
+        topk_hyps = topk_hyps.view(batch_size, self.topk, -1)
+        topk_lengths = torch.index_select(top_lengths, dim=0, index=indices,)
+        topk_lengths = topk_lengths.view(batch_size, self.topk)
+        topk_log_probs = torch.index_select(
+            top_log_probs, dim=0, index=indices,
+        )
+        topk_log_probs = topk_log_probs.view(batch_size, self.topk, -1)
 
-            # The index of which beam the current top-K output came from in (t-1) timesteps.
-            predecessors = (
-                torch.div(candidates, vocab_size, rounding_mode="floor")
-                + self.beam_offset.unsqueeze(1).expand_as(candidates)
-            ).view(batch_size * self.beam_size)
+        return topk_hyps, topk_lengths, topk_scores, topk_log_probs
 
-            # Permute the memory to synchoronize with the output.
-            memory = self.permute_mem(memory, index=predecessors)
-            if self.lm_weight > 0:
-                lm_memory = self.permute_lm_mem(lm_memory, index=predecessors)
+    def search_step(
+        self,
+        alived_hyps,
+        inp_tokens,
+        log_probs,
+        eos_hyps_and_log_probs_scores,
+        memory,
+        scorer_memory,
+        attn,
+        prev_attn_peak,
+        enc_states,
+        enc_lens,
+        step,
+    ):
+        """A search step for the next most likely tokens.
 
-            if self.ctc_weight > 0:
-                ctc_memory = ctc_scorer.permute_mem(ctc_memory, candidates)
+        Arguments
+        ---------
+        alived_hyps : AlivedHypotheses
+            The alived hypotheses.
+        inp_tokens : torch.Tensor
+            The input tensor of the current step.
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
+        eos_hyps_and_log_probs_scores : list
+            Generated hypotheses (the one that haved reached eos) and log probs scores.
+        memory : No limit
+            The memory variables input for this step.
+            (ex. RNN hidden states).
+        scorer_memory : No limit
+            The memory variables input for this step.
+            (ex. RNN hidden states).
+        attn : torch.Tensor
+            The attention weight.
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
+        enc_states : torch.Tensor
+            The encoder states to be attended.
+        enc_lens : torch.Tensor
+            The actual length of each enc_states sequence.
+        step : int
+            The current decoding step.
 
-            # If using_max_attn_shift, then the previous attn peak has to be permuted too.
-            if self.using_max_attn_shift:
-                prev_attn_peak = torch.index_select(
-                    prev_attn_peak, dim=0, index=predecessors
-                )
+        Returns
+        -------
+        alived_hyps : AlivedHypotheses
+            The alived hypotheses.
+        inp_tokens : torch.Tensor
+            The input tensor of the current step.
+        log_probs : torch.Tensor
+            The log-probabilities of the current step output.
+        eos_hyps_and_log_probs_scores : list
+            Generated hypotheses (the one that haved reached eos) and log probs scores.
+        memory : No limit
+            The memory variables generated in this step.
+        scorer_memory : No limit
+            The memory variables generated in this step.
+        attn : torch.Tensor
+            The attention weight.
+        prev_attn_peak : torch.Tensor
+            The previous attention peak place.
+        scores : torch.Tensor
+            The scores of the current step output.
+        """
+        (log_probs, memory, attn,) = self._attn_weight_step(
+            inp_tokens, memory, enc_states, enc_lens, attn, log_probs,
+        )
 
-            # Add coverage penalty
-            if self.coverage_penalty > 0:
-                cur_attn = torch.index_select(attn, dim=0, index=predecessors)
-
-                # coverage: cumulative attention probability vector
-                if t == 0:
-                    # Init coverage
-                    self.coverage = cur_attn
-
-                # the attn of transformer is [batch_size*beam_size, current_step, source_len]
-                if len(cur_attn.size()) > 2:
-                    self.converage = torch.sum(cur_attn, dim=1)
-                else:
-                    # Update coverage
-                    self.coverage = torch.index_select(
-                        self.coverage, dim=0, index=predecessors
-                    )
-                    self.coverage = self.coverage + cur_attn
-
-                # Compute coverage penalty and add it to scores
-                penalty = torch.max(
-                    self.coverage, self.coverage.clone().fill_(0.5)
-                ).sum(-1)
-                penalty = penalty - self.coverage.size(-1) * 0.5
-                penalty = penalty.view(batch_size * self.beam_size)
-                penalty = (
-                    penalty / (t + 1) if self.length_normalization else penalty
-                )
-                scores = scores - penalty * self.coverage_penalty
-
-            # Update alived_seq
-            alived_seq = torch.cat(
-                [
-                    torch.index_select(alived_seq, dim=0, index=predecessors),
-                    inp_tokens.unsqueeze(1),
-                ],
-                dim=-1,
-            )
+        # Keep the original value
+        log_probs_clone = log_probs.clone().reshape(self.batch_size, -1)
 
-            # Takes the log-probabilities
-            beam_log_probs = log_probs_clone[
-                torch.arange(batch_size).unsqueeze(1), candidates
-            ].reshape(batch_size * self.beam_size)
-            alived_log_probs = torch.cat(
-                [
-                    torch.index_select(
-                        alived_log_probs, dim=0, index=predecessors
-                    ),
-                    beam_log_probs.unsqueeze(1),
-                ],
-                dim=-1,
-            )
+        (log_probs, prev_attn_peak,) = self._max_attn_shift_step(
+            attn, prev_attn_peak, log_probs,
+        )
 
-            is_eos = self._update_hyp_and_scores(
-                inp_tokens,
-                alived_seq,
-                alived_log_probs,
-                hyps_and_scores,
-                scores,
-                timesteps=t,
-            )
+        log_probs = self._set_eos_minus_inf_step(
+            log_probs, step, self.min_decode_steps,
+        )
 
-            # Block the paths that have reached eos.
-            sequence_scores.masked_fill_(is_eos, float("-inf"))
+        log_probs = self._eos_threshold_step(log_probs)
 
-        if not self._check_full_beams(hyps_and_scores, self.beam_size):
-            # Using all eos to fill-up the hyps.
-            eos = (
-                torch.zeros(batch_size * self.beam_size, device=device)
-                .fill_(self.eos_index)
-                .long()
-            )
-            _ = self._update_hyp_and_scores(
-                eos,
-                alived_seq,
-                alived_log_probs,
-                hyps_and_scores,
-                scores,
-                timesteps=max_decode_steps,
-            )
+        (log_probs, scorer_memory,) = self._scorer_step(
+            inp_tokens, scorer_memory, attn, log_probs,
+        )
 
         (
-            topk_hyps,
-            topk_scores,
-            topk_lengths,
-            log_probs,
-        ) = self._get_top_score_prediction(hyps_and_scores, topk=self.topk,)
-        # pick the best hyp
-        predictions = topk_hyps[:, 0, :]
-        predictions = batch_filter_seq2seq_output(
-            predictions, eos_id=self.eos_index
+            scores,
+            candidates,
+            predecessors,
+            inp_tokens,
+            alived_hyps,
+        ) = self._compute_scores_and_next_inp_tokens(
+            alived_hyps, log_probs, step,
         )
 
-        if self.return_log_probs:
-            return predictions, topk_scores, log_probs
-        else:
-            return predictions, topk_scores
-
-    def ctc_forward_step(self, x):
-        """Applies a ctc step during bramsearch."""
-        logits = self.ctc_fc(x)
-        log_probs = self.softmax(logits)
-        return log_probs
+        memory, scorer_memory, prev_attn_peak = self._update_permute_memory(
+            memory, scorer_memory, predecessors, candidates, prev_attn_peak
+        )
 
-    def permute_mem(self, memory, index):
-        """This method permutes the seq2seq model memory
-        to synchronize the memory index with the current output.
+        alived_hyps = self._update_sequences_and_log_probs(
+            log_probs_clone, inp_tokens, predecessors, candidates, alived_hyps,
+        )
 
-        Arguments
-        ---------
-        memory : No limit
-            The memory variable to be permuted.
-        index : torch.Tensor
-            The index of the previous path.
+        is_eos = self._update_hyps_and_scores_if_eos_token(
+            inp_tokens, alived_hyps, eos_hyps_and_log_probs_scores, scores,
+        )
 
-        Return
-        ------
-        The variable of the memory being permuted.
+        # Block the paths that have reached eos.
+        alived_hyps.sequence_scores.masked_fill_(is_eos, float("-inf"))
 
-        """
-        raise NotImplementedError
+        return (
+            alived_hyps,
+            inp_tokens,
+            log_probs,
+            eos_hyps_and_log_probs_scores,
+            memory,
+            scorer_memory,
+            attn,
+            prev_attn_peak,
+            scores,
+        )
 
-    def permute_lm_mem(self, memory, index):
-        """This method permutes the language model memory
-        to synchronize the memory index with the current output.
+    def _fill_alived_hyps_with_eos_token(
+        self, alived_hyps, eos_hyps_and_log_probs_scores, scores,
+    ):
+        """Fill the alived_hyps that have not reached eos with eos.
 
         Arguments
         ---------
-        memory : No limit
-            The memory variable to be permuted.
-        index : torch.Tensor
-            The index of the previous path.
+        alived_hyps : AlivedHypotheses
+            The alived hypotheses.
+        eos_hyps_and_log_probs_scores : list
+            Generated hypotheses (the one that haved reached eos) and log probs scores.
+        scores : torch.Tensor
+            The scores of the current step output.
 
         Returns
         -------
-        The variable of the memory being permuted.
+        eos_hyps_and_log_probs_scores : list
+            Generated hypotheses (the one that haved reached eos) and log probs scores.
         """
-        raise NotImplementedError
-
-
-class S2SRNNBeamSearcher(S2SBeamSearcher):
-    """
-    This class implements the beam search decoding
-    for AttentionalRNNDecoder (speechbrain/nnet/RNN.py).
-    See also S2SBaseSearcher(), S2SBeamSearcher().
-
-    Arguments
-    ---------
-    embedding : torch.nn.Module
-        An embedding layer.
-    decoder : torch.nn.Module
-        Attentional RNN decoder.
-    linear : torch.nn.Module
-        A linear output layer.
-    temperature : float
-        Temperature factor applied to softmax. It changes the probability
-        distribution, being softer when T>1 and sharper with T<1.
-    **kwargs
-        see S2SBeamSearcher, arguments are directly passed.
-
-    Example
-    -------
-    >>> emb = torch.nn.Embedding(5, 3)
-    >>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
-    ...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
-    ... )
-    >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
-    >>> ctc_lin = sb.nnet.linear.Linear(n_neurons=5, input_size=7)
-    >>> searcher = S2SRNNBeamSearcher(
-    ...     embedding=emb,
-    ...     decoder=dec,
-    ...     linear=lin,
-    ...     ctc_linear=ctc_lin,
-    ...     bos_index=4,
-    ...     eos_index=4,
-    ...     blank_index=4,
-    ...     min_decode_ratio=0,
-    ...     max_decode_ratio=1,
-    ...     beam_size=2,
-    ... )
-    >>> enc = torch.rand([2, 6, 7])
-    >>> wav_len = torch.rand([2])
-    >>> hyps, scores = searcher(enc, wav_len)
-    """
-
-    def __init__(
-        self,
-        embedding,
-        decoder,
-        linear,
-        ctc_linear=None,
-        temperature=1.0,
-        **kwargs,
-    ):
-        super().__init__(**kwargs)
-        self.emb = embedding
-        self.dec = decoder
-        self.fc = linear
-        self.ctc_fc = ctc_linear
-        if self.ctc_weight > 0.0 and self.ctc_fc is None:
-            raise ValueError(
-                "To perform joint ATT/CTC decoding, ctc_fc is required."
+        if not self._check_full_beams(eos_hyps_and_log_probs_scores):
+            # Using all eos to fill-up the hyps.
+            inp_tokens = (
+                torch.zeros(self.n_bh, device=self.device)
+                .fill_(self.eos_index)
+                .long()
+            )
+            self._update_hyps_and_scores_if_eos_token(
+                inp_tokens, alived_hyps, eos_hyps_and_log_probs_scores, scores,
             )
 
-        self.softmax = torch.nn.LogSoftmax(dim=-1)
-        self.temperature = temperature
-
-    def reset_mem(self, batch_size, device):
-        """Needed to reset the memory during beamsearch."""
-        hs = None
-        self.dec.attn.reset()
-        c = torch.zeros(batch_size, self.dec.attn_dim, device=device)
-        return hs, c
+        return eos_hyps_and_log_probs_scores
 
-    def forward_step(self, inp_tokens, memory, enc_states, enc_lens):
-        """Performs a step in the implemented beamsearcher."""
-        with torch.no_grad():
-            hs, c = memory
-            e = self.emb(inp_tokens)
-            dec_out, hs, c, w = self.dec.forward_step(
-                e, hs, c, enc_states, enc_lens
-            )
-            log_probs = self.softmax(self.fc(dec_out) / self.temperature)
-        # average attn weight of heads when attn_type is multiheadlocation
-        if self.dec.attn_type == "multiheadlocation":
-            w = torch.mean(w, dim=1)
-        return log_probs, (hs, c), w
+    def forward(self, enc_states, wav_len):  # noqa: C901
+        """Applies beamsearch and returns the predicted tokens.
 
-    def permute_mem(self, memory, index):
-        """Memory permutation during beamsearch."""
-        hs, c = memory
+        Arguments
+        ---------
+        enc_states : torch.Tensor
+            The encoder states to be attended.
+        wav_len : torch.Tensor
+            The actual length of each enc_states sequence.
 
-        # shape of hs: [num_layers, batch_size, n_neurons]
-        if isinstance(hs, tuple):
-            hs_0 = torch.index_select(hs[0], dim=1, index=index)
-            hs_1 = torch.index_select(hs[1], dim=1, index=index)
-            hs = (hs_0, hs_1)
-        else:
-            hs = torch.index_select(hs, dim=1, index=index)
+        Returns
+        -------
+        hyps : list
+            The predicted tokens.
+        best_lens : torch.Tensor
+            The length of each predicted tokens.
+        best_scores : torch.Tensor
+            The scores of each predicted tokens.
+        best_log_probs : torch.Tensor
+            The log probabilities of each predicted tokens.
+        """
+        (
+            alived_hyps,
+            inp_tokens,
+            log_probs,
+            eos_hyps_and_log_probs_scores,
+            memory,
+            scorer_memory,
+            attn,
+            prev_attn_peak,
+            enc_states,
+            enc_lens,
+        ) = self.init_beam_search_data(enc_states, wav_len)
+
+        for step in range(self.max_decode_steps):
+            # terminate condition
+            if self._check_full_beams(eos_hyps_and_log_probs_scores):
+                break
 
-        c = torch.index_select(c, dim=0, index=index)
-        if self.dec.attn_type == "location":
-            self.dec.attn.prev_attn = torch.index_select(
-                self.dec.attn.prev_attn, dim=0, index=index
+            (
+                alived_hyps,
+                inp_tokens,
+                log_probs,
+                eos_hyps_and_log_probs_scores,
+                memory,
+                scorer_memory,
+                attn,
+                prev_attn_peak,
+                scores,
+            ) = self.search_step(
+                alived_hyps,
+                inp_tokens,
+                log_probs,
+                eos_hyps_and_log_probs_scores,
+                memory,
+                scorer_memory,
+                attn,
+                prev_attn_peak,
+                enc_states,
+                enc_lens,
+                step,
             )
-        return (hs, c)
 
+        finals_hyps_and_log_probs_scores = self._fill_alived_hyps_with_eos_token(
+            alived_hyps, eos_hyps_and_log_probs_scores, scores,
+        )
 
-class S2SRNNBeamSearchLM(S2SRNNBeamSearcher):
-    """This class implements the beam search decoding
-    for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM.
-    See also S2SBaseSearcher(), S2SBeamSearcher(), S2SRNNBeamSearcher().
+        (
+            topk_hyps,
+            topk_lengths,
+            topk_scores,
+            topk_log_probs,
+        ) = self._get_topk_prediction(finals_hyps_and_log_probs_scores)
 
-    Arguments
-    ---------
-    embedding : torch.nn.Module
-        An embedding layer.
-    decoder : torch.nn.Module
-        Attentional RNN decoder.
-    linear : torch.nn.Module
-        A linear output layer.
-    language_model : torch.nn.Module
-        A language model.
-    temperature_lm : float
-        Temperature factor applied to softmax. It changes the probability
-        distribution, being softer when T>1 and sharper with T<1.
-    **kwargs
-        Arguments to pass to S2SBeamSearcher.
+        if self.return_topk:
+            return topk_hyps, topk_lengths, topk_scores, topk_log_probs
+        else:
+            # select the best hyps
+            best_hyps = topk_hyps[:, 0, :]
+            best_lens = topk_lengths[:, 0]
+            best_scores = topk_scores[:, 0]
+            best_log_probs = topk_log_probs[:, 0, :]
 
-    Example
-    -------
-    >>> from speechbrain.lobes.models.RNNLM import RNNLM
-    >>> emb = torch.nn.Embedding(5, 3)
-    >>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
-    ...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
-    ... )
-    >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
-    >>> lm = RNNLM(output_neurons=5, return_hidden=True)
-    >>> searcher = S2SRNNBeamSearchLM(
-    ...     embedding=emb,
-    ...     decoder=dec,
-    ...     linear=lin,
-    ...     language_model=lm,
-    ...     bos_index=4,
-    ...     eos_index=4,
-    ...     blank_index=4,
-    ...     min_decode_ratio=0,
-    ...     max_decode_ratio=1,
-    ...     beam_size=2,
-    ...     lm_weight=0.5,
-    ... )
-    >>> enc = torch.rand([2, 6, 7])
-    >>> wav_len = torch.rand([2])
-    >>> hyps, scores = searcher(enc, wav_len)
-    """
+            # Convert best hypothesis to list
+            hyps = undo_padding(best_hyps, best_lens)
 
-    def __init__(
-        self,
-        embedding,
-        decoder,
-        linear,
-        language_model,
-        temperature_lm=1.0,
-        **kwargs,
-    ):
-        super().__init__(
-            embedding, decoder, linear, **kwargs
-        )
+            return hyps, best_lens, best_scores, best_log_probs
 
-        self.lm = language_model
-        self.lm.eval()
-        self.log_softmax = sb.nnet.activations.Softmax(apply_log=True)
-        self.temperature_lm = temperature_lm
+    def permute_mem(self, memory, index):
+        """This method permutes the seq2seq model memory
+        to synchronize the memory index with the current output.
 
-    def lm_forward_step(self, inp_tokens, memory):
-        """Applies a step to the LM during beamsearch."""
-        with torch.no_grad():
-            logits, hs = self.lm(inp_tokens, hx=memory)
-            log_probs = self.log_softmax(logits / self.temperature_lm)
+        Arguments
+        ---------
+        memory : No limit
+            The memory variable to be permuted.
+        index : torch.Tensor
+            The index of the previous path.
 
-        return log_probs, hs
+        Return
+        ------
+        The variable of the memory being permuted.
 
-    def permute_lm_mem(self, memory, index):
-        """This is to permute lm memory to synchronize with current index
-        during beam search. The order of beams will be shuffled by scores
-        every timestep to allow batched beam search.
-        Further details please refer to speechbrain/decoder/seq2seq.py.
         """
-
-        if isinstance(memory, tuple):
-            memory_0 = torch.index_select(memory[0], dim=1, index=index)
-            memory_1 = torch.index_select(memory[1], dim=1, index=index)
-            memory = (memory_0, memory_1)
-        else:
-            memory = torch.index_select(memory, dim=1, index=index)
-        return memory
-
-    def reset_lm_mem(self, batch_size, device):
-        """Needed to reset the LM memory during beamsearch."""
-        # set hidden_state=None, pytorch RNN will automatically set it to
-        # zero vectors.
-        return None
+        raise NotImplementedError
 
 
-class S2SRNNBeamSearchTransformerLM(S2SRNNBeamSearcher):
-    """This class implements the beam search decoding
-    for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM.
-    See also S2SBaseSearcher(), S2SBeamSearcher(), S2SRNNBeamSearcher().
+class S2SRNNBeamSearcher(S2SBeamSearcher):
+    """
+    This class implements the beam search decoding
+    for AttentionalRNNDecoder (speechbrain/nnet/RNN.py).
+    See also S2SBaseSearcher(), S2SBeamSearcher().
 
     Arguments
     ---------
@@ -1221,127 +1387,158 @@ class S2SRNNBeamSearchTransformerLM(S2SRNNBeamSearcher):
         Attentional RNN decoder.
     linear : torch.nn.Module
         A linear output layer.
-    language_model : torch.nn.Module
-        A language model.
-    temperature_lm : float
+    temperature : float
         Temperature factor applied to softmax. It changes the probability
         distribution, being softer when T>1 and sharper with T<1.
     **kwargs
-        Arguments to pass to S2SBeamSearcher.
+        see S2SBeamSearcher, arguments are directly passed.
 
     Example
     -------
-    >>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM
-    >>> emb = torch.nn.Embedding(5, 3)
+    >>> import speechbrain as sb
+    >>> vocab_size = 5
+    >>> emb = torch.nn.Embedding(vocab_size, 3)
     >>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
     ...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
     ... )
-    >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
-    >>> lm = TransformerLM(5, 512, 8, 1, 0, 1024, activation=torch.nn.GELU)
-    >>> searcher = S2SRNNBeamSearchTransformerLM(
+    >>> lin = sb.nnet.linear.Linear(n_neurons=vocab_size, input_size=3)
+    >>> coverage_scorer = sb.decoders.scorer.CoverageScorer(vocab_size)
+    >>> scorer = sb.decoders.scorer.ScorerBuilder(
+    ...     full_scorers = [coverage_scorer],
+    ...     partial_scorers = [],
+    ...     weights= dict(coverage=1.5)
+    ... )
+    >>> searcher = S2SRNNBeamSearcher(
     ...     embedding=emb,
     ...     decoder=dec,
     ...     linear=lin,
-    ...     language_model=lm,
     ...     bos_index=4,
     ...     eos_index=4,
-    ...     blank_index=4,
     ...     min_decode_ratio=0,
     ...     max_decode_ratio=1,
     ...     beam_size=2,
-    ...     lm_weight=0.5,
+    ...     scorer=scorer,
     ... )
-    >>> enc = torch.rand([2, 6, 7])
-    >>> wav_len = torch.rand([2])
-    >>> hyps, scores = searcher(enc, wav_len)
+    >>> batch_size = 2
+    >>> enc = torch.rand([batch_size, 6, 7])
+    >>> wav_len = torch.ones([batch_size])
+    >>> hyps, _, _, _ = searcher(enc, wav_len)
     """
 
     def __init__(
-        self,
-        embedding,
-        decoder,
-        linear,
-        language_model,
-        temperature_lm=1.0,
-        **kwargs,
+        self, embedding, decoder, linear, temperature=1.0, **kwargs,
     ):
-        super().__init__(
-            embedding, decoder, linear, **kwargs
-        )
+        super().__init__(**kwargs)
+        self.emb = embedding
+        self.dec = decoder
+        self.fc = linear
+        self.softmax = torch.nn.LogSoftmax(dim=-1)
+        self.temperature = temperature
 
-        self.lm = language_model
-        self.lm.eval()
-        self.log_softmax = sb.nnet.activations.Softmax(apply_log=True)
-        self.temperature_lm = temperature_lm
+    def reset_mem(self, batch_size, device):
+        """Needed to reset the memory during beamsearch."""
+        hs = None
+        self.dec.attn.reset()
+        c = torch.zeros(batch_size, self.dec.attn_dim, device=device)
+        return hs, c
 
-    def lm_forward_step(self, inp_tokens, memory):
-        """Performs a step in the LM during beamsearch."""
-        memory = _update_mem(inp_tokens, memory)
-        if not next(self.lm.parameters()).is_cuda:
-            self.lm.to(inp_tokens.device)
-        logits = self.lm(memory)
-        log_probs = self.softmax(logits / self.temperature_lm)
-        return log_probs[:, -1, :], memory
-
-    def permute_lm_mem(self, memory, index):
-        """Permutes the LM ,emory during beamsearch"""
-        memory = torch.index_select(memory, dim=0, index=index)
-        return memory
+    def forward_step(self, inp_tokens, memory, enc_states, enc_lens):
+        """Performs a step in the implemented beamsearcher."""
+        with torch.no_grad():
+            hs, c = memory
+            e = self.emb(inp_tokens)
+            dec_out, hs, c, w = self.dec.forward_step(
+                e, hs, c, enc_states, enc_lens
+            )
+            log_probs = self.softmax(self.fc(dec_out) / self.temperature)
+            # average attn weight of heads when attn_type is multiheadlocation
+            if self.dec.attn_type == "multiheadlocation":
+                w = torch.mean(w, dim=1)
+        return log_probs, (hs, c), w
 
-    def reset_lm_mem(self, batch_size, device):
-        """Needed to reset the LM memory during beamsearch"""
-        # set hidden_state=None, pytorch RNN will automatically set it to
-        # zero vectors.
-        return None
+    def permute_mem(self, memory, index):
+        """Memory permutation during beamsearch."""
+        hs, c = memory
+
+        # shape of hs: [num_layers, batch_size, n_neurons]
+        if isinstance(hs, tuple):
+            hs_0 = torch.index_select(hs[0], dim=1, index=index)
+            hs_1 = torch.index_select(hs[1], dim=1, index=index)
+            hs = (hs_0, hs_1)
+        else:
+            hs = torch.index_select(hs, dim=1, index=index)
+
+        c = torch.index_select(c, dim=0, index=index)
+        if self.dec.attn_type == "location":
+            self.dec.attn.prev_attn = torch.index_select(
+                self.dec.attn.prev_attn, dim=0, index=index
+            )
+        return (hs, c)
 
 
-class S2STransformerBeamSearch(S2SBeamSearcher):
+class S2STransformerBeamSearcher(S2SBeamSearcher):
     """This class implements the beam search decoding
     for Transformer.
     See also S2SBaseSearcher(), S2SBeamSearcher().
-
     Arguments
     ---------
-    model : torch.nn.Module
-        The model to use for decoding.
+    modules : list with the followings one:
+        model : torch.nn.Module
+            A Transformer model.
+        seq_lin : torch.nn.Module
+            A linear output layer.
     linear : torch.nn.Module
         A linear output layer.
     **kwargs
         Arguments to pass to S2SBeamSearcher
-
-    Example:
-    --------
-    >>> # see recipes/LibriSpeech/ASR_transformer/experiment.py
+    Example
+    -------
+    >>> from speechbrain.nnet.linear import Linear
+    >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
+    >>> from speechbrain.decoders import S2STransformerBeamSearcher
+    >>> batch_size=8
+    >>> n_channels=6
+    >>> input_size=40
+    >>> d_model=128
+    >>> tgt_vocab=140
+    >>> src = torch.rand([batch_size, n_channels, input_size])
+    >>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels])
+    >>> net = TransformerASR(
+    ...    tgt_vocab, input_size, d_model, 8, 1, 1, 1024, activation=torch.nn.GELU
+    ... )
+    >>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab)
+    >>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab)
+    >>> searcher = S2STransformerBeamSearcher(
+    ...     modules=[net, lin],
+    ...     bos_index=1,
+    ...     eos_index=2,
+    ...     min_decode_ratio=0.0,
+    ...     max_decode_ratio=1.0,
+    ...     using_eos_threshold=False,
+    ...     beam_size=7,
+    ...     temperature=1.15,
+    ... )
+    >>> enc, dec = net.forward(src, tgt)
+    >>> hyps, _, _, _  = searcher(enc, torch.ones(batch_size))
     """
 
     def __init__(
-        self, modules, temperature=1.0, temperature_lm=1.0, **kwargs,
+        self, modules, temperature=1.0, **kwargs,
     ):
         super().__init__(**kwargs)
 
         self.model = modules[0]
         self.fc = modules[1]
-        self.ctc_fc = modules[2]
         self.softmax = torch.nn.LogSoftmax(dim=-1)
 
         self.temperature = temperature
-        self.temperature_lm = temperature_lm
 
     def reset_mem(self, batch_size, device):
         """Needed to reset the memory during beamsearch."""
         return None
 
-    def reset_lm_mem(self, batch_size, device):
-        """Needed to reset the LM memory during beamsearch."""
-        return None
-
     def permute_mem(self, memory, index):
-        """Permutes the memory."""
-        memory = torch.index_select(memory, dim=0, index=index)
-        return memory
-
-    def permute_lm_mem(self, memory, index):
-        """Permutes the memory of the language model."""
+        """Memory permutation during beamsearch."""
         memory = torch.index_select(memory, dim=0, index=index)
         return memory
 
@@ -1352,14 +1549,114 @@ class S2STransformerBeamSearch(S2SBeamSearcher):
         prob_dist = self.softmax(self.fc(pred) / self.temperature)
         return prob_dist[:, -1, :], memory, attn
 
-    def lm_forward_step(self, inp_tokens, memory):
-        """Performs a step in the implemented LM module."""
+
+class S2SWhisperGreedySearch(S2SGreedySearcher):
+    """
+    This class implements the greedy decoding
+    for Whisper neural nets made by OpenAI in
+    https://cdn.openai.com/papers/whisper.pdf.
+    Arguments
+    ---------
+    model : HuggingFaceWhisper
+        The Whisper model.
+    language_token : int
+        The language token to be used for the decoder input.
+    bos_token : int
+        The beginning of sentence token to be used for the decoder input.
+    task_token : int
+        The task token to be used for the decoder input.
+    timestamp_token : int
+        The timestamp token to be used for the decoder input.
+    max_length : int
+        The maximum decoding steps to perform.
+        The Whisper model has a maximum length of 448.
+    **kwargs
+        see S2SBaseSearcher, arguments are directly passed.
+    """
+
+    def __init__(
+        self,
+        model,
+        language_token=50259,
+        bos_token=50258,
+        task_token=50359,
+        timestamp_token=50363,
+        max_length=448,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.model = model
+        self.softmax = torch.nn.LogSoftmax(dim=-1)
+        self.decoder_input_tokens = None
+        self.language_token = language_token  # default language is english
+        self.bos_token = bos_token  # always this value
+        self.task_token = task_token  # default task is transcribe
+        self.timestamp_token = timestamp_token  # default is notimestamp
+        self.max_length = max_length - 3  # 3 tokens are added to the input
+
+    def set_language_token(self, language_token):
+        """set the language token to be used for the decoder input."""
+        self.language_token = language_token
+
+    def set_bos_token(self, bos_token):
+        """set the bos token to be used for the decoder input."""
+        self.bos_token = bos_token
+
+    def set_task_token(self, task_token):
+        """set the task token to be used for the decoder input."""
+        self.task_token = task_token
+
+    def set_timestamp_token(self, timestamp_token):
+        """set the timestamp token to be used for the decoder input."""
+        self.timestamp_token = timestamp_token
+        # need to reset bos_index too as timestamp_token is the first
+        # inp_token and need to be the first so that the first input gave
+        # to the model is [bos, language, task, timestamp] (order matters).
+        self.bos_index = self.timestamp_token
+
+    def set_decoder_input_tokens(self, decoder_input_tokens):
+        """decoder_input_tokens are the tokens used as input to the decoder.
+        They are directly taken from the tokenizer.prefix_tokens attribute.
+        decoder_input_tokens = [bos_token, language_token, task_token, timestamp_token]
+        """
+        self.set_bos_token(decoder_input_tokens[0])
+        self.set_language_token(decoder_input_tokens[1])
+        self.set_task_token(decoder_input_tokens[2])
+        self.set_timestamp_token(decoder_input_tokens[3])
+
+        # bos will be timestamp in our case.
+        self.decoder_input_tokens = [
+            self.bos_token,
+            self.language_token,
+            self.task_token,
+        ]
+
+    def reset_mem(self, batch_size, device):
+        """This method set the first tokens to be decoder_input_tokens during search."""
+        return torch.tensor([self.decoder_input_tokens] * batch_size).to(device)
+
+    def permute_mem(self, memory, index):
+        """Memory permutation during beamsearch."""
+        memory = torch.index_select(memory, dim=0, index=index)
+        return memory
+
+    def forward_step(self, inp_tokens, memory, enc_states, enc_lens):
+        """Performs a step in the implemented beamsearcher."""
         memory = _update_mem(inp_tokens, memory)
-        if not next(self.lm_modules.parameters()).is_cuda:
-            self.lm_modules.to(inp_tokens.device)
-        logits = self.lm_modules(memory)
-        log_probs = self.softmax(logits / self.temperature_lm)
-        return log_probs[:, -1, :], memory
+
+        # WARNING: the max_decode_ratio need to be under 448 because
+        #  of positinal encoding
+        dec_out, attn = self.model.forward_decoder(enc_states, memory)
+        log_probs = self.softmax(dec_out[:, -1])
+
+        return log_probs, memory, attn
+
+    def change_max_decoding_length(self, min_decode_steps, max_decode_steps):
+        """set the minimum/maximum length the decoder can take."""
+        return (
+            int(self.min_decode_ratio * self.max_length),
+            int(self.max_decode_ratio * self.max_length),
+        )
 
 
 class S2STransformerGreedySearch(S2SGreedySearcher):
@@ -1406,7 +1703,6 @@ class S2SWhisperBeamSearch(S2SBeamSearcher):
     """This class implements the beam search decoding
     for Whisper neural nets made by OpenAI in
     https://cdn.openai.com/papers/whisper.pdf.
-
     Arguments
     ---------
     module : list with the followings one:
@@ -1433,24 +1729,20 @@ class S2SWhisperBeamSearch(S2SBeamSearcher):
         self,
         module,
         temperature=1.0,
-        temperature_lm=1.0,
         language_token=50259,
         bos_token=50258,
         task_token=50359,
         timestamp_token=50363,
-        max_length=447,
+        max_length=448,
         **kwargs,
     ):
         super().__init__(**kwargs)
 
         self.model = module[0]
-        if len(module) == 2:
-            self.ctc_fc = module[1]
 
         self.softmax = torch.nn.LogSoftmax(dim=-1)
 
         self.temperature = temperature
-        self.temperature_lm = temperature_lm
 
         self.decoder_input_tokens = None
         self.language_token = language_token  # default language is english
@@ -1490,7 +1782,6 @@ class S2SWhisperBeamSearch(S2SBeamSearcher):
     def set_decoder_input_tokens(self, decoder_input_tokens):
         """decoder_input_tokens are the tokens used as input to the decoder.
         They are directly taken from the tokenizer.prefix_tokens attribute.
-
         decoder_input_tokens = [bos_token, language_token, task_token, timestamp_token]
         """
         self.set_bos_token(decoder_input_tokens[0])
@@ -1509,175 +1800,59 @@ class S2SWhisperBeamSearch(S2SBeamSearcher):
         """This method set the first tokens to be decoder_input_tokens during search."""
         return torch.tensor([self.decoder_input_tokens] * batch_size).to(device)
 
-    def reset_lm_mem(self, batch_size, device):
-        """Needed to reset the LM memory during beamsearch."""
-        return None
-
     def permute_mem(self, memory, index):
         """Permutes the memory."""
         memory = torch.index_select(memory, dim=0, index=index)
         return memory
 
-    def permute_lm_mem(self, memory, index):
-        """Permutes the memory of the language model."""
-        memory = torch.index_select(memory, dim=0, index=index)
-        return memory
+    def set_n_out(self):
+        """set the number of output tokens."""
+        return self.model.model.decoder.embed_tokens.weight.shape[0]
 
     def forward_step(self, inp_tokens, memory, enc_states, enc_lens):
         """Performs a step in the implemented beamsearcher."""
         memory = _update_mem(inp_tokens, memory)
         dec_out, attn, = self.model.forward_decoder(enc_states, memory)
-        log_probs = self.softmax(dec_out[:, -1])
+        log_probs = self.softmax(dec_out[:, -1] / self.temperature)
         return log_probs, memory, attn
 
-    def lm_forward_step(self, inp_tokens, memory):
-        """Performs a step in the implemented LM module."""
-        memory = _update_mem(inp_tokens, memory)
-        if not next(self.lm_modules.parameters()).is_cuda:
-            self.lm_modules.to(inp_tokens.device)
-        logits = self.lm_modules(memory)
-        log_probs = self.softmax(logits / self.temperature_lm)
-        return log_probs[:, -1, :], memory
-
-
-def batch_filter_seq2seq_output(prediction, eos_id=-1):
-    """Calling batch_size times of filter_seq2seq_output.
-
-    Arguments
-    ---------
-    prediction : list of torch.Tensor
-        A list containing the output ints predicted by the seq2seq system.
-    eos_id : int, string
-        The id of the eos.
-
-    Returns
-    ------
-    list
-        The output predicted by seq2seq model.
-
-    Example
-    -------
-    >>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])]
-    >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4)
-    >>> predictions
-    [[1, 2, 3], [2, 3]]
-    """
-    outputs = []
-    for p in prediction:
-        res = filter_seq2seq_output(p.tolist(), eos_id=eos_id)
-        outputs.append(res)
-    return outputs
-
-
-def filter_seq2seq_output(string_pred, eos_id=-1):
-    """Filter the output until the first eos occurs (exclusive).
-
-    Arguments
-    ---------
-    string_pred : list
-        A list containing the output strings/ints predicted by the seq2seq system.
-    eos_id : int, string
-        The id of the eos.
-
-    Returns
-    ------
-    list
-        The output predicted by seq2seq model.
-
-    Example
-    -------
-    >>> string_pred = ['a','b','c','d','eos','e']
-    >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos')
-    >>> string_out
-    ['a', 'b', 'c', 'd']
-    """
-    if isinstance(string_pred, list):
-        try:
-            eos_index = next(
-                i for i, v in enumerate(string_pred) if v == eos_id
-            )
-        except StopIteration:
-            eos_index = len(string_pred)
-        string_out = string_pred[:eos_index]
-    else:
-        raise ValueError("The input must be a list.")
-    return string_out
 
-
-def inflate_tensor(tensor, times, dim):
-    """This function inflates the tensor for times along dim.
+class S2SHFTextBasedBeamSearcher(S2STransformerBeamSearcher):
+    """This class implements the beam search decoding
+    for the text-based HF seq2seq models, such as mBART or NLLB.
+    It is NOT significantly different from S2STransformerBeamSearcher.
+    This is why it inherits S2STransformerBeamSearcher.
+    The main difference might arise when one wishes to use directly
+    the lm_head of the text-based HF model rather than making a new
+    projection layer (self.fc = None).
 
     Arguments
     ---------
-    tensor : torch.Tensor
-        The tensor to be inflated.
-    times : int
-        The tensor will inflate for this number of times.
-    dim : int
-        The dim to be inflated.
-
-    Returns
-    -------
-    torch.Tensor
-        The inflated tensor.
-
-    Example
-    -------
-    >>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
-    >>> new_tensor = inflate_tensor(tensor, 2, dim=0)
-    >>> new_tensor
-    tensor([[1., 2., 3.],
-            [1., 2., 3.],
-            [4., 5., 6.],
-            [4., 5., 6.]])
+    modules : list with the followings one:
+        model : torch.nn.Module
+            A Transformer model.
+        seq_lin : torch.nn.Module
+            A linear output layer.
+            Normally set to None for this usecase.
+    vocab_size : int
+        The dimension of the lm_head.
+    **kwargs
+        Arguments to pass to S2SBeamSearcher
     """
-    return torch.repeat_interleave(tensor, times, dim=dim)
-
 
-def mask_by_condition(tensor, cond, fill_value):
-    """This function will mask some element in the tensor with fill_value, if condition=False.
-
-    Arguments
-    ---------
-    tensor : torch.Tensor
-        The tensor to be masked.
-    cond : torch.BoolTensor
-        This tensor has to be the same size as tensor.
-        Each element represents whether to keep the value in tensor.
-    fill_value : float
-        The value to fill in the masked element.
+    def __init__(self, modules, vocab_size, **kwargs):
+        super(S2SHFTextBasedBeamSearcher, self).__init__(modules, **kwargs)
+        self.vocab_size = vocab_size
 
-    Returns
-    -------
-    torch.Tensor
-        The masked tensor.
+    def forward_step(self, inp_tokens, memory, enc_states, enc_lens):
+        """Performs a step in the implemented beamsearcher."""
+        memory = _update_mem(inp_tokens, memory)
+        pred, attn = self.model.decode(memory, enc_states, enc_lens)
+        if self.fc is not None:
+            pred = self.fc(pred)
+        prob_dist = self.softmax(pred / self.temperature)
+        return prob_dist[:, -1, :], memory, attn
 
-    Example
-    -------
-    >>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
-    >>> cond = torch.BoolTensor([[True, True, False], [True, False, False]])
-    >>> mask_by_condition(tensor, cond, 0)
-    tensor([[1., 2., 0.],
-            [4., 0., 0.]])
-    """
-    tensor = torch.where(
-        cond, tensor, torch.Tensor([fill_value]).to(tensor.device)
-    )
-    return tensor
-
-
-def _update_mem(inp_tokens, memory):
-    """This function is for updating the memory for transformer searches.
-    it is called at each decoding step. When being called, it appends the
-    predicted token of the previous step to existing memory.
-
-    Arguments:
-    -----------
-    inp_tokens : tensor
-        Predicted token of the previous decoding step.
-    memory : tensor
-        Contains all the predicted tokens.
-    """
-    if memory is None:
-        return inp_tokens.unsqueeze(1)
-    return torch.cat([memory, inp_tokens.unsqueeze(1)], dim=-1)
+    def set_n_out(self):
+        """set the number of output tokens."""
+        return self.vocab_size
diff --git a/speechbrain/decoders/transducer.py b/speechbrain/decoders/transducer.py
index e9db6f7f1d700ec9933f35f183eedf321982e3d8..403eb820a22b30c11872aa4a1f69c8babea2cb02 100644
--- a/speechbrain/decoders/transducer.py
+++ b/speechbrain/decoders/transducer.py
@@ -5,7 +5,19 @@ Author:
     Sung-Lin Yeh 2020
 """
 import torch
+from dataclasses import dataclass
 from functools import partial
+from typing import Optional, Any
+
+
+@dataclass
+class TransducerGreedySearcherStreamingContext(torch.nn.Module):
+    """Simple wrapper for the hidden state of the transducer greedy searcher.
+    Used by :meth:`~TransducerBeamSearcher.transducer_greedy_decode_streaming`.
+    """
+
+    hidden: Optional[Any] = None
+    """Hidden state; typically a tensor or a tuple of tensors."""
 
 
 class TransducerBeamSearcher(torch.nn.Module):
@@ -81,7 +93,7 @@ class TransducerBeamSearcher(torch.nn.Module):
     ...     lm_weight=0.0,
     ... )
     >>> enc = torch.rand([1, 20, 10])
-    >>> hyps, scores, _, _ = searcher(enc)
+    >>> hyps, _, _, _ = searcher(enc)
     """
 
     def __init__(
@@ -135,7 +147,9 @@ class TransducerBeamSearcher(torch.nn.Module):
         hyps = self.searcher(tn_output)
         return hyps
 
-    def transducer_greedy_decode(self, tn_output):
+    def transducer_greedy_decode(
+        self, tn_output, hidden_state=None, return_hidden=False
+    ):
         """Transducer greedy decoder is a greedy decoder over batch which apply Transducer rules:
             1- for each time step in the Transcription Network (TN) output:
                 -> Update the ith utterance only if
@@ -149,18 +163,43 @@ class TransducerBeamSearcher(torch.nn.Module):
             Output from transcription network with shape
             [batch, time_len, hiddens].
 
+        hidden_state : (torch.Tensor, torch.Tensor)
+            Hidden state to initially feed the decode network with. This is
+            useful in conjunction with `return_hidden` to be able to perform
+            beam search in a streaming context, so that you can reuse the last
+            hidden state as an initial state across calls.
+
+        return_hidden : bool
+            Whether the return tuple should contain an extra 5th element with
+            the hidden state at of the last step. See `hidden_state`.
+
         Returns
         -------
-        torch.tensor
+        Tuple of 4 or 5 elements (if `return_hidden`).
+
+        First element: List[List[int]]
+            List of decoded tokens
+
+        Second element: torch.Tensor
             Outputs a logits tensor [B,T,1,Output_Dim]; padding
             has not been removed.
+
+        Third element: None
+            nbest; irrelevant for greedy decode
+
+        Fourth element: None
+            nbest scores; irrelevant for greedy decode
+
+        Fifth element: Present if `return_hidden`, (torch.Tensor, torch.Tensor)
+            Tuple representing the hidden state required to call
+            `transducer_greedy_decode` where you left off in a streaming
+            context.
         """
         hyp = {
             "prediction": [[] for _ in range(tn_output.size(0))],
             "logp_scores": [0.0 for _ in range(tn_output.size(0))],
         }
         # prepare BOS = Blank for the Prediction Network (PN)
-        hidden = None
         input_PN = (
             torch.ones(
                 (tn_output.size(0), 1),
@@ -169,8 +208,13 @@ class TransducerBeamSearcher(torch.nn.Module):
             )
             * self.blank_id
         )
-        # First forward-pass on PN
-        out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst)
+
+        if hidden_state is None:
+            # First forward-pass on PN
+            out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst)
+        else:
+            out_PN, hidden = hidden_state
+
         # For each time step
         for t_step in range(tn_output.size(1)):
             # do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
@@ -180,7 +224,7 @@ class TransducerBeamSearcher(torch.nn.Module):
             )
             # Sort outputs at time
             logp_targets, positions = torch.max(
-                self.softmax(log_probs).squeeze(1).squeeze(1), dim=1
+                log_probs.squeeze(1).squeeze(1), dim=1
             )
             # Batch hidden update
             have_update_hyp = []
@@ -210,13 +254,42 @@ class TransducerBeamSearcher(torch.nn.Module):
                     have_update_hyp, selected_hidden, hidden
                 )
 
-        return (
+        ret = (
             hyp["prediction"],
             torch.Tensor(hyp["logp_scores"]).exp().mean(),
             None,
             None,
         )
 
+        if return_hidden:
+            # append the `(out_PN, hidden)` tuple to ret
+            ret += ((out_PN, hidden,),)
+
+        return ret
+
+    def transducer_greedy_decode_streaming(
+        self, x: torch.Tensor, context: TransducerGreedySearcherStreamingContext
+    ):
+        """Tiny wrapper for
+        :meth:`~TransducerBeamSearcher.transducer_greedy_decode` with an API
+        that makes it suitable to be passed as a `decoding_function` for
+        streaming.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Outputs of the prediction network (equivalent to `tn_output`)
+        context : TransducerGreedySearcherStreamingContext
+            Mutable streaming context object, which must be specified and reused
+            across calls when streaming.
+            You can obtain an initial context by initializing a default object.
+        """
+        (hyp, _scores, _, _, hidden) = self.transducer_greedy_decode(
+            x, context.hidden, return_hidden=True
+        )
+        context.hidden = hidden
+        return hyp
+
     def transducer_beam_search_decode(self, tn_output):
         """Transducer beam search decoder is a beam search decoder over batch which apply Transducer rules:
             1- for each utterance:
diff --git a/speechbrain/decoders/utils.py b/speechbrain/decoders/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2c3e283f459aae58f1186f5debf27ba8282f79a
--- /dev/null
+++ b/speechbrain/decoders/utils.py
@@ -0,0 +1,144 @@
+""" Utils functions for the decoding modules.
+
+Authors
+ * Adel Moumen 2023
+ * Ju-Chieh Chou 2020
+ * Peter Plantinga 2020
+ * Mirco Ravanelli 2020
+ * Sung-Lin Yeh 2020
+"""
+
+import torch
+
+
+def _update_mem(inp_tokens, memory):
+    """This function is for updating the memory for transformer searches.
+    it is called at each decoding step. When being called, it appends the
+    predicted token of the previous step to existing memory.
+    Arguments:
+    -----------
+    inp_tokens : tensor
+        Predicted token of the previous decoding step.
+    memory : tensor
+        Contains all the predicted tokens.
+    """
+    if memory is None:
+        memory = torch.empty(inp_tokens.size(0), 0, device=inp_tokens.device)
+    return torch.cat([memory, inp_tokens.unsqueeze(1)], dim=-1)
+
+
+def inflate_tensor(tensor, times, dim):
+    """This function inflates the tensor for times along dim.
+
+    Arguments
+    ---------
+    tensor : torch.Tensor
+        The tensor to be inflated.
+    times : int
+        The tensor will inflate for this number of times.
+    dim : int
+        The dim to be inflated.
+
+    Returns
+    -------
+    torch.Tensor
+        The inflated tensor.
+
+    Example
+    -------
+    >>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
+    >>> new_tensor = inflate_tensor(tensor, 2, dim=0)
+    >>> new_tensor
+    tensor([[1., 2., 3.],
+            [1., 2., 3.],
+            [4., 5., 6.],
+            [4., 5., 6.]])
+    """
+    return torch.repeat_interleave(tensor, times, dim=dim)
+
+
+def mask_by_condition(tensor, cond, fill_value):
+    """This function will mask some element in the tensor with fill_value, if condition=False.
+
+    Arguments
+    ---------
+    tensor : torch.Tensor
+        The tensor to be masked.
+    cond : torch.BoolTensor
+        This tensor has to be the same size as tensor.
+        Each element represents whether to keep the value in tensor.
+    fill_value : float
+        The value to fill in the masked element.
+
+    Returns
+    -------
+    torch.Tensor
+        The masked tensor.
+
+    Example
+    -------
+    >>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
+    >>> cond = torch.BoolTensor([[True, True, False], [True, False, False]])
+    >>> mask_by_condition(tensor, cond, 0)
+    tensor([[1., 2., 0.],
+            [4., 0., 0.]])
+    """
+    return torch.where(cond, tensor, fill_value)
+
+
+def batch_filter_seq2seq_output(prediction, eos_id=-1):
+    """Calling batch_size times of filter_seq2seq_output.
+    Arguments
+    ---------
+    prediction : list of torch.Tensor
+        A list containing the output ints predicted by the seq2seq system.
+    eos_id : int, string
+        The id of the eos.
+    Returns
+    ------
+    list
+        The output predicted by seq2seq model.
+    Example
+    -------
+    >>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])]
+    >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4)
+    >>> predictions
+    [[1, 2, 3], [2, 3]]
+    """
+    outputs = []
+    for p in prediction:
+        res = filter_seq2seq_output(p.tolist(), eos_id=eos_id)
+        outputs.append(res)
+    return outputs
+
+
+def filter_seq2seq_output(string_pred, eos_id=-1):
+    """Filter the output until the first eos occurs (exclusive).
+    Arguments
+    ---------
+    string_pred : list
+        A list containing the output strings/ints predicted by the seq2seq system.
+    eos_id : int, string
+        The id of the eos.
+    Returns
+    ------
+    list
+        The output predicted by seq2seq model.
+    Example
+    -------
+    >>> string_pred = ['a','b','c','d','eos','e']
+    >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos')
+    >>> string_out
+    ['a', 'b', 'c', 'd']
+    """
+    if isinstance(string_pred, list):
+        try:
+            eos_index = next(
+                i for i, v in enumerate(string_pred) if v == eos_id
+            )
+        except StopIteration:
+            eos_index = len(string_pred)
+        string_out = string_pred[:eos_index]
+    else:
+        raise ValueError("The input must be a list.")
+    return string_out
diff --git a/speechbrain/inference/ASR.py b/speechbrain/inference/ASR.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbb18fa160bcc97113557b1b119b85547dd20c36
--- /dev/null
+++ b/speechbrain/inference/ASR.py
@@ -0,0 +1,878 @@
+""" Specifies the inference interfaces for Automatic speech Recognition (ASR) modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023, 2024
+ * Adel Moumen 2023, 2024
+ * Pradnya Kandarkar 2023
+"""
+from dataclasses import dataclass
+from typing import Any, Optional, List
+import itertools
+import torch
+import torchaudio
+import sentencepiece
+import speechbrain
+from speechbrain.inference.interfaces import Pretrained
+import functools
+from speechbrain.utils.fetching import fetch
+from speechbrain.utils.data_utils import split_path
+from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+from speechbrain.utils.streaming import split_fixed_chunks
+
+
+class EncoderDecoderASR(Pretrained):
+    """A ready-to-use Encoder-Decoder ASR model
+
+    The class can be used either to run only the encoder (encode()) to extract
+    features or to run the entire encoder-decoder model
+    (transcribe()) to transcribe speech. The given YAML must contain the fields
+    specified in the *_NEEDED[] lists.
+
+    Example
+    -------
+    >>> from speechbrain.inference.ASR import EncoderDecoderASR
+    >>> tmpdir = getfixture("tmpdir")
+    >>> asr_model = EncoderDecoderASR.from_hparams(
+    ...     source="speechbrain/asr-crdnn-rnnlm-librispeech",
+    ...     savedir=tmpdir,
+    ... )  # doctest: +SKIP
+    >>> asr_model.transcribe_file("tests/samples/single-mic/example2.flac")  # doctest: +SKIP
+    "MY FATHER HAS REVEALED THE CULPRIT'S NAME"
+    """
+
+    HPARAMS_NEEDED = ["tokenizer"]
+    MODULES_NEEDED = ["encoder", "decoder"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.tokenizer = self.hparams.tokenizer
+        self.transducer_beam_search = False
+        self.transformer_beam_search = False
+        if hasattr(self.hparams, "transducer_beam_search"):
+            self.transducer_beam_search = self.hparams.transducer_beam_search
+        if hasattr(self.hparams, "transformer_beam_search"):
+            self.transformer_beam_search = self.hparams.transformer_beam_search
+
+    def transcribe_file(self, path, **kwargs):
+        """Transcribes the given audiofile into a sequence of words.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file which to transcribe.
+
+        Returns
+        -------
+        str
+            The audiofile transcription produced by this ASR system.
+        """
+        waveform = self.load_audio(path, **kwargs)
+        # Fake a batch:
+        batch = waveform.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+        predicted_words, predicted_tokens = self.transcribe_batch(
+            batch, rel_length
+        )
+        return predicted_words[0]
+
+    def encode_batch(self, wavs, wav_lens):
+        """Encodes the input audio into a sequence of hidden states
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        torch.Tensor
+            The encoded batch
+        """
+        wavs = wavs.float()
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+        encoder_out = self.mods.encoder(wavs, wav_lens)
+        if self.transformer_beam_search:
+            encoder_out = self.mods.transformer.encode(encoder_out, wav_lens)
+        return encoder_out
+
+    def transcribe_batch(self, wavs, wav_lens):
+        """Transcribes the input audio into a sequence of words
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        list
+            Each waveform in the batch transcribed.
+        tensor
+            Each predicted token id.
+        """
+        with torch.no_grad():
+            wav_lens = wav_lens.to(self.device)
+            encoder_out = self.encode_batch(wavs, wav_lens)
+            if self.transducer_beam_search:
+                inputs = [encoder_out]
+            else:
+                inputs = [encoder_out, wav_lens]
+            predicted_tokens, _, _, _ = self.mods.decoder(*inputs)
+            predicted_words = [
+                self.tokenizer.decode_ids(token_seq)
+                for token_seq in predicted_tokens
+            ]
+        return predicted_words, predicted_tokens
+
+    def forward(self, wavs, wav_lens):
+        """Runs full transcription - note: no gradients through decoding"""
+        return self.transcribe_batch(wavs, wav_lens)
+
+
+class EncoderASR(Pretrained):
+    """A ready-to-use Encoder ASR model
+
+    The class can be used either to run only the encoder (encode()) to extract
+    features or to run the entire encoder + decoder function model
+    (transcribe()) to transcribe speech. The given YAML must contain the fields
+    specified in the *_NEEDED[] lists.
+
+    Example
+    -------
+    >>> from speechbrain.inference.ASR import EncoderASR
+    >>> tmpdir = getfixture("tmpdir")
+    >>> asr_model = EncoderASR.from_hparams(
+    ...     source="speechbrain/asr-wav2vec2-commonvoice-fr",
+    ...     savedir=tmpdir,
+    ... ) # doctest: +SKIP
+    >>> asr_model.transcribe_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
+    """
+
+    HPARAMS_NEEDED = ["tokenizer", "decoding_function"]
+    MODULES_NEEDED = ["encoder"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.tokenizer = self.hparams.tokenizer
+        self.set_decoding_function()
+
+    def set_decoding_function(self):
+        """Set the decoding function based on the parameters defined in the hyperparameter file.
+
+        The decoding function is determined by the `decoding_function` specified in the hyperparameter file.
+        It can be either a functools.partial object representing a decoding function or an instance of
+        `speechbrain.decoders.ctc.CTCBaseSearcher` for beam search decoding.
+
+        Raises:
+            ValueError: If the decoding function is neither a functools.partial nor an instance of
+                        speechbrain.decoders.ctc.CTCBaseSearcher.
+
+        Note:
+            - For greedy decoding (functools.partial), the provided `decoding_function` is assigned directly.
+            - For CTCBeamSearcher decoding, an instance of the specified `decoding_function` is created, and
+            additional parameters are added based on the tokenizer type.
+        """
+        # Greedy Decoding case
+        if isinstance(self.hparams.decoding_function, functools.partial):
+            self.decoding_function = self.hparams.decoding_function
+        # CTCBeamSearcher case
+        else:
+            # 1. check if the decoding function is an instance of speechbrain.decoders.CTCBaseSearcher
+            if issubclass(
+                self.hparams.decoding_function,
+                speechbrain.decoders.ctc.CTCBaseSearcher,
+            ):
+                # If so, we need to retrieve the vocab list from the tokenizer.
+                # We also need to check if the tokenizer is a sentencepiece or a CTCTextEncoder.
+                if isinstance(
+                    self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder
+                ):
+                    ind2lab = self.tokenizer.ind2lab
+                    vocab_list = [ind2lab[x] for x in range(len(ind2lab))]
+                elif isinstance(
+                    self.tokenizer, sentencepiece.SentencePieceProcessor
+                ):
+                    vocab_list = [
+                        self.tokenizer.id_to_piece(i)
+                        for i in range(self.tokenizer.vocab_size())
+                    ]
+                else:
+                    raise ValueError(
+                        "The tokenizer must be sentencepiece or CTCTextEncoder"
+                    )
+
+                # We can now instantiate the decoding class and add all the parameters
+                if hasattr(self.hparams, "test_beam_search"):
+                    opt_beam_search_params = self.hparams.test_beam_search
+                    # check if the kenlm_model_path is provided and fetch it if necessary
+                    if "kenlm_model_path" in opt_beam_search_params:
+                        source, fl = split_path(
+                            opt_beam_search_params["kenlm_model_path"]
+                        )
+                        kenlm_model_path = str(
+                            fetch(fl, source=source, savedir=".")
+                        )
+                        # we need to update the kenlm_model_path in the opt_beam_search_params
+                        opt_beam_search_params[
+                            "kenlm_model_path"
+                        ] = kenlm_model_path
+                else:
+                    opt_beam_search_params = {}
+                self.decoding_function = self.hparams.decoding_function(
+                    **opt_beam_search_params, vocab_list=vocab_list
+                )
+            else:
+                raise ValueError(
+                    "The decoding function must be an instance of speechbrain.decoders.CTCBaseSearcher"
+                )
+
+    def transcribe_file(self, path, **kwargs):
+        """Transcribes the given audiofile into a sequence of words.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file which to transcribe.
+
+        Returns
+        -------
+        str
+            The audiofile transcription produced by this ASR system.
+        """
+        waveform = self.load_audio(path, **kwargs)
+        # Fake a batch:
+        batch = waveform.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+        predicted_words, predicted_tokens = self.transcribe_batch(
+            batch, rel_length
+        )
+        return str(predicted_words[0])
+
+    def encode_batch(self, wavs, wav_lens):
+        """Encodes the input audio into a sequence of hidden states
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderASR.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        torch.Tensor
+            The encoded batch
+        """
+        wavs = wavs.float()
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+        encoder_out = self.mods.encoder(wavs, wav_lens)
+        return encoder_out
+
+    def transcribe_batch(self, wavs, wav_lens):
+        """Transcribes the input audio into a sequence of words
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderASR.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        list
+            Each waveform in the batch transcribed.
+        tensor
+            Each predicted token id.
+        """
+        with torch.no_grad():
+            wav_lens = wav_lens.to(self.device)
+            encoder_out = self.encode_batch(wavs, wav_lens)
+            predictions = self.decoding_function(encoder_out, wav_lens)
+            is_ctc_text_encoder_tokenizer = isinstance(
+                self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder
+            )
+            if isinstance(self.hparams.decoding_function, functools.partial):
+                if is_ctc_text_encoder_tokenizer:
+                    predicted_words = [
+                        "".join(self.tokenizer.decode_ndim(token_seq))
+                        for token_seq in predictions
+                    ]
+                else:
+                    predicted_words = [
+                        self.tokenizer.decode_ids(token_seq)
+                        for token_seq in predictions
+                    ]
+            else:
+                predicted_words = [hyp[0].text for hyp in predictions]
+
+        return predicted_words, predictions
+
+    def forward(self, wavs, wav_lens):
+        """Runs the encoder"""
+        return self.encode_batch(wavs, wav_lens)
+
+
+class WhisperASR(Pretrained):
+    """A ready-to-use Whisper ASR model
+
+    The class can be used  to  run the entire encoder-decoder whisper model
+    (transcribe()) to transcribe speech. The given YAML must contains the fields
+    specified in the *_NEEDED[] lists.
+
+    Example
+    -------
+    >>> from speechbrain.inference.ASR import WhisperASR
+    >>> tmpdir = getfixture("tmpdir")
+    >>> asr_model = WhisperASR.from_hparams(source="speechbrain/asr-whisper-medium-commonvoice-it", savedir=tmpdir,) # doctest: +SKIP
+    >>> asr_model.transcribe_file("speechbrain/asr-whisper-medium-commonvoice-it/example-it.wav")  # doctest: +SKIP
+    """
+
+    HPARAMS_NEEDED = ["language"]
+    MODULES_NEEDED = ["whisper", "decoder"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.tokenizer = self.hparams.whisper.tokenizer
+        self.tokenizer.set_prefix_tokens(
+            self.hparams.language, "transcribe", False
+        )
+        self.hparams.decoder.set_decoder_input_tokens(
+            self.tokenizer.prefix_tokens
+        )
+
+    def transcribe_file(self, path):
+        """Transcribes the given audiofile into a sequence of words.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file which to transcribe.
+
+        Returns
+        -------
+        str
+            The audiofile transcription produced by this ASR system.
+        """
+        waveform = self.load_audio(path)
+        # Fake a batch:
+        batch = waveform.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+        predicted_words, predicted_tokens = self.transcribe_batch(
+            batch, rel_length
+        )
+        return " ".join(predicted_words[0])
+
+    def encode_batch(self, wavs, wav_lens):
+        """Encodes the input audio into a sequence of hidden states
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.tensor
+            Batch of waveforms [batch, time, channels].
+        wav_lens : torch.tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        torch.tensor
+            The encoded batch
+        """
+        wavs = wavs.float()
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+        encoder_out = self.mods.whisper.forward_encoder(wavs)
+        return encoder_out
+
+    def transcribe_batch(self, wavs, wav_lens):
+        """Transcribes the input audio into a sequence of words
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.tensor
+            Batch of waveforms [batch, time, channels].
+        wav_lens : torch.tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        list
+            Each waveform in the batch transcribed.
+        tensor
+            Each predicted token id.
+        """
+        with torch.no_grad():
+            wav_lens = wav_lens.to(self.device)
+            encoder_out = self.encode_batch(wavs, wav_lens)
+            predicted_tokens, _, _, _ = self.mods.decoder(encoder_out, wav_lens)
+            predicted_words = self.tokenizer.batch_decode(
+                predicted_tokens, skip_special_tokens=True
+            )
+            if self.hparams.normalized_transcripts:
+                predicted_words = [
+                    self.tokenizer._normalize(text).split(" ")
+                    for text in predicted_words
+                ]
+
+        return predicted_words, predicted_tokens
+
+    def forward(self, wavs, wav_lens):
+        """Runs full transcription - note: no gradients through decoding"""
+        return self.transcribe_batch(wavs, wav_lens)
+
+
+@dataclass
+class ASRStreamingContext:
+    """Streaming metadata, initialized by
+    :meth:`~StreamingASR.make_streaming_context` (see there for details on
+    initialization of fields here).
+
+    This object is intended to be mutate: the same object should be passed
+    across calls as streaming progresses (namely when using the lower-level
+    :meth:`~StreamingASR.encode_chunk`, etc. APIs).
+
+    Holds some references to opaque streaming contexts, so the context is
+    model-agnostic to an extent."""
+
+    config: DynChunkTrainConfig
+    """Dynamic chunk training configuration used to initialize the streaming
+    context. Cannot be modified on the fly."""
+
+    fea_extractor_context: Any
+    """Opaque feature extractor streaming context."""
+
+    encoder_context: Any
+    """Opaque encoder streaming context."""
+
+    decoder_context: Any
+    """Opaque decoder streaming context."""
+
+    tokenizer_context: Optional[List[Any]]
+    """Opaque streaming context for the tokenizer. Initially `None`. Initialized
+    to a list of tokenizer contexts once batch size can be determined."""
+
+
+class StreamingASR(Pretrained):
+    """A ready-to-use, streaming-capable ASR model.
+
+    Example
+    -------
+    >>> from speechbrain.inference.ASR import StreamingASR
+    >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+    >>> tmpdir = getfixture("tmpdir")
+    >>> asr_model = StreamingASR.from_hparams(source="speechbrain/asr-conformer-streaming-librispeech", savedir=tmpdir,) # doctest: +SKIP
+    >>> asr_model.transcribe_file("speechbrain/asr-conformer-streaming-librispeech/test-en.wav", DynChunkTrainConfig(24, 8)) # doctest: +SKIP
+    """
+
+    HPARAMS_NEEDED = [
+        "fea_streaming_extractor",
+        "make_decoder_streaming_context",
+        "decoding_function",
+        "make_tokenizer_streaming_context",
+        "tokenizer_decode_streaming",
+    ]
+    MODULES_NEEDED = ["enc", "proj_enc"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.filter_props = self.hparams.fea_streaming_extractor.properties
+
+    def _get_audio_stream(
+        self, streamer: torchaudio.io.StreamReader, frames_per_chunk: int
+    ):
+        """From a :class:`torchaudio.io.StreamReader`, identifies the audio
+        stream and returns an iterable stream of chunks (after resampling and
+        downmixing to mono).
+
+        Arguments
+        ---------
+        streamer : torchaudio.io.StreamReader
+            The stream object. Must hold exactly one source stream of an
+            audio type.
+        frames_per_chunk : int
+            The number of frames per chunk. For a streaming model, this should
+            be determined from the DynChunkTrain configuration.
+        """
+
+        stream_infos = [
+            streamer.get_src_stream_info(i)
+            for i in range(streamer.num_src_streams)
+        ]
+
+        audio_stream_infos = [
+            (i, stream_info)
+            for i, stream_info in enumerate(stream_infos)
+            if stream_info.media_type == "audio"
+        ]
+
+        if len(audio_stream_infos) != 1:
+            raise ValueError(
+                f"Expected stream to have only 1 stream (with any number of channels), got {len(audio_stream_infos)} (with streams: {stream_infos})"
+            )
+
+        # find the index of the first (and only) audio stream
+        audio_stream_index = audio_stream_infos[0][0]
+
+        # output stream #0
+        streamer.add_basic_audio_stream(
+            frames_per_chunk=frames_per_chunk,
+            stream_index=audio_stream_index,
+            sample_rate=self.audio_normalizer.sample_rate,
+            format="fltp",  # torch.float32
+            num_channels=1,
+        )
+
+        for (chunk,) in streamer.stream():
+            chunk = chunk.squeeze(-1)  # we deal with mono, remove that dim
+            chunk = chunk.unsqueeze(0)  # create a fake batch dim
+            yield chunk
+
+    def transcribe_file_streaming(
+        self,
+        path,
+        dynchunktrain_config: DynChunkTrainConfig,
+        use_torchaudio_streaming: bool = True,
+        **kwargs,
+    ):
+        """Transcribes the given audio file into a sequence of words, in a
+        streaming fashion, meaning that text is being yield from this
+        generator, in the form of strings to concatenate.
+
+        Arguments
+        ---------
+        path : str
+            URI/path to the audio to transcribe. When
+            ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
+            fetching from HF or a local file. When ``True``, resolves the URI
+            through ffmpeg, as documented in
+            :class:`torchaudio.io.StreamReader`.
+        dynchunktrain_config : DynChunkTrainConfig
+            Streaming configuration. Sane values and how much time chunks
+            actually represent is model-dependent.
+        use_torchaudio_streaming : bool
+            Whether the audio file can be loaded in a streaming fashion. If not,
+            transcription is still performed through chunks of audio, but the
+            entire audio file is fetched and loaded at once.
+            This skips the usual fetching method and instead resolves the URI
+            using torchaudio (via ffmpeg).
+
+        Returns
+        -------
+        generator of str
+            An iterator yielding transcribed chunks (strings). There is a yield
+            for every chunk, even if the transcribed string for that chunk is an
+            empty string.
+        """
+
+        chunk_size = self.get_chunk_size_frames(dynchunktrain_config)
+
+        if use_torchaudio_streaming:
+            streamer = torchaudio.io.StreamReader(path)
+            chunks = self._get_audio_stream(streamer, chunk_size)
+        else:
+            waveform = self.load_audio(path, **kwargs)
+            batch = waveform.unsqueeze(0)  # create batch dim
+            chunks = split_fixed_chunks(batch, chunk_size)
+
+        rel_length = torch.tensor([1.0])
+        context = self.make_streaming_context(dynchunktrain_config)
+
+        final_chunks = [
+            torch.zeros((1, chunk_size), device=self.device)
+        ] * self.hparams.fea_streaming_extractor.get_recommended_final_chunk_count(
+            chunk_size
+        )
+
+        for chunk in itertools.chain(chunks, final_chunks):
+            predicted_words = self.transcribe_chunk(context, chunk, rel_length)
+            yield predicted_words[0]
+
+    def transcribe_file(
+        self,
+        path,
+        dynchunktrain_config: DynChunkTrainConfig,
+        use_torchaudio_streaming: bool = True,
+    ):
+        """Transcribes the given audio file into a sequence of words.
+
+        Arguments
+        ---------
+        path : str
+            URI/path to the audio to transcribe. When
+            ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
+            fetching from HF or a local file. When ``True``, resolves the URI
+            through ffmpeg, as documented in
+            :class:`torchaudio.io.StreamReader`.
+        dynchunktrain_config : DynChunkTrainConfig
+            Streaming configuration. Sane values and how much time chunks
+            actually represent is model-dependent.
+        use_torchaudio_streaming : bool
+            Whether the audio file can be loaded in a streaming fashion. If not,
+            transcription is still performed through chunks of audio, but the
+            entire audio file is fetched and loaded at once.
+            This skips the usual fetching method and instead resolves the URI
+            using torchaudio (via ffmpeg).
+
+        Returns
+        -------
+        str
+            The audio file transcription produced by this ASR system.
+        """
+
+        pred = ""
+
+        for text_chunk in self.transcribe_file_streaming(
+            path, dynchunktrain_config, use_torchaudio_streaming
+        ):
+            pred += text_chunk
+
+        return pred
+
+    def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig):
+        """Create a blank streaming context to be passed around for chunk
+        encoding/transcription.
+
+        Arguments
+        ---------
+        dynchunktrain_config : DynChunkTrainConfig
+            Streaming configuration. Sane values and how much time chunks
+            actually represent is model-dependent."""
+
+        return ASRStreamingContext(
+            config=dynchunktrain_config,
+            fea_extractor_context=self.hparams.fea_streaming_extractor.make_streaming_context(),
+            encoder_context=self.mods.enc.make_streaming_context(
+                dynchunktrain_config
+            ),
+            decoder_context=self.hparams.make_decoder_streaming_context(),
+            tokenizer_context=None,
+        )
+
+    def get_chunk_size_frames(
+        self, dynchunktrain_config: DynChunkTrainConfig
+    ) -> int:
+        """Returns the chunk size in actual audio samples, i.e. the exact
+        expected length along the time dimension of an input chunk tensor (as
+        passed to :meth:`~StreamingASR.encode_chunk` and similar low-level
+        streaming functions).
+
+        Arguments
+        ---------
+        dynchunktrain_config : DynChunkTrainConfig
+            The streaming configuration to determine the chunk frame count of.
+        """
+
+        return (self.filter_props.stride - 1) * dynchunktrain_config.chunk_size
+
+    @torch.no_grad()
+    def encode_chunk(
+        self,
+        context: ASRStreamingContext,
+        chunk: torch.Tensor,
+        chunk_len: Optional[torch.Tensor] = None,
+    ):
+        """Encoding of a batch of audio chunks into a batch of encoded
+        sequences.
+        For full speech-to-text offline transcription, use `transcribe_batch` or
+        `transcribe_file`.
+        Must be called over a given context in the correct order of chunks over
+        time.
+
+        Arguments
+        ---------
+        context : ASRStreamingContext
+            Mutable streaming context object, which must be specified and reused
+            across calls when streaming.
+            You can obtain an initial context by calling
+            `asr.make_streaming_context(config)`.
+
+        chunk : torch.Tensor
+            The tensor for an audio chunk of shape `[batch size, time]`.
+            The time dimension must strictly match
+            `asr.get_chunk_size_frames(config)`.
+            The waveform is expected to be in the model's expected format (i.e.
+            the sampling rate must be correct).
+
+        chunk_len : torch.Tensor, optional
+            The relative chunk length tensor of shape `[batch size]`. This is to
+            be used when the audio in one of the chunks of the batch is ending
+            within this chunk.
+            If unspecified, equivalent to `torch.ones((batch_size,))`.
+
+        Returns
+        -------
+        torch.Tensor
+            Encoded output, of a model-dependent shape."""
+
+        if chunk_len is None:
+            chunk_len = torch.ones((chunk.size(0),))
+
+        chunk = chunk.float()
+        chunk, chunk_len = chunk.to(self.device), chunk_len.to(self.device)
+
+        assert chunk.shape[-1] <= self.get_chunk_size_frames(context.config)
+
+        x = self.hparams.fea_streaming_extractor(
+            chunk, context=context.fea_extractor_context, lengths=chunk_len
+        )
+        x = self.mods.enc.forward_streaming(x, context.encoder_context)
+        x = self.mods.proj_enc(x)
+        return x
+
+    @torch.no_grad()
+    def decode_chunk(
+        self, context: ASRStreamingContext, x: torch.Tensor
+    ) -> tuple[list, list]:
+        """Decodes the output of the encoder into tokens and the associated
+        transcription.
+        Must be called over a given context in the correct order of chunks over
+        time.
+
+        Arguments
+        ---------
+        context : ASRStreamingContext
+            Mutable streaming context object, which should be the same object
+            that was passed to `encode_chunk`.
+
+        x : torch.Tensor
+            The output of `encode_chunk` for a given chunk.
+
+        Returns
+        -------
+        list of str
+            Decoded tokens of length `batch_size`. The decoded strings can be
+            of 0-length.
+        list of list of output token hypotheses
+            List of length `batch_size`, each holding a list of tokens of any
+            length `>=0`.
+        """
+        tokens = self.hparams.decoding_function(x, context.decoder_context)
+
+        # initialize token context for real now that we know the batch size
+        if context.tokenizer_context is None:
+            context.tokenizer_context = [
+                self.hparams.make_tokenizer_streaming_context()
+                for _ in range(len(tokens))
+            ]
+
+        words = [
+            self.hparams.tokenizer_decode_streaming(
+                self.hparams.tokenizer, cur_tokens, context.tokenizer_context[i]
+            )
+            for i, cur_tokens in enumerate(tokens)
+        ]
+
+        return words, tokens
+
+    def transcribe_chunk(
+        self,
+        context: ASRStreamingContext,
+        chunk: torch.Tensor,
+        chunk_len: Optional[torch.Tensor] = None,
+    ):
+        """Transcription of a batch of audio chunks into transcribed text.
+        Must be called over a given context in the correct order of chunks over
+        time.
+
+        Arguments
+        ---------
+        context : ASRStreamingContext
+            Mutable streaming context object, which must be specified and reused
+            across calls when streaming.
+            You can obtain an initial context by calling
+            `asr.make_streaming_context(config)`.
+
+        chunk : torch.Tensor
+            The tensor for an audio chunk of shape `[batch size, time]`.
+            The time dimension must strictly match
+            `asr.get_chunk_size_frames(config)`.
+            The waveform is expected to be in the model's expected format (i.e.
+            the sampling rate must be correct).
+
+        chunk_len : torch.Tensor, optional
+            The relative chunk length tensor of shape `[batch size]`. This is to
+            be used when the audio in one of the chunks of the batch is ending
+            within this chunk.
+            If unspecified, equivalent to `torch.ones((batch_size,))`.
+
+        Returns
+        -------
+        str
+            Transcribed string for this chunk, might be of length zero.
+        """
+
+        if chunk_len is None:
+            chunk_len = torch.ones((chunk.size(0),))
+
+        chunk = chunk.float()
+        chunk, chunk_len = chunk.to(self.device), chunk_len.to(self.device)
+
+        x = self.encode_chunk(context, chunk, chunk_len)
+        words, _tokens = self.decode_chunk(context, x)
+
+        return words
diff --git a/speechbrain/inference/SLU.py b/speechbrain/inference/SLU.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d7a5c8024a19bfccdb4d7d19beeb34629d92df
--- /dev/null
+++ b/speechbrain/inference/SLU.py
@@ -0,0 +1,132 @@
+""" Specifies the inference interfaces for Spoken Language Understanding (SLU) modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+from speechbrain.inference.interfaces import Pretrained
+from speechbrain.inference.ASR import EncoderDecoderASR
+
+
+class EndToEndSLU(Pretrained):
+    """An end-to-end SLU model.
+
+    The class can be used either to run only the encoder (encode()) to extract
+    features or to run the entire model (decode()) to map the speech to its semantics.
+
+    Example
+    -------
+    >>> from speechbrain.inference.SLU import EndToEndSLU
+    >>> tmpdir = getfixture("tmpdir")
+    >>> slu_model = EndToEndSLU.from_hparams(
+    ...     source="speechbrain/slu-timers-and-such-direct-librispeech-asr",
+    ...     savedir=tmpdir,
+    ... )  # doctest: +SKIP
+    >>> slu_model.decode_file("tests/samples/single-mic/example6.wav") # doctest: +SKIP
+    "{'intent': 'SimpleMath', 'slots': {'number1': 37.67, 'number2': 75.7, 'op': ' minus '}}"
+    """
+
+    HPARAMS_NEEDED = ["tokenizer", "asr_model_source"]
+    MODULES_NEEDED = ["slu_enc", "beam_searcher"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.tokenizer = self.hparams.tokenizer
+        self.asr_model = EncoderDecoderASR.from_hparams(
+            source=self.hparams.asr_model_source,
+            run_opts={"device": self.device},
+        )
+
+    def decode_file(self, path, **kwargs):
+        """Maps the given audio file to a string representing the
+        semantic dictionary for the utterance.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file to decode.
+
+        Returns
+        -------
+        str
+            The predicted semantics.
+        """
+        waveform = self.load_audio(path, **kwargs)
+        waveform = waveform.to(self.device)
+        # Fake a batch:
+        batch = waveform.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+        predicted_words, predicted_tokens = self.decode_batch(batch, rel_length)
+        return predicted_words[0]
+
+    def encode_batch(self, wavs, wav_lens):
+        """Encodes the input audio into a sequence of hidden states
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        torch.Tensor
+            The encoded batch
+        """
+        wavs = wavs.float()
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+        ASR_encoder_out = self.asr_model.encode_batch(wavs.detach(), wav_lens)
+        encoder_out = self.mods.slu_enc(ASR_encoder_out)
+        return encoder_out
+
+    def decode_batch(self, wavs, wav_lens):
+        """Maps the input audio to its semantics
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        list
+            Each waveform in the batch decoded.
+        tensor
+            Each predicted token id.
+        """
+        with torch.no_grad():
+            wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+            encoder_out = self.encode_batch(wavs, wav_lens)
+            predicted_tokens, scores, _, _ = self.mods.beam_searcher(
+                encoder_out, wav_lens
+            )
+            predicted_words = [
+                self.tokenizer.decode_ids(token_seq)
+                for token_seq in predicted_tokens
+            ]
+        return predicted_words, predicted_tokens
+
+    def forward(self, wavs, wav_lens):
+        """Runs full decoding - note: no gradients through decoding"""
+        return self.decode_batch(wavs, wav_lens)
diff --git a/speechbrain/inference/ST.py b/speechbrain/inference/ST.py
new file mode 100644
index 0000000000000000000000000000000000000000..c53290df3cb4d74439a4b4e9f34d0750a48a65c6
--- /dev/null
+++ b/speechbrain/inference/ST.py
@@ -0,0 +1,125 @@
+""" Specifies the inference interfaces for Speech Translation (ST) modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+from speechbrain.inference.interfaces import Pretrained
+
+
+class EncoderDecoderS2UT(Pretrained):
+    """A ready-to-use Encoder Decoder for speech-to-unit translation model
+
+    The class can be used  to  run the entire encoder-decoder S2UT model
+    (translate_file()) to translate speech. The given YAML must contains the fields
+    specified in the *_NEEDED[] lists.
+
+    Example
+    -------
+    >>> from speechbrain.inference.ST import EncoderDecoderS2UT
+    >>> tmpdir = getfixture("tmpdir")
+    >>> s2ut_model = EncoderDecoderS2UT.from_hparams(source="speechbrain/s2st-transformer-fr-en-hubert-l6-k100-cvss", savedir=tmpdir) # doctest: +SKIP
+    >>> s2ut_model.translate_file("speechbrain/s2st-transformer-fr-en-hubert-l6-k100-cvss/example-fr.wav") # doctest: +SKIP
+    """
+
+    HPARAMS_NEEDED = ["sample_rate"]
+    MODULES_NEEDED = ["encoder", "decoder"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.sample_rate = self.hparams.sample_rate
+
+    def translate_file(self, path):
+        """Translates the given audiofile into a sequence speech unit.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file which to translate.
+
+        Returns
+        -------
+        int[]
+            The audiofile translation produced by this speech-to-unit translationmodel.
+        """
+
+        audio = self.load_audio(path)
+        audio = audio.to(self.device)
+        # Fake a batch:
+        batch = audio.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+        predicted_tokens = self.translate_batch(batch, rel_length)
+        return predicted_tokens[0]
+
+    def encode_batch(self, wavs, wav_lens):
+        """Encodes the input audio into a sequence of hidden states
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderDecoderS2UT.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.tensor
+            Batch of waveforms [batch, time, channels].
+        wav_lens : torch.tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        torch.tensor
+            The encoded batch
+        """
+        wavs = wavs.float()
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+        encoder_out = self.mods.encoder(wavs, wav_lens)
+        return encoder_out
+
+    def translate_batch(self, wavs, wav_lens):
+        """Translates the input audio into a sequence of words
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderDecoderS2UT.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.tensor
+            Batch of waveforms [batch, time, channels].
+        wav_lens : torch.tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        list
+            Each waveform in the batch translated.
+        tensor
+            Each predicted token id.
+        """
+        with torch.no_grad():
+            wav_lens = wav_lens.to(self.device)
+            encoder_out = self.encode_batch(wavs, wav_lens)
+            predicted_tokens, _, _, _ = self.mods.decoder(encoder_out, wav_lens)
+        return predicted_tokens
+
+    def forward(self, wavs, wav_lens):
+        """Runs full translation"""
+        return self.encode_batch(wavs, wav_lens)
diff --git a/speechbrain/inference/TTS.py b/speechbrain/inference/TTS.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b8bd3339762b75c0306f695130996370d72a31
--- /dev/null
+++ b/speechbrain/inference/TTS.py
@@ -0,0 +1,848 @@
+""" Specifies the inference interfaces for Text-To-Speech (TTS) modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import re
+import logging
+import torch
+import torchaudio
+import random
+import speechbrain
+from speechbrain.utils.fetching import fetch
+from speechbrain.inference.interfaces import Pretrained
+from speechbrain.utils.text_to_sequence import text_to_sequence
+from speechbrain.inference.text import GraphemeToPhoneme
+from speechbrain.inference.encoders import MelSpectrogramEncoder
+from speechbrain.inference.classifiers import EncoderClassifier
+
+
+logger = logging.getLogger(__name__)
+
+
+class Tacotron2(Pretrained):
+    """
+    A ready-to-use wrapper for Tacotron2 (text -> mel_spec).
+
+    Arguments
+    ---------
+    hparams
+        Hyperparameters (from HyperPyYAML)
+
+    Example
+    -------
+    >>> tmpdir_tts = getfixture('tmpdir') / "tts"
+    >>> tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir=tmpdir_tts)
+    >>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
+    >>> items = [
+    ...   "A quick brown fox jumped over the lazy dog",
+    ...   "How much wood would a woodchuck chuck?",
+    ...   "Never odd or even"
+    ... ]
+    >>> mel_outputs, mel_lengths, alignments = tacotron2.encode_batch(items)
+
+    >>> # One can combine the TTS model with a vocoder (that generates the final waveform)
+    >>> # Intialize the Vocoder (HiFIGAN)
+    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
+    >>> from speechbrain.inference.vocoders import HIFIGAN
+    >>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder)
+    >>> # Running the TTS
+    >>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
+    >>> # Running Vocoder (spectrogram-to-waveform)
+    >>> waveforms = hifi_gan.decode_batch(mel_output)
+    """
+
+    HPARAMS_NEEDED = ["model", "text_to_sequence"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.text_cleaners = getattr(
+            self.hparams, "text_cleaners", ["english_cleaners"]
+        )
+        self.infer = self.hparams.model.infer
+
+    def text_to_seq(self, txt):
+        """Encodes raw text into a tensor with a customer text-to-sequence function"""
+        sequence = self.hparams.text_to_sequence(txt, self.text_cleaners)
+        return sequence, len(sequence)
+
+    def encode_batch(self, texts):
+        """Computes mel-spectrogram for a list of texts
+
+        Texts must be sorted in decreasing order on their lengths
+
+        Arguments
+        ---------
+        texts: List[str]
+            texts to be encoded into spectrogram
+
+        Returns
+        -------
+        tensors of output spectrograms, output lengths and alignments
+        """
+        with torch.no_grad():
+            inputs = [
+                {
+                    "text_sequences": torch.tensor(
+                        self.text_to_seq(item)[0], device=self.device
+                    )
+                }
+                for item in texts
+            ]
+            inputs = speechbrain.dataio.batch.PaddedBatch(inputs)
+
+            lens = [self.text_to_seq(item)[1] for item in texts]
+            assert lens == sorted(
+                lens, reverse=True
+            ), "input lengths must be sorted in decreasing order"
+            input_lengths = torch.tensor(lens, device=self.device)
+
+            mel_outputs_postnet, mel_lengths, alignments = self.infer(
+                inputs.text_sequences.data, input_lengths
+            )
+        return mel_outputs_postnet, mel_lengths, alignments
+
+    def encode_text(self, text):
+        """Runs inference for a single text str"""
+        return self.encode_batch([text])
+
+    def forward(self, texts):
+        "Encodes the input texts."
+        return self.encode_batch(texts)
+
+
+class MSTacotron2(Pretrained):
+    """
+    A ready-to-use wrapper for Zero-Shot Multi-Speaker Tacotron2.
+    For voice cloning: (text, reference_audio) -> (mel_spec).
+    For generating a random speaker voice: (text) -> (mel_spec).
+
+
+    Example
+    -------
+    >>> tmpdir_tts = getfixture('tmpdir') / "tts"
+    >>> mstacotron2 = MSTacotron2.from_hparams(source="speechbrain/tts-mstacotron2-libritts", savedir=tmpdir_tts) # doctest: +SKIP
+    >>> # Sample rate of the reference audio must be greater or equal to the sample rate of the speaker embedding model
+    >>> reference_audio_path = "tests/samples/single-mic/example1.wav"
+    >>> input_text = "Mary had a little lamb."
+    >>> mel_output, mel_length, alignment = mstacotron2.clone_voice(input_text, reference_audio_path) # doctest: +SKIP
+    >>> # One can combine the TTS model with a vocoder (that generates the final waveform)
+    >>> # Intialize the Vocoder (HiFIGAN)
+    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
+    >>> from speechbrain.inference.vocoders import HIFIGAN
+    >>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-libritts-22050Hz", savedir=tmpdir_vocoder) # doctest: +SKIP
+    >>> # Running the TTS
+    >>> mel_output, mel_length, alignment = mstacotron2.clone_voice(input_text, reference_audio_path) # doctest: +SKIP
+    >>> # Running Vocoder (spectrogram-to-waveform)
+    >>> waveforms = hifi_gan.decode_batch(mel_output) # doctest: +SKIP
+    >>> # For generating a random speaker voice, use the following
+    >>> mel_output, mel_length, alignment = mstacotron2.generate_random_voice(input_text) # doctest: +SKIP
+    """
+
+    HPARAMS_NEEDED = ["model"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.text_cleaners = ["english_cleaners"]
+        self.infer = self.hparams.model.infer
+        self.custom_mel_spec_encoder = self.hparams.custom_mel_spec_encoder
+
+        self.g2p = GraphemeToPhoneme.from_hparams(
+            self.hparams.g2p, run_opts={"device": self.device}
+        )
+
+        self.spk_emb_encoder = None
+        if self.custom_mel_spec_encoder:
+            self.spk_emb_encoder = MelSpectrogramEncoder.from_hparams(
+                source=self.hparams.spk_emb_encoder,
+                run_opts={"device": self.device},
+            )
+        else:
+            self.spk_emb_encoder = EncoderClassifier.from_hparams(
+                source=self.hparams.spk_emb_encoder,
+                run_opts={"device": self.device},
+            )
+
+    def __text_to_seq(self, txt):
+        """Encodes raw text into a tensor with a customer text-to-equence fuction
+        """
+        sequence = text_to_sequence(txt, self.text_cleaners)
+        return sequence, len(sequence)
+
+    def clone_voice(self, texts, audio_path):
+        """
+        Generates mel-spectrogram using input text and reference audio
+
+        Arguments
+        ---------
+        texts : str or list
+            Input text
+        audio_path : str
+            Reference audio
+
+        Returns
+        -------
+        tensors of output spectrograms, output lengths and alignments
+        """
+
+        # Loads audio
+        ref_signal, signal_sr = torchaudio.load(audio_path)
+
+        # Resamples the audio if required
+        if signal_sr != self.hparams.spk_emb_sample_rate:
+            ref_signal = torchaudio.functional.resample(
+                ref_signal, signal_sr, self.hparams.spk_emb_sample_rate
+            )
+        ref_signal = ref_signal.to(self.device)
+
+        # Computes speaker embedding
+        if self.custom_mel_spec_encoder:
+            spk_emb = self.spk_emb_encoder.encode_waveform(ref_signal)
+        else:
+            spk_emb = self.spk_emb_encoder.encode_batch(ref_signal)
+
+        spk_emb = spk_emb.squeeze(0)
+
+        # Converts input texts into the corresponding phoneme sequences
+        if isinstance(texts, str):
+            texts = [texts]
+        phoneme_seqs = self.g2p(texts)
+        for i in range(len(phoneme_seqs)):
+            phoneme_seqs[i] = " ".join(phoneme_seqs[i])
+            phoneme_seqs[i] = "{" + phoneme_seqs[i] + "}"
+
+        # Repeats the speaker embedding to match the number of input texts
+        spk_embs = spk_emb.repeat(len(texts), 1)
+
+        # Calls __encode_batch to generate the mel-spectrograms
+        return self.__encode_batch(phoneme_seqs, spk_embs)
+
+    def generate_random_voice(self, texts):
+        """
+        Generates mel-spectrogram using input text and a random speaker voice
+
+        Arguments
+        ---------
+        texts : str or list
+            Input text
+
+        Returns
+        -------
+        tensors of output spectrograms, output lengths and alignments
+        """
+
+        spk_emb = self.__sample_random_speaker().float()
+        spk_emb = spk_emb.to(self.device)
+
+        # Converts input texts into the corresponding phoneme sequences
+        if isinstance(texts, str):
+            texts = [texts]
+        phoneme_seqs = self.g2p(texts)
+        for i in range(len(phoneme_seqs)):
+            phoneme_seqs[i] = " ".join(phoneme_seqs[i])
+            phoneme_seqs[i] = "{" + phoneme_seqs[i] + "}"
+
+        # Repeats the speaker embedding to match the number of input texts
+        spk_embs = spk_emb.repeat(len(texts), 1)
+
+        # Calls __encode_batch to generate the mel-spectrograms
+        return self.__encode_batch(phoneme_seqs, spk_embs)
+
+    def __encode_batch(self, texts, spk_embs):
+        """Computes mel-spectrograms for a list of texts
+        Texts are sorted in decreasing order on their lengths
+
+        Arguments
+        ---------
+        texts: List[str]
+            texts to be encoded into spectrogram
+        spk_embs: torch.Tensor
+            speaker embeddings
+
+        Returns
+        -------
+        tensors of output spectrograms, output lengths and alignments
+        """
+
+        with torch.no_grad():
+            inputs = [
+                {
+                    "text_sequences": torch.tensor(
+                        self.__text_to_seq(item)[0], device=self.device
+                    )
+                }
+                for item in texts
+            ]
+
+            inputs = sorted(
+                inputs,
+                key=lambda x: x["text_sequences"].size()[0],
+                reverse=True,
+            )
+
+            lens = [entry["text_sequences"].size()[0] for entry in inputs]
+
+            inputs = speechbrain.dataio.batch.PaddedBatch(inputs)
+
+            assert lens == sorted(
+                lens, reverse=True
+            ), "ipnut lengths must be sorted in decreasing order"
+            input_lengths = torch.tensor(lens, device=self.device)
+
+            mel_outputs_postnet, mel_lengths, alignments = self.infer(
+                inputs.text_sequences.data, spk_embs, input_lengths
+            )
+        return mel_outputs_postnet, mel_lengths, alignments
+
+    def __sample_random_speaker(self):
+        """Samples a random speaker embedding from a pretrained GMM
+
+        Returns
+        -------
+        x: torch.Tensor
+            A randomly sampled speaker embedding
+        """
+
+        # Fetches and Loads GMM trained on speaker embeddings
+        speaker_gmm_local_path = fetch(
+            filename=self.hparams.random_speaker_sampler,
+            source=self.hparams.random_speaker_sampler_source,
+            savedir=self.hparams.pretrainer.collect_in,
+        )
+        random_speaker_gmm = torch.load(speaker_gmm_local_path)
+        gmm_n_components = random_speaker_gmm["gmm_n_components"]
+        gmm_means = random_speaker_gmm["gmm_means"]
+        gmm_covariances = random_speaker_gmm["gmm_covariances"]
+
+        # Randomly selects a speaker
+        counts = torch.zeros(gmm_n_components)
+        counts[random.randint(0, gmm_n_components - 1)] = 1
+        x = torch.empty(0, device=counts.device)
+
+        # Samples an embedding for the speaker
+        for k in torch.arange(gmm_n_components)[counts > 0]:
+            # Considers full covariance type
+            d_k = torch.distributions.multivariate_normal.MultivariateNormal(
+                gmm_means[k], gmm_covariances[k]
+            )
+            x_k = torch.stack([d_k.sample() for _ in range(int(counts[k]))])
+
+            x = torch.cat((x, x_k), dim=0)
+
+        return x
+
+
+class FastSpeech2(Pretrained):
+    """
+    A ready-to-use wrapper for Fastspeech2 (text -> mel_spec).
+    Arguments
+    ---------
+    hparams
+        Hyperparameters (from HyperPyYAML)
+    Example
+    -------
+    >>> tmpdir_tts = getfixture('tmpdir') / "tts"
+    >>> fastspeech2 = FastSpeech2.from_hparams(source="speechbrain/tts-fastspeech2-ljspeech", savedir=tmpdir_tts) # doctest: +SKIP
+    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
+    >>> items = [
+    ...   "A quick brown fox jumped over the lazy dog",
+    ...   "How much wood would a woodchuck chuck?",
+    ...   "Never odd or even"
+    ... ]
+    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(items) # doctest: +SKIP
+    >>>
+    >>> # One can combine the TTS model with a vocoder (that generates the final waveform)
+    >>> # Intialize the Vocoder (HiFIGAN)
+    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
+    >>> from speechbrain.inference.vocoders import HIFIGAN
+    >>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder) # doctest: +SKIP
+    >>> # Running the TTS
+    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
+    >>> # Running Vocoder (spectrogram-to-waveform)
+    >>> waveforms = hifi_gan.decode_batch(mel_outputs) # doctest: +SKIP
+    """
+
+    HPARAMS_NEEDED = ["spn_predictor", "model", "input_encoder"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        lexicon = self.hparams.lexicon
+        lexicon = ["@@"] + lexicon
+        self.input_encoder = self.hparams.input_encoder
+        self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
+        self.input_encoder.add_unk()
+
+        self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
+
+        self.spn_token_encoded = (
+            self.input_encoder.encode_sequence_torch(["spn"]).int().item()
+        )
+
+    def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
+        """Computes mel-spectrogram for a list of texts
+
+        Arguments
+        ---------
+        texts: List[str]
+            texts to be converted to spectrogram
+        pace: float
+            pace for the speech synthesis
+        pitch_rate : float
+            scaling factor for phoneme pitches
+        energy_rate : float
+            scaling factor for phoneme energies
+
+        Returns
+        -------
+        tensors of output spectrograms, output lengths and alignments
+        """
+
+        # Preprocessing required at the inference time for the input text
+        # "label" below contains input text
+        # "phoneme_labels" contain the phoneme sequences corresponding to input text labels
+        # "last_phonemes_combined" is used to indicate whether the index position is for a last phoneme of a word
+        # "punc_positions" is used to add back the silence for punctuations
+        phoneme_labels = list()
+        last_phonemes_combined = list()
+        punc_positions = list()
+
+        for label in texts:
+            phoneme_label = list()
+            last_phonemes = list()
+            punc_position = list()
+
+            words = label.split()
+            words = [word.strip() for word in words]
+            words_phonemes = self.g2p(words)
+
+            for i in range(len(words_phonemes)):
+                words_phonemes_seq = words_phonemes[i]
+                for phoneme in words_phonemes_seq:
+                    if not phoneme.isspace():
+                        phoneme_label.append(phoneme)
+                        last_phonemes.append(0)
+                        punc_position.append(0)
+                last_phonemes[-1] = 1
+                if words[i][-1] in ":;-,.!?":
+                    punc_position[-1] = 1
+
+            phoneme_labels.append(phoneme_label)
+            last_phonemes_combined.append(last_phonemes)
+            punc_positions.append(punc_position)
+
+        # Inserts silent phonemes in the input phoneme sequence
+        all_tokens_with_spn = list()
+        max_seq_len = -1
+        for i in range(len(phoneme_labels)):
+            phoneme_label = phoneme_labels[i]
+            token_seq = (
+                self.input_encoder.encode_sequence_torch(phoneme_label)
+                .int()
+                .to(self.device)
+            )
+            last_phonemes = torch.LongTensor(last_phonemes_combined[i]).to(
+                self.device
+            )
+
+            # Runs the silent phoneme predictor
+            spn_preds = (
+                self.hparams.modules["spn_predictor"]
+                .infer(token_seq.unsqueeze(0), last_phonemes.unsqueeze(0))
+                .int()
+            )
+
+            spn_to_add = torch.nonzero(spn_preds).reshape(-1).tolist()
+
+            for j in range(len(punc_positions[i])):
+                if punc_positions[i][j] == 1:
+                    spn_to_add.append(j)
+
+            tokens_with_spn = list()
+
+            for token_idx in range(token_seq.shape[0]):
+                tokens_with_spn.append(token_seq[token_idx].item())
+                if token_idx in spn_to_add:
+                    tokens_with_spn.append(self.spn_token_encoded)
+
+            tokens_with_spn = torch.LongTensor(tokens_with_spn).to(self.device)
+            all_tokens_with_spn.append(tokens_with_spn)
+            if max_seq_len < tokens_with_spn.shape[-1]:
+                max_seq_len = tokens_with_spn.shape[-1]
+
+        # "tokens_with_spn_tensor" holds the input phoneme sequence with silent phonemes
+        tokens_with_spn_tensor_padded = torch.LongTensor(
+            len(texts), max_seq_len
+        ).to(self.device)
+        tokens_with_spn_tensor_padded.zero_()
+
+        for seq_idx, seq in enumerate(all_tokens_with_spn):
+            tokens_with_spn_tensor_padded[seq_idx, : len(seq)] = seq
+
+        return self.encode_batch(
+            tokens_with_spn_tensor_padded,
+            pace=pace,
+            pitch_rate=pitch_rate,
+            energy_rate=energy_rate,
+        )
+
+    def encode_phoneme(
+        self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0
+    ):
+        """Computes mel-spectrogram for a list of phoneme sequences
+
+        Arguments
+        ---------
+        phonemes: List[List[str]]
+            phonemes to be converted to spectrogram
+        pace: float
+            pace for the speech synthesis
+        pitch_rate : float
+            scaling factor for phoneme pitches
+        energy_rate : float
+            scaling factor for phoneme energies
+
+        Returns
+        -------
+        tensors of output spectrograms, output lengths and alignments
+        """
+
+        all_tokens = []
+        max_seq_len = -1
+        for phoneme in phonemes:
+            token_seq = (
+                self.input_encoder.encode_sequence_torch(phoneme)
+                .int()
+                .to(self.device)
+            )
+            if max_seq_len < token_seq.shape[-1]:
+                max_seq_len = token_seq.shape[-1]
+            all_tokens.append(token_seq)
+
+        tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to(
+            self.device
+        )
+        tokens_padded.zero_()
+
+        for seq_idx, seq in enumerate(all_tokens):
+            tokens_padded[seq_idx, : len(seq)] = seq
+
+        return self.encode_batch(
+            tokens_padded,
+            pace=pace,
+            pitch_rate=pitch_rate,
+            energy_rate=energy_rate,
+        )
+
+    def encode_batch(
+        self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0
+    ):
+        """Batch inference for a tensor of phoneme sequences
+        Arguments
+        ---------
+        tokens_padded : torch.Tensor
+            A sequence of encoded phonemes to be converted to spectrogram
+        pace : float
+            pace for the speech synthesis
+        pitch_rate : float
+            scaling factor for phoneme pitches
+        energy_rate : float
+            scaling factor for phoneme energies
+        """
+        with torch.no_grad():
+            (
+                _,
+                post_mel_outputs,
+                durations,
+                pitch,
+                _,
+                energy,
+                _,
+                _,
+            ) = self.hparams.model(
+                tokens_padded,
+                pace=pace,
+                pitch_rate=pitch_rate,
+                energy_rate=energy_rate,
+            )
+
+            # Transposes to make in compliant with HiFI GAN expected format
+            post_mel_outputs = post_mel_outputs.transpose(-1, 1)
+
+        return post_mel_outputs, durations, pitch, energy
+
+    def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
+        """Batch inference for a tensor of phoneme sequences
+        Arguments
+        ---------
+        text : str
+            A text to be converted to spectrogram
+        pace : float
+            pace for the speech synthesis
+        pitch_rate : float
+            scaling factor for phoneme pitches
+        energy_rate : float
+            scaling factor for phoneme energies
+        """
+        return self.encode_text(
+            [text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate
+        )
+
+
+class FastSpeech2InternalAlignment(Pretrained):
+    """
+    A ready-to-use wrapper for Fastspeech2 with internal alignment(text -> mel_spec).
+    Arguments
+    ---------
+    hparams
+        Hyperparameters (from HyperPyYAML)
+    Example
+    -------
+    >>> tmpdir_tts = getfixture('tmpdir') / "tts"
+    >>> fastspeech2 = FastSpeech2InternalAlignment.from_hparams(source="speechbrain/tts-fastspeech2-internal-alignment-ljspeech", savedir=tmpdir_tts) # doctest: +SKIP
+    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
+    >>> items = [
+    ...   "A quick brown fox jumped over the lazy dog",
+    ...   "How much wood would a woodchuck chuck?",
+    ...   "Never odd or even"
+    ... ]
+    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(items) # doctest: +SKIP
+    >>> # One can combine the TTS model with a vocoder (that generates the final waveform)
+    >>> # Intialize the Vocoder (HiFIGAN)
+    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
+    >>> from speechbrain.inference.vocoders import HIFIGAN
+    >>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder) # doctest: +SKIP
+    >>> # Running the TTS
+    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
+    >>> # Running Vocoder (spectrogram-to-waveform)
+    >>> waveforms = hifi_gan.decode_batch(mel_outputs) # doctest: +SKIP
+    """
+
+    HPARAMS_NEEDED = ["model", "input_encoder"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        lexicon = self.hparams.lexicon
+        lexicon = ["@@"] + lexicon
+        self.input_encoder = self.hparams.input_encoder
+        self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
+        self.input_encoder.add_unk()
+
+        self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
+
+    def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
+        """Computes mel-spectrogram for a list of texts
+
+        Arguments
+        ---------
+        texts: List[str]
+            texts to be converted to spectrogram
+        pace: float
+            pace for the speech synthesis
+        pitch_rate : float
+            scaling factor for phoneme pitches
+        energy_rate : float
+            scaling factor for phoneme energies
+
+        Returns
+        -------
+        tensors of output spectrograms, output lengths and alignments
+        """
+
+        # Preprocessing required at the inference time for the input text
+        # "label" below contains input text
+        # "phoneme_labels" contain the phoneme sequences corresponding to input text labels
+
+        phoneme_labels = list()
+        max_seq_len = -1
+
+        for label in texts:
+            phonemes_with_punc = self._g2p_keep_punctuations(self.g2p, label)
+            if max_seq_len < len(phonemes_with_punc):
+                max_seq_len = len(phonemes_with_punc)
+            token_seq = (
+                self.input_encoder.encode_sequence_torch(phonemes_with_punc)
+                .int()
+                .to(self.device)
+            )
+            phoneme_labels.append(token_seq)
+
+        tokens_padded = torch.LongTensor(len(texts), max_seq_len).to(
+            self.device
+        )
+        tokens_padded.zero_()
+
+        for seq_idx, seq in enumerate(phoneme_labels):
+            tokens_padded[seq_idx, : len(seq)] = seq
+
+        return self.encode_batch(
+            tokens_padded,
+            pace=pace,
+            pitch_rate=pitch_rate,
+            energy_rate=energy_rate,
+        )
+
+    def _g2p_keep_punctuations(self, g2p_model, text):
+        """do grapheme to phoneme and keep the punctuations between the words"""
+        # find the words where a "-" or "'" or "." or ":" appears in the middle
+        special_words = re.findall(r"\w+[-':\.][-':\.\w]*\w+", text)
+
+        # remove intra-word punctuations ("-':."), this does not change the output of speechbrain g2p
+        for special_word in special_words:
+            rmp = special_word.replace("-", "")
+            rmp = rmp.replace("'", "")
+            rmp = rmp.replace(":", "")
+            rmp = rmp.replace(".", "")
+            text = text.replace(special_word, rmp)
+
+        # keep inter-word punctuations
+        all_ = re.findall(r"[\w]+|[-!'(),.:;? ]", text)
+        try:
+            phonemes = g2p_model(text)
+        except RuntimeError:
+            logger.info(f"error with text: {text}")
+            quit()
+        word_phonemes = "-".join(phonemes).split(" ")
+
+        phonemes_with_punc = []
+        count = 0
+        try:
+            # if the g2p model splits the words correctly
+            for i in all_:
+                if i not in "-!'(),.:;? ":
+                    phonemes_with_punc.extend(word_phonemes[count].split("-"))
+                    count += 1
+                else:
+                    phonemes_with_punc.append(i)
+        except IndexError:
+            # sometimes the g2p model cannot split the words correctly
+            logger.warning(
+                f"Do g2p word by word because of unexpected ouputs from g2p for text: {text}"
+            )
+
+            for i in all_:
+                if i not in "-!'(),.:;? ":
+                    p = g2p_model.g2p(i)
+                    p_without_space = [i for i in p if i != " "]
+                    phonemes_with_punc.extend(p_without_space)
+                else:
+                    phonemes_with_punc.append(i)
+
+        while "" in phonemes_with_punc:
+            phonemes_with_punc.remove("")
+        return phonemes_with_punc
+
+    def encode_phoneme(
+        self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0
+    ):
+        """Computes mel-spectrogram for a list of phoneme sequences
+
+        Arguments
+        ---------
+        phonemes: List[List[str]]
+            phonemes to be converted to spectrogram
+        pace: float
+            pace for the speech synthesis
+        pitch_rate : float
+            scaling factor for phoneme pitches
+        energy_rate : float
+            scaling factor for phoneme energies
+
+        Returns
+        -------
+        tensors of output spectrograms, output lengths and alignments
+        """
+
+        all_tokens = []
+        max_seq_len = -1
+        for phoneme in phonemes:
+            token_seq = (
+                self.input_encoder.encode_sequence_torch(phoneme)
+                .int()
+                .to(self.device)
+            )
+            if max_seq_len < token_seq.shape[-1]:
+                max_seq_len = token_seq.shape[-1]
+            all_tokens.append(token_seq)
+
+        tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to(
+            self.device
+        )
+        tokens_padded.zero_()
+
+        for seq_idx, seq in enumerate(all_tokens):
+            tokens_padded[seq_idx, : len(seq)] = seq
+
+        return self.encode_batch(
+            tokens_padded,
+            pace=pace,
+            pitch_rate=pitch_rate,
+            energy_rate=energy_rate,
+        )
+
+    def encode_batch(
+        self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0
+    ):
+        """Batch inference for a tensor of phoneme sequences
+        Arguments
+        ---------
+        tokens_padded : torch.Tensor
+            A sequence of encoded phonemes to be converted to spectrogram
+        pace : float
+            pace for the speech synthesis
+        pitch_rate : float
+            scaling factor for phoneme pitches
+        energy_rate : float
+            scaling factor for phoneme energies
+        """
+        with torch.no_grad():
+            (
+                _,
+                post_mel_outputs,
+                durations,
+                pitch,
+                _,
+                energy,
+                _,
+                _,
+                _,
+                _,
+                _,
+                _,
+            ) = self.hparams.model(
+                tokens_padded,
+                pace=pace,
+                pitch_rate=pitch_rate,
+                energy_rate=energy_rate,
+            )
+
+            # Transposes to make in compliant with HiFI GAN expected format
+            post_mel_outputs = post_mel_outputs.transpose(-1, 1)
+
+        return post_mel_outputs, durations, pitch, energy
+
+    def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
+        """Batch inference for a tensor of phoneme sequences
+        Arguments
+        ---------
+        text : str
+            A text to be converted to spectrogram
+        pace : float
+            pace for the speech synthesis
+        pitch_rate : float
+            scaling factor for phoneme pitches
+        energy_rate : float
+            scaling factor for phoneme energies
+        """
+        return self.encode_text(
+            [text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate
+        )
diff --git a/speechbrain/inference/VAD.py b/speechbrain/inference/VAD.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f3d55e9c47432603aa7335a25a48e4564eda0ce
--- /dev/null
+++ b/speechbrain/inference/VAD.py
@@ -0,0 +1,961 @@
+""" Specifies the inference interfaces for Voice Activity Detection (VAD) modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+import torchaudio
+from speechbrain.utils.data_utils import split_path
+from speechbrain.utils.fetching import fetch
+from speechbrain.inference.interfaces import Pretrained
+
+
+class VAD(Pretrained):
+    """A ready-to-use class for Voice Activity Detection (VAD) using a
+    pre-trained model.
+
+    Example
+    -------
+    >>> import torchaudio
+    >>> from speechbrain.inference.VAD import VAD
+    >>> # Model is downloaded from the speechbrain HuggingFace repo
+    >>> tmpdir = getfixture("tmpdir")
+    >>> VAD = VAD.from_hparams(
+    ...     source="speechbrain/vad-crdnn-libriparty",
+    ...     savedir=tmpdir,
+    ... )
+
+    >>> # Perform VAD
+    >>> boundaries = VAD.get_speech_segments("tests/samples/single-mic/example1.wav")
+    """
+
+    HPARAMS_NEEDED = ["sample_rate", "time_resolution", "device"]
+
+    MODULES_NEEDED = ["compute_features", "mean_var_norm", "model"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.time_resolution = self.hparams.time_resolution
+        self.sample_rate = self.hparams.sample_rate
+
+    def get_speech_prob_file(
+        self,
+        audio_file,
+        large_chunk_size=30,
+        small_chunk_size=10,
+        overlap_small_chunk=False,
+    ):
+        """Outputs the frame-level speech probability of the input audio file
+        using the neural model specified in the hparam file. To make this code
+        both parallelizable and scalable to long sequences, it uses a
+        double-windowing approach.  First, we sequentially read non-overlapping
+        large chunks of the input signal.  We then split the large chunks into
+        smaller chunks and we process them in parallel.
+
+        Arguments
+        ---------
+        audio_file: path
+            Path of the audio file containing the recording. The file is read
+            with torchaudio.
+        large_chunk_size: float
+            Size (in seconds) of the large chunks that are read sequentially
+            from the input audio file.
+        small_chunk_size:
+            Size (in seconds) of the small chunks extracted from the large ones.
+            The audio signal is processed in parallel within the small chunks.
+            Note that large_chunk_size/small_chunk_size must be an integer.
+        overlap_small_chunk: bool
+            True, creates overlapped small chunks. The probabilities of the
+            overlapped chunks are combined using hamming windows.
+
+        Returns
+        -------
+        prob_vad: torch.Tensor
+            Tensor containing the frame-level speech probabilities for the
+            input audio file.
+        """
+        # Getting the total size of the input file
+        sample_rate, audio_len = self._get_audio_info(audio_file)
+
+        if sample_rate != self.sample_rate:
+            raise ValueError(
+                "The detected sample rate is different from that set in the hparam file"
+            )
+
+        # Computing the length (in samples) of the large and small chunks
+        long_chunk_len = int(sample_rate * large_chunk_size)
+        small_chunk_len = int(sample_rate * small_chunk_size)
+
+        # Setting the step size of the small chunk (50% overlapping windows are supported)
+        small_chunk_step = small_chunk_size
+        if overlap_small_chunk:
+            small_chunk_step = small_chunk_size / 2
+
+        # Computing the length (in sample) of the small_chunk step size
+        small_chunk_len_step = int(sample_rate * small_chunk_step)
+
+        # Loop over big chunks
+        prob_chunks = []
+        last_chunk = False
+        begin_sample = 0
+        while True:
+            # Check if the current chunk is the last one
+            if begin_sample + long_chunk_len >= audio_len:
+                last_chunk = True
+
+            # Reading the big chunk
+            large_chunk, fs = torchaudio.load(
+                str(audio_file),
+                frame_offset=begin_sample,
+                num_frames=long_chunk_len,
+            )
+            large_chunk = large_chunk.to(self.device)
+
+            # Manage padding of the last small chunk
+            if last_chunk or large_chunk.shape[-1] < small_chunk_len:
+                padding = torch.zeros(
+                    1, small_chunk_len, device=large_chunk.device
+                )
+                large_chunk = torch.cat([large_chunk, padding], dim=1)
+
+            # Splitting the big chunk into smaller (overlapped) ones
+            small_chunks = torch.nn.functional.unfold(
+                large_chunk.unsqueeze(1).unsqueeze(2),
+                kernel_size=(1, small_chunk_len),
+                stride=(1, small_chunk_len_step),
+            )
+            small_chunks = small_chunks.squeeze(0).transpose(0, 1)
+
+            # Getting (in parallel) the frame-level speech probabilities
+            small_chunks_prob = self.get_speech_prob_chunk(small_chunks)
+            small_chunks_prob = small_chunks_prob[:, :-1, :]
+
+            # Manage overlapping chunks
+            if overlap_small_chunk:
+                small_chunks_prob = self._manage_overlapped_chunks(
+                    small_chunks_prob
+                )
+
+            # Prepare for folding
+            small_chunks_prob = small_chunks_prob.permute(2, 1, 0)
+
+            # Computing lengths in samples
+            out_len = int(
+                large_chunk.shape[-1] / (sample_rate * self.time_resolution)
+            )
+            kernel_len = int(small_chunk_size / self.time_resolution)
+            step_len = int(small_chunk_step / self.time_resolution)
+
+            # Folding the frame-level predictions
+            small_chunks_prob = torch.nn.functional.fold(
+                small_chunks_prob,
+                output_size=(1, out_len),
+                kernel_size=(1, kernel_len),
+                stride=(1, step_len),
+            )
+
+            # Appending the frame-level speech probabilities of the large chunk
+            small_chunks_prob = small_chunks_prob.squeeze(1).transpose(-1, -2)
+            prob_chunks.append(small_chunks_prob)
+
+            # Check stop condition
+            if last_chunk:
+                break
+
+            # Update counter to process the next big chunk
+            begin_sample = begin_sample + long_chunk_len
+
+        # Converting the list to a tensor
+        prob_vad = torch.cat(prob_chunks, dim=1)
+        last_elem = int(audio_len / (self.time_resolution * sample_rate))
+        prob_vad = prob_vad[:, 0:last_elem, :]
+
+        return prob_vad
+
+    def _manage_overlapped_chunks(self, small_chunks_prob):
+        """This support function manages overlapped the case in which the
+        small chunks have a 50% overlap."""
+
+        # Weighting the frame-level probabilities with a hamming window
+        # reduces uncertainty when overlapping chunks are used.
+        hamming_window = torch.hamming_window(
+            small_chunks_prob.shape[1], device=self.device
+        )
+
+        # First and last chunks require special care
+        half_point = int(small_chunks_prob.shape[1] / 2)
+        small_chunks_prob[0, half_point:] = small_chunks_prob[
+            0, half_point:
+        ] * hamming_window[half_point:].unsqueeze(1)
+        small_chunks_prob[-1, 0:half_point] = small_chunks_prob[
+            -1, 0:half_point
+        ] * hamming_window[0:half_point].unsqueeze(1)
+
+        # Applying the window to all the other probabilities
+        small_chunks_prob[1:-1] = small_chunks_prob[
+            1:-1
+        ] * hamming_window.unsqueeze(0).unsqueeze(2)
+
+        return small_chunks_prob
+
+    def get_speech_prob_chunk(self, wavs, wav_lens=None):
+        """Outputs the frame-level posterior probability for the input audio chunks
+        Outputs close to zero refers to time steps with a low probability of speech
+        activity, while outputs closer to one likely contain speech.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model. Make sure the sample rate is fs=16000 Hz.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        torch.Tensor
+            The encoded batch
+        """
+        # Manage single waveforms in input
+        if len(wavs.shape) == 1:
+            wavs = wavs.unsqueeze(0)
+
+        # Assign full length if wav_lens is not assigned
+        if wav_lens is None:
+            wav_lens = torch.ones(wavs.shape[0], device=self.device)
+
+        # Storing waveform in the specified device
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+        wavs = wavs.float()
+
+        # Computing features and embeddings
+        feats = self.mods.compute_features(wavs)
+        feats = self.mods.mean_var_norm(feats, wav_lens)
+        outputs = self.mods.cnn(feats)
+
+        outputs = outputs.reshape(
+            outputs.shape[0],
+            outputs.shape[1],
+            outputs.shape[2] * outputs.shape[3],
+        )
+
+        outputs, h = self.mods.rnn(outputs)
+        outputs = self.mods.dnn(outputs)
+        output_prob = torch.sigmoid(outputs)
+
+        return output_prob
+
+    def apply_threshold(
+        self, vad_prob, activation_th=0.5, deactivation_th=0.25
+    ):
+        """Scans the frame-level speech probabilities and applies a threshold
+        on them. Speech starts when a value larger than activation_th is
+        detected, while it ends when observing a value lower than
+        the deactivation_th.
+
+        Arguments
+        ---------
+        vad_prob: torch.Tensor
+            Frame-level speech probabilities.
+        activation_th:  float
+            Threshold for starting a speech segment.
+        deactivation_th: float
+            Threshold for ending a speech segment.
+
+        Returns
+        -------
+        vad_th: torch.Tensor
+            Tensor containing 1 for speech regions and 0 for non-speech regions.
+        """
+        vad_activation = (vad_prob >= activation_th).int()
+        vad_deactivation = (vad_prob >= deactivation_th).int()
+        vad_th = vad_activation + vad_deactivation
+
+        # Loop over batches and time steps
+        for batch in range(vad_th.shape[0]):
+            for time_step in range(vad_th.shape[1] - 1):
+                if (
+                    vad_th[batch, time_step] == 2
+                    and vad_th[batch, time_step + 1] == 1
+                ):
+                    vad_th[batch, time_step + 1] = 2
+
+        vad_th[vad_th == 1] = 0
+        vad_th[vad_th == 2] = 1
+        return vad_th
+
+    def get_boundaries(self, prob_th, output_value="seconds"):
+        """Computes the time boundaries where speech activity is detected.
+        It takes in input frame-level binary decisions
+        (1 for speech, 0 for non-speech) and outputs the begin/end second
+        (or sample) of each detected speech region.
+
+        Arguments
+        ---------
+        prob_th: torch.Tensor
+            Frame-level binary decisions (1 for speech frame, 0 for a
+            non-speech one).  The tensor can be obtained from apply_threshold.
+        output_value: 'seconds' or 'samples'
+            When the option 'seconds' is set, the returned boundaries are in
+            seconds, otherwise, it reports them in samples.
+
+        Returns
+        -------
+        boundaries: torch.Tensor
+            Tensor containing the start second (or sample) of speech segments
+            in even positions and their corresponding end in odd positions
+            (e.g, [1.0, 1.5, 5,.0 6.0] means that we have two speech segment;
+             one from 1.0 to 1.5 seconds and another from 5.0 to 6.0 seconds).
+        """
+        # Shifting frame-levels binary decision by 1
+        # This allows detecting changes in speech/non-speech activities
+        prob_th_shifted = torch.roll(prob_th, dims=1, shifts=1)
+        prob_th_shifted[:, 0, :] = 0
+        prob_th = prob_th + prob_th_shifted
+
+        # Needed to first and last time step
+        prob_th[:, 0, :] = (prob_th[:, 0, :] >= 1).int()
+        prob_th[:, -1, :] = (prob_th[:, -1, :] >= 1).int()
+
+        # Fix edge cases (when a speech starts in the last frames)
+        if (prob_th == 1).nonzero().shape[0] % 2 == 1:
+            prob_th = torch.cat(
+                (
+                    prob_th,
+                    torch.Tensor([1.0])
+                    .unsqueeze(0)
+                    .unsqueeze(2)
+                    .to(self.device),
+                ),
+                dim=1,
+            )
+
+        # Where prob_th is 1 there is a change
+        indexes = (prob_th == 1).nonzero()[:, 1].reshape(-1, 2)
+
+        # Remove 1 from end samples
+        indexes[:, -1] = indexes[:, -1] - 1
+
+        # From indexes to samples
+        seconds = (indexes * self.time_resolution).float()
+        samples = (self.sample_rate * seconds).round().int()
+
+        if output_value == "seconds":
+            boundaries = seconds
+        else:
+            boundaries = samples
+        return boundaries
+
+    def merge_close_segments(self, boundaries, close_th=0.250):
+        """Merges segments that are shorter than the given threshold.
+
+        Arguments
+        ---------
+        boundaries : str
+            Tensor containing the speech boundaries. It can be derived using the
+            get_boundaries method.
+        close_th: float
+            If the distance between boundaries is smaller than close_th, the
+            segments will be merged.
+
+        Returns
+        -------
+        new_boundaries
+            The new boundaries with the merged segments.
+        """
+
+        new_boundaries = []
+
+        # Single segment case
+        if boundaries.shape[0] == 0:
+            return boundaries
+
+        # Getting beg and end of previous segment
+        prev_beg_seg = boundaries[0, 0].float()
+        prev_end_seg = boundaries[0, 1].float()
+
+        # Process all the segments
+        for i in range(1, boundaries.shape[0]):
+            beg_seg = boundaries[i, 0]
+            segment_distance = beg_seg - prev_end_seg
+
+            # Merging close segments
+            if segment_distance <= close_th:
+                prev_end_seg = boundaries[i, 1]
+
+            else:
+                # Appending new segments
+                new_boundaries.append([prev_beg_seg, prev_end_seg])
+                prev_beg_seg = beg_seg
+                prev_end_seg = boundaries[i, 1]
+
+        new_boundaries.append([prev_beg_seg, prev_end_seg])
+        new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device)
+        return new_boundaries
+
+    def remove_short_segments(self, boundaries, len_th=0.250):
+        """Removes segments that are too short.
+
+        Arguments
+        ---------
+        boundaries : torch.Tensor
+            Tensor containing the speech boundaries. It can be derived using the
+            get_boundaries method.
+        len_th: float
+            If the length of the segment is smaller than close_th, the segments
+            will be merged.
+
+        Returns
+        -------
+        new_boundaries
+            The new boundaries without the short segments.
+        """
+        new_boundaries = []
+
+        # Process the segments
+        for i in range(boundaries.shape[0]):
+            # Computing segment length
+            seg_len = boundaries[i, 1] - boundaries[i, 0]
+
+            # Accept segment only if longer than len_th
+            if seg_len > len_th:
+                new_boundaries.append([boundaries[i, 0], boundaries[i, 1]])
+        new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device)
+
+        return new_boundaries
+
+    def save_boundaries(
+        self, boundaries, save_path=None, print_boundaries=True, audio_file=None
+    ):
+        """Saves the boundaries on a file (and/or prints them)  in a readable format.
+
+        Arguments
+        ---------
+        boundaries: torch.Tensor
+            Tensor containing the speech boundaries. It can be derived using the
+            get_boundaries method.
+        save_path: path
+            When to store the text file containing the speech/non-speech intervals.
+        print_boundaries: Bool
+            Prints the speech/non-speech intervals in the standard outputs.
+        audio_file: path
+            Path of the audio file containing the recording. The file is read
+            with torchaudio. It is used here to detect the length of the
+            signal.
+        """
+        # Create a new file if needed
+        if save_path is not None:
+            f = open(save_path, mode="w", encoding="utf-8")
+
+        # Getting the total size of the input file
+        if audio_file is not None:
+            sample_rate, audio_len = self._get_audio_info(audio_file)
+            audio_len = audio_len / sample_rate
+
+        # Setting the rights format for second- or sample-based boundaries
+        if boundaries.dtype == torch.int:
+            value_format = "% i"
+        else:
+            value_format = "% .2f "
+
+        # Printing speech and non-speech intervals
+        last_end = 0
+        cnt_seg = 0
+        for i in range(boundaries.shape[0]):
+            begin_value = boundaries[i, 0]
+            end_value = boundaries[i, 1]
+
+            if last_end != begin_value:
+                cnt_seg = cnt_seg + 1
+                print_str = (
+                    "segment_%03d " + value_format + value_format + "NON_SPEECH"
+                )
+                if print_boundaries:
+                    print(print_str % (cnt_seg, last_end, begin_value))
+                if save_path is not None:
+                    f.write(print_str % (cnt_seg, last_end, begin_value) + "\n")
+
+            cnt_seg = cnt_seg + 1
+            print_str = "segment_%03d " + value_format + value_format + "SPEECH"
+            if print_boundaries:
+                print(print_str % (cnt_seg, begin_value, end_value))
+            if save_path is not None:
+                f.write(print_str % (cnt_seg, begin_value, end_value) + "\n")
+
+            last_end = end_value
+
+        # Managing last segment
+        if audio_file is not None:
+            if last_end < audio_len:
+                cnt_seg = cnt_seg + 1
+                print_str = (
+                    "segment_%03d " + value_format + value_format + "NON_SPEECH"
+                )
+                if print_boundaries:
+                    print(print_str % (cnt_seg, end_value, audio_len))
+                if save_path is not None:
+                    f.write(print_str % (cnt_seg, end_value, audio_len) + "\n")
+
+        if save_path is not None:
+            f.close()
+
+    def energy_VAD(
+        self,
+        audio_file,
+        boundaries,
+        activation_th=0.5,
+        deactivation_th=0.0,
+        eps=1e-6,
+    ):
+        """Applies energy-based VAD within the detected speech segments.The neural
+        network VAD often creates longer segments and tends to merge segments that
+        are close with each other.
+
+        The energy VAD post-processes can be useful for having a fine-grained voice
+        activity detection.
+
+        The energy VAD computes the energy within the small chunks. The energy is
+        normalized within the segment to have mean 0.5 and +-0.5 of std.
+        This helps to set the energy threshold.
+
+        Arguments
+        ---------
+        audio_file: path
+            Path of the audio file containing the recording. The file is read
+            with torchaudio.
+        boundaries : torch.Tensor
+            Tensor containing the speech boundaries. It can be derived using the
+            get_boundaries method.
+        activation_th: float
+            A new speech segment is started it the energy is above activation_th.
+        deactivation_th: float
+            The segment is considered ended when the energy is <= deactivation_th.
+        eps: float
+            Small constant for numerical stability.
+
+
+        Returns
+        -------
+        new_boundaries
+            The new boundaries that are post-processed by the energy VAD.
+        """
+
+        # Getting the total size of the input file
+        sample_rate, audio_len = self._get_audio_info(audio_file)
+
+        if sample_rate != self.sample_rate:
+            raise ValueError(
+                "The detected sample rate is different from that set in the hparam file"
+            )
+
+        # Computing the chunk length of the energy window
+        chunk_len = int(self.time_resolution * sample_rate)
+        new_boundaries = []
+
+        # Processing speech segments
+        for i in range(boundaries.shape[0]):
+            begin_sample = int(boundaries[i, 0] * sample_rate)
+            end_sample = int(boundaries[i, 1] * sample_rate)
+            seg_len = end_sample - begin_sample
+
+            # Reading the speech segment
+            segment, _ = torchaudio.load(
+                audio_file, frame_offset=begin_sample, num_frames=seg_len
+            )
+
+            # Create chunks
+            segment_chunks = self.create_chunks(
+                segment, chunk_size=chunk_len, chunk_stride=chunk_len
+            )
+
+            # Energy computation within each chunk
+            energy_chunks = segment_chunks.abs().sum(-1) + eps
+            energy_chunks = energy_chunks.log()
+
+            # Energy normalization
+            energy_chunks = (
+                (energy_chunks - energy_chunks.mean())
+                / (2 * energy_chunks.std())
+            ) + 0.5
+            energy_chunks = energy_chunks.unsqueeze(0).unsqueeze(2)
+
+            # Apply threshold based on the energy value
+            energy_vad = self.apply_threshold(
+                energy_chunks,
+                activation_th=activation_th,
+                deactivation_th=deactivation_th,
+            )
+
+            # Get the boundaries
+            energy_boundaries = self.get_boundaries(
+                energy_vad, output_value="seconds"
+            )
+
+            # Get the final boundaries in the original signal
+            for j in range(energy_boundaries.shape[0]):
+                start_en = boundaries[i, 0] + energy_boundaries[j, 0]
+                end_end = boundaries[i, 0] + energy_boundaries[j, 1]
+                new_boundaries.append([start_en, end_end])
+
+        # Convert boundaries to tensor
+        new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device)
+        return new_boundaries
+
+    def create_chunks(self, x, chunk_size=16384, chunk_stride=16384):
+        """Splits the input into smaller chunks of size chunk_size with
+        an overlap chunk_stride. The chunks are concatenated over
+        the batch axis.
+
+        Arguments
+        ---------
+        x: torch.Tensor
+            Signal to split into chunks.
+        chunk_size : str
+            The size of each chunk.
+        chunk_stride:
+            The stride (hop) of each chunk.
+
+
+        Returns
+        -------
+        x: torch.Tensor
+            A new tensors with the chunks derived from the input signal.
+
+        """
+        x = x.unfold(1, chunk_size, chunk_stride)
+        x = x.reshape(x.shape[0] * x.shape[1], -1)
+        return x
+
+    def _get_audio_info(self, audio_file):
+        """Returns the sample rate and the length of the input audio file"""
+
+        # Getting the total size of the input file
+        metadata = torchaudio.info(str(audio_file))
+        sample_rate = metadata.sample_rate
+        audio_len = metadata.num_frames
+        return sample_rate, audio_len
+
+    def upsample_VAD(self, vad_out, audio_file, time_resolution=0.01):
+        """Upsamples the output of the vad to help visualization. It creates a
+        signal that is 1 when there is speech and 0 when there is no speech.
+        The vad signal has the same resolution as the input one and can be
+        opened with it (e.g, using audacity) to visually figure out VAD regions.
+
+        Arguments
+        ---------
+        vad_out: torch.Tensor
+            Tensor containing 1 for each frame of speech and 0 for each non-speech
+            frame.
+        audio_file: path
+            The original audio file used to compute vad_out
+        time_resolution : float
+            Time resolution of the vad_out signal.
+
+        Returns
+        -------
+        vad_signal
+            The upsampled version of the vad_out tensor.
+        """
+
+        # Getting the total size of the input file
+        sample_rate, sig_len = self._get_audio_info(audio_file)
+
+        if sample_rate != self.sample_rate:
+            raise ValueError(
+                "The detected sample rate is different from that set in the hparam file"
+            )
+
+        beg_samp = 0
+        step_size = int(time_resolution * sample_rate)
+        end_samp = step_size
+        index = 0
+
+        # Initialize upsampled signal
+        vad_signal = torch.zeros(1, sig_len, device=vad_out.device)
+
+        # Upsample signal
+        while end_samp < sig_len:
+            vad_signal[0, beg_samp:end_samp] = vad_out[0, index, 0]
+            index = index + 1
+            beg_samp = beg_samp + step_size
+            end_samp = beg_samp + step_size
+        return vad_signal
+
+    def upsample_boundaries(self, boundaries, audio_file):
+        """Based on the input boundaries, this method creates a signal that is 1
+        when there is speech and 0 when there is no speech.
+        The vad signal has the same resolution as the input one and can be
+        opened with it (e.g, using audacity) to visually figure out VAD regions.
+
+        Arguments
+        ---------
+        boundaries: torch.Tensor
+            Tensor containing the boundaries of the speech segments.
+        audio_file: path
+            The original audio file used to compute vad_out
+
+        Returns
+        -------
+        vad_signal
+            The output vad signal with the same resolution of the input one.
+        """
+
+        # Getting the total size of the input file
+        sample_rate, sig_len = self._get_audio_info(audio_file)
+
+        if sample_rate != self.sample_rate:
+            raise ValueError(
+                "The detected sample rate is different from that set in the hparam file"
+            )
+
+        # Initialization of the output signal
+        vad_signal = torch.zeros(1, sig_len, device=boundaries.device)
+
+        # Composing the vad signal from boundaries
+        for i in range(boundaries.shape[0]):
+            beg_sample = int(boundaries[i, 0] * sample_rate)
+            end_sample = int(boundaries[i, 1] * sample_rate)
+            vad_signal[0, beg_sample:end_sample] = 1.0
+        return vad_signal
+
+    def double_check_speech_segments(
+        self, boundaries, audio_file, speech_th=0.5
+    ):
+        """Takes in input the boundaries of the detected speech segments and
+        double checks (using the neural VAD) that they actually contain speech.
+
+        Arguments
+        ---------
+        boundaries: torch.Tensor
+            Tensor containing the boundaries of the speech segments.
+        audio_file: path
+            The original audio file used to compute vad_out.
+        speech_th: float
+            Threshold on the mean posterior probability over which speech is
+            confirmed. Below that threshold, the segment is re-assigned to a
+            non-speech region.
+
+        Returns
+        -------
+        new_boundaries
+            The boundaries of the segments where speech activity is confirmed.
+        """
+
+        # Getting the total size of the input file
+        sample_rate, sig_len = self._get_audio_info(audio_file)
+
+        # Double check the segments
+        new_boundaries = []
+        for i in range(boundaries.shape[0]):
+            beg_sample = int(boundaries[i, 0] * sample_rate)
+            end_sample = int(boundaries[i, 1] * sample_rate)
+            len_seg = end_sample - beg_sample
+
+            # Read the candidate speech segment
+            segment, fs = torchaudio.load(
+                str(audio_file), frame_offset=beg_sample, num_frames=len_seg
+            )
+            speech_prob = self.get_speech_prob_chunk(segment)
+            if speech_prob.mean() > speech_th:
+                # Accept this as a speech segment
+                new_boundaries.append([boundaries[i, 0], boundaries[i, 1]])
+
+        # Convert boundaries from list to tensor
+        new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device)
+        return new_boundaries
+
+    def get_segments(
+        self, boundaries, audio_file, before_margin=0.1, after_margin=0.1
+    ):
+        """Returns a list containing all the detected speech segments.
+
+        Arguments
+        ---------
+        boundaries: torch.Tensor
+            Tensor containing the boundaries of the speech segments.
+        audio_file: path
+            The original audio file used to compute vad_out.
+        before_margin: float
+            Used to cut the segments samples a bit before the detected margin.
+        after_margin: float
+            Use to cut the segments samples a bit after the detected margin.
+
+        Returns
+        -------
+        segments: list
+            List containing the detected speech segments
+        """
+        sample_rate, sig_len = self._get_audio_info(audio_file)
+
+        if sample_rate != self.sample_rate:
+            raise ValueError(
+                "The detected sample rate is different from that set in the hparam file"
+            )
+
+        segments = []
+        for i in range(boundaries.shape[0]):
+            beg_sample = boundaries[i, 0] * sample_rate
+            end_sample = boundaries[i, 1] * sample_rate
+
+            beg_sample = int(max(0, beg_sample - before_margin * sample_rate))
+            end_sample = int(
+                min(sig_len, end_sample + after_margin * sample_rate)
+            )
+
+            len_seg = end_sample - beg_sample
+            vad_segment, fs = torchaudio.load(
+                audio_file, frame_offset=beg_sample, num_frames=len_seg
+            )
+            segments.append(vad_segment)
+        return segments
+
+    def get_speech_segments(
+        self,
+        audio_file,
+        large_chunk_size=30,
+        small_chunk_size=10,
+        overlap_small_chunk=False,
+        apply_energy_VAD=False,
+        double_check=True,
+        close_th=0.250,
+        len_th=0.250,
+        activation_th=0.5,
+        deactivation_th=0.25,
+        en_activation_th=0.5,
+        en_deactivation_th=0.0,
+        speech_th=0.50,
+    ):
+        """Detects speech segments within the input file. The input signal can
+        be both a short or a long recording. The function computes the
+        posterior probabilities on large chunks (e.g, 30 sec), that are read
+        sequentially (to avoid storing big signals in memory).
+        Each large chunk is, in turn, split into smaller chunks (e.g, 10 seconds)
+        that are processed in parallel. The pipeline for detecting the speech
+        segments is the following:
+            1- Compute posteriors probabilities at the frame level.
+            2- Apply a threshold on the posterior probability.
+            3- Derive candidate speech segments on top of that.
+            4- Apply energy VAD within each candidate segment (optional).
+            5- Merge segments that are too close.
+            6- Remove segments that are too short.
+            7- Double check speech segments (optional).
+
+
+        Arguments
+        ---------
+        audio_file : str
+            Path to audio file.
+        large_chunk_size: float
+            Size (in seconds) of the large chunks that are read sequentially
+            from the input audio file.
+        small_chunk_size: float
+            Size (in seconds) of the small chunks extracted from the large ones.
+            The audio signal is processed in parallel within the small chunks.
+            Note that large_chunk_size/small_chunk_size must be an integer.
+        overlap_small_chunk: bool
+            If True, it creates overlapped small chunks (with 50% overlap).
+            The probabilities of the overlapped chunks are combined using
+            hamming windows.
+        apply_energy_VAD: bool
+            If True, a energy-based VAD is used on the detected speech segments.
+            The neural network VAD often creates longer segments and tends to
+            merge close segments together. The energy VAD post-processes can be
+            useful for having a fine-grained voice activity detection.
+            The energy thresholds is  managed by activation_th and
+            deactivation_th (see below).
+        double_check: bool
+            If True, double checks (using the neural VAD) that the candidate
+            speech segments actually contain speech. A threshold on the mean
+            posterior probabilities provided by the neural network is applied
+            based on the speech_th parameter (see below).
+        activation_th:  float
+            Threshold of the neural posteriors above which starting a speech segment.
+        deactivation_th: float
+            Threshold of the neural posteriors below which ending a speech segment.
+        en_activation_th: float
+            A new speech segment is started it the energy is above activation_th.
+            This is active only if apply_energy_VAD is True.
+        en_deactivation_th: float
+            The segment is considered ended when the energy is <= deactivation_th.
+            This is active only if apply_energy_VAD is True.
+        speech_th: float
+            Threshold on the mean posterior probability within the candidate
+            speech segment. Below that threshold, the segment is re-assigned to
+            a non-speech region. This is active only if double_check is True.
+        close_th: float
+            If the distance between boundaries is smaller than close_th, the
+            segments will be merged.
+        len_th: float
+            If the length of the segment is smaller than close_th, the segments
+            will be merged.
+
+        Returns
+        -------
+        boundaries: torch.Tensor
+            Tensor containing the start second of speech segments in even
+            positions and their corresponding end in odd positions
+            (e.g, [1.0, 1.5, 5,.0 6.0] means that we have two speech segment;
+             one from 1.0 to 1.5 seconds and another from 5.0 to 6.0 seconds).
+        """
+
+        # Fetch audio file from web if not local
+        source, fl = split_path(audio_file)
+        audio_file = fetch(fl, source=source)
+
+        # Computing speech vs non speech probabilities
+        prob_chunks = self.get_speech_prob_file(
+            audio_file,
+            large_chunk_size=large_chunk_size,
+            small_chunk_size=small_chunk_size,
+            overlap_small_chunk=overlap_small_chunk,
+        )
+
+        # Apply a threshold to get candidate speech segments
+        prob_th = self.apply_threshold(
+            prob_chunks,
+            activation_th=activation_th,
+            deactivation_th=deactivation_th,
+        ).float()
+
+        # Compute the boundaries of the speech segments
+        boundaries = self.get_boundaries(prob_th, output_value="seconds")
+
+        # Apply energy-based VAD on the detected speech segments
+        if apply_energy_VAD:
+            boundaries = self.energy_VAD(
+                audio_file,
+                boundaries,
+                activation_th=en_activation_th,
+                deactivation_th=en_deactivation_th,
+            )
+
+        # Merge short segments
+        boundaries = self.merge_close_segments(boundaries, close_th=close_th)
+
+        # Remove short segments
+        boundaries = self.remove_short_segments(boundaries, len_th=len_th)
+
+        # Double check speech segments
+        if double_check:
+            boundaries = self.double_check_speech_segments(
+                boundaries, audio_file, speech_th=speech_th
+            )
+
+        return boundaries
+
+    def forward(self, wavs, wav_lens=None):
+        """Gets frame-level speech-activity predictions"""
+        return self.get_speech_prob_chunk(wavs, wav_lens)
diff --git a/speechbrain/inference/__init__.py b/speechbrain/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a73cef81d2185b3845fd5b2364191009fe1955c5
--- /dev/null
+++ b/speechbrain/inference/__init__.py
@@ -0,0 +1,18 @@
+"""Importing all the inference interfaces"""
+
+from . import *  # noqa
+
+from .ASR import *  # noqa
+from .classifiers import *  # noqa
+from .diarization import *  # noqa
+from .encoders import *  # noqa
+from .enhancement import *  # noqa
+from .interfaces import *  # noqa
+from .separation import *  # noqa
+from .SLU import *  # noqa
+from .speaker import *  # noqa
+from .ST import *  # noqa
+from .text import *  # noqa
+from .TTS import *  # noqa
+from .VAD import *  # noqa
+from .vocoders import *  # noqa
diff --git a/speechbrain/inference/classifiers.py b/speechbrain/inference/classifiers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a59c9ff551305734c770b2ead89693d0849dcc9a
--- /dev/null
+++ b/speechbrain/inference/classifiers.py
@@ -0,0 +1,314 @@
+""" Specifies the inference interfaces for Audio Classification modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+import torchaudio
+import speechbrain
+from speechbrain.utils.fetching import fetch
+from speechbrain.utils.data_utils import split_path
+from speechbrain.inference.interfaces import Pretrained
+
+
+class EncoderClassifier(Pretrained):
+    """A ready-to-use class for utterance-level classification (e.g, speaker-id,
+    language-id, emotion recognition, keyword spotting, etc).
+
+    The class assumes that an encoder called "embedding_model" and a model
+    called "classifier" are defined in the yaml file. If you want to
+    convert the predicted index into a corresponding text label, please
+    provide the path of the label_encoder in a variable called 'lab_encoder_file'
+    within the yaml.
+
+    The class can be used either to run only the encoder (encode_batch()) to
+    extract embeddings or to run a classification step (classify_batch()).
+    ```
+
+    Example
+    -------
+    >>> import torchaudio
+    >>> from speechbrain.inference.classifiers import EncoderClassifier
+    >>> # Model is downloaded from the speechbrain HuggingFace repo
+    >>> tmpdir = getfixture("tmpdir")
+    >>> classifier = EncoderClassifier.from_hparams(
+    ...     source="speechbrain/spkrec-ecapa-voxceleb",
+    ...     savedir=tmpdir,
+    ... )
+    >>> classifier.hparams.label_encoder.ignore_len()
+
+    >>> # Compute embeddings
+    >>> signal, fs = torchaudio.load("tests/samples/single-mic/example1.wav")
+    >>> embeddings = classifier.encode_batch(signal)
+
+    >>> # Classification
+    >>> prediction = classifier.classify_batch(signal)
+    """
+
+    MODULES_NEEDED = [
+        "compute_features",
+        "mean_var_norm",
+        "embedding_model",
+        "classifier",
+    ]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def encode_batch(self, wavs, wav_lens=None, normalize=False):
+        """Encodes the input audio into a single vector embedding.
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = <this>.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model. Make sure the sample rate is fs=16000 Hz.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+        normalize : bool
+            If True, it normalizes the embeddings with the statistics
+            contained in mean_var_norm_emb.
+
+        Returns
+        -------
+        torch.Tensor
+            The encoded batch
+        """
+        # Manage single waveforms in input
+        if len(wavs.shape) == 1:
+            wavs = wavs.unsqueeze(0)
+
+        # Assign full length if wav_lens is not assigned
+        if wav_lens is None:
+            wav_lens = torch.ones(wavs.shape[0], device=self.device)
+
+        # Storing waveform in the specified device
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+        wavs = wavs.float()
+
+        # Computing features and embeddings
+        feats = self.mods.compute_features(wavs)
+        feats = self.mods.mean_var_norm(feats, wav_lens)
+        embeddings = self.mods.embedding_model(feats, wav_lens)
+        if normalize:
+            embeddings = self.hparams.mean_var_norm_emb(
+                embeddings, torch.ones(embeddings.shape[0], device=self.device)
+            )
+        return embeddings
+
+    def classify_batch(self, wavs, wav_lens=None):
+        """Performs classification on the top of the encoded features.
+
+        It returns the posterior probabilities, the index and, if the label
+        encoder is specified it also the text label.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model. Make sure the sample rate is fs=16000 Hz.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        out_prob
+            The log posterior probabilities of each class ([batch, N_class])
+        score:
+            It is the value of the log-posterior for the best class ([batch,])
+        index
+            The indexes of the best class ([batch,])
+        text_lab:
+            List with the text labels corresponding to the indexes.
+            (label encoder should be provided).
+        """
+        emb = self.encode_batch(wavs, wav_lens)
+        out_prob = self.mods.classifier(emb).squeeze(1)
+        score, index = torch.max(out_prob, dim=-1)
+        text_lab = self.hparams.label_encoder.decode_torch(index)
+        return out_prob, score, index, text_lab
+
+    def classify_file(self, path, **kwargs):
+        """Classifies the given audiofile into the given set of labels.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file to classify.
+
+        Returns
+        -------
+        out_prob
+            The log posterior probabilities of each class ([batch, N_class])
+        score:
+            It is the value of the log-posterior for the best class ([batch,])
+        index
+            The indexes of the best class ([batch,])
+        text_lab:
+            List with the text labels corresponding to the indexes.
+            (label encoder should be provided).
+        """
+        waveform = self.load_audio(path, **kwargs)
+        # Fake a batch:
+        batch = waveform.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+        emb = self.encode_batch(batch, rel_length)
+        out_prob = self.mods.classifier(emb).squeeze(1)
+        score, index = torch.max(out_prob, dim=-1)
+        text_lab = self.hparams.label_encoder.decode_torch(index)
+        return out_prob, score, index, text_lab
+
+    def forward(self, wavs, wav_lens=None):
+        """Runs the classification"""
+        return self.classify_batch(wavs, wav_lens)
+
+
+class AudioClassifier(Pretrained):
+    """A ready-to-use class for utterance-level classification (e.g, speaker-id,
+    language-id, emotion recognition, keyword spotting, etc).
+
+    The class assumes that an encoder called "embedding_model" and a model
+    called "classifier" are defined in the yaml file. If you want to
+    convert the predicted index into a corresponding text label, please
+    provide the path of the label_encoder in a variable called 'lab_encoder_file'
+    within the yaml.
+
+    The class can be used either to run only the encoder (encode_batch()) to
+    extract embeddings or to run a classification step (classify_batch()).
+    ```
+
+    Example
+    -------
+    >>> import torchaudio
+    >>> from speechbrain.inference.classifiers import AudioClassifier
+    >>> tmpdir = getfixture("tmpdir")
+    >>> classifier = AudioClassifier.from_hparams(
+    ...     source="speechbrain/cnn14-esc50",
+    ...     savedir=tmpdir,
+    ... )
+    >>> signal = torch.randn(1, 16000)
+    >>> prediction, _, _, text_lab = classifier.classify_batch(signal)
+    >>> print(prediction.shape)
+    torch.Size([1, 1, 50])
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def classify_batch(self, wavs, wav_lens=None):
+        """Performs classification on the top of the encoded features.
+
+        It returns the posterior probabilities, the index and, if the label
+        encoder is specified it also the text label.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model. Make sure the sample rate is fs=16000 Hz.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        out_prob
+            The log posterior probabilities of each class ([batch, N_class])
+        score:
+            It is the value of the log-posterior for the best class ([batch,])
+        index
+            The indexes of the best class ([batch,])
+        text_lab:
+            List with the text labels corresponding to the indexes.
+            (label encoder should be provided).
+        """
+        wavs = wavs.to(self.device)
+        X_stft = self.mods.compute_stft(wavs)
+        X_stft_power = speechbrain.processing.features.spectral_magnitude(
+            X_stft, power=self.hparams.spec_mag_power
+        )
+
+        if self.hparams.use_melspectra:
+            net_input = self.mods.compute_fbank(X_stft_power)
+        else:
+            net_input = torch.log1p(X_stft_power)
+
+        # Embeddings + sound classifier
+        embeddings = self.mods.embedding_model(net_input)
+        if embeddings.ndim == 4:
+            embeddings = embeddings.mean((-1, -2))
+
+        out_probs = self.mods.classifier(embeddings)
+        score, index = torch.max(out_probs, dim=-1)
+        text_lab = self.hparams.label_encoder.decode_torch(index)
+        return out_probs, score, index, text_lab
+
+    def classify_file(self, path, savedir="audio_cache"):
+        """Classifies the given audiofile into the given set of labels.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file to classify.
+
+        Returns
+        -------
+        out_prob
+            The log posterior probabilities of each class ([batch, N_class])
+        score:
+            It is the value of the log-posterior for the best class ([batch,])
+        index
+            The indexes of the best class ([batch,])
+        text_lab:
+            List with the text labels corresponding to the indexes.
+            (label encoder should be provided).
+        """
+        source, fl = split_path(path)
+        path = fetch(fl, source=source, savedir=savedir)
+
+        batch, fs_file = torchaudio.load(path)
+        batch = batch.to(self.device)
+        fs_model = self.hparams.sample_rate
+
+        # resample the data if needed
+        if fs_file != fs_model:
+            print(
+                "Resampling the audio from {} Hz to {} Hz".format(
+                    fs_file, fs_model
+                )
+            )
+            tf = torchaudio.transforms.Resample(
+                orig_freq=fs_file, new_freq=fs_model
+            ).to(self.device)
+            batch = batch.mean(dim=0, keepdim=True)
+            batch = tf(batch)
+
+        out_probs, score, index, text_lab = self.classify_batch(batch)
+        return out_probs, score, index, text_lab
+
+    def forward(self, wavs, wav_lens=None):
+        """Runs the classification"""
+        return self.classify_batch(wavs, wav_lens)
diff --git a/speechbrain/inference/diarization.py b/speechbrain/inference/diarization.py
new file mode 100644
index 0000000000000000000000000000000000000000..91a3015e9a9fb5315386744b9bc775f69988cb10
--- /dev/null
+++ b/speechbrain/inference/diarization.py
@@ -0,0 +1,230 @@
+""" Specifies the inference interfaces for diarization modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+from speechbrain.inference.interfaces import Pretrained
+
+
+class Speech_Emotion_Diarization(Pretrained):
+    """A ready-to-use SED interface (audio -> emotions and their durations)
+
+    Arguments
+    ---------
+    hparams
+        Hyperparameters (from HyperPyYAML)
+
+    Example
+    -------
+    >>> from speechbrain.inference.diarization import Speech_Emotion_Diarization
+    >>> tmpdir = getfixture("tmpdir")
+    >>> sed_model = Speech_Emotion_Diarization.from_hparams(source="speechbrain/emotion-diarization-wavlm-large", savedir=tmpdir,) # doctest: +SKIP
+    >>> sed_model.diarize_file("speechbrain/emotion-diarization-wavlm-large/example.wav") # doctest: +SKIP
+    """
+
+    MODULES_NEEDED = ["input_norm", "wav2vec", "output_mlp"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def diarize_file(self, path):
+        """Get emotion diarization of a spoken utterance.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file which to diarize.
+
+        Returns
+        -------
+        list of dictionary: List[Dict[List]]
+            The emotions and their temporal boundaries.
+        """
+        waveform = self.load_audio(path)
+        # Fake a batch:
+        batch = waveform.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+        frame_class = self.diarize_batch(batch, rel_length, [path])
+        return frame_class
+
+    def encode_batch(self, wavs, wav_lens):
+        """Encodes audios into fine-grained emotional embeddings
+
+        Arguments
+        ---------
+        wavs : torch.tensor
+            Batch of waveforms [batch, time, channels].
+        wav_lens : torch.tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        torch.tensor
+            The encoded batch
+        """
+        if len(wavs.shape) == 1:
+            wavs = wavs.unsqueeze(0)
+
+        # Assign full length if wav_lens is not assigned
+        if wav_lens is None:
+            wav_lens = torch.ones(wavs.shape[0], device=self.device)
+
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+
+        wavs = self.mods.input_norm(wavs, wav_lens)
+        outputs = self.mods.wav2vec2(wavs)
+        return outputs
+
+    def diarize_batch(self, wavs, wav_lens, batch_id):
+        """Get emotion diarization of a batch of waveforms.
+
+        The waveforms should already be in the model's desired format.
+        You can call:
+        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
+        to get a correctly converted signal in most cases.
+
+        Arguments
+        ---------
+        wavs : torch.tensor
+            Batch of waveforms [batch, time, channels].
+        wav_lens : torch.tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+        batch_id : torch.tensor
+            id of each batch (file names etc.)
+
+        Returns
+        -------
+        list of dictionary: List[Dict[List]]
+            The emotions and their temporal boundaries.
+        """
+        outputs = self.encode_batch(wavs, wav_lens)
+        averaged_out = self.hparams.avg_pool(outputs)
+        outputs = self.mods.output_mlp(averaged_out)
+        outputs = self.hparams.log_softmax(outputs)
+        score, index = torch.max(outputs, dim=-1)
+        preds = self.hparams.label_encoder.decode_torch(index)
+        results = self.preds_to_diarization(preds, batch_id)
+        return results
+
+    def preds_to_diarization(self, prediction, batch_id):
+        """Convert frame-wise predictions into a dictionary of
+        diarization results.
+
+        Returns
+        -------
+        dictionary
+            A dictionary with the start/end of each emotion
+        """
+        results = {}
+
+        for i in range(len(prediction)):
+            pred = prediction[i]
+            lol = []
+            for j in range(len(pred)):
+                start = round(self.hparams.stride * 0.02 * j, 2)
+                end = round(start + self.hparams.window_length * 0.02, 2)
+                lol.append([batch_id[i], start, end, pred[j]])
+
+            lol = self.merge_ssegs_same_emotion_adjacent(lol)
+            results[batch_id[i]] = [
+                {"start": k[1], "end": k[2], "emotion": k[3]} for k in lol
+            ]
+            return results
+
+    def forward(self, wavs, wav_lens, batch_id):
+        """Get emotion diarization for a batch of waveforms."""
+        return self.diarize_batch(wavs, wav_lens, batch_id)
+
+    def is_overlapped(self, end1, start2):
+        """Returns True if segments are overlapping.
+
+        Arguments
+        ---------
+        end1 : float
+            End time of the first segment.
+        start2 : float
+            Start time of the second segment.
+
+        Returns
+        -------
+        overlapped : bool
+            True of segments overlapped else False.
+
+        Example
+        -------
+        >>> from speechbrain.processing import diarization as diar
+        >>> diar.is_overlapped(5.5, 3.4)
+        True
+        >>> diar.is_overlapped(5.5, 6.4)
+        False
+        """
+
+        if start2 > end1:
+            return False
+        else:
+            return True
+
+    def merge_ssegs_same_emotion_adjacent(self, lol):
+        """Merge adjacent sub-segs if they are the same emotion.
+        Arguments
+        ---------
+        lol : list of list
+            Each list contains [utt_id, sseg_start, sseg_end, emo_label].
+        Returns
+        -------
+        new_lol : list of list
+            new_lol contains adjacent segments merged from the same emotion ID.
+        Example
+        -------
+        >>> from speechbrain.utils.EDER import merge_ssegs_same_emotion_adjacent
+        >>> lol=[['u1', 0.0, 7.0, 'a'],
+        ... ['u1', 7.0, 9.0, 'a'],
+        ... ['u1', 9.0, 11.0, 'n'],
+        ... ['u1', 11.0, 13.0, 'n'],
+        ... ['u1', 13.0, 15.0, 'n'],
+        ... ['u1', 15.0, 16.0, 'a']]
+        >>> merge_ssegs_same_emotion_adjacent(lol)
+        [['u1', 0.0, 9.0, 'a'], ['u1', 9.0, 15.0, 'n'], ['u1', 15.0, 16.0, 'a']]
+        """
+        new_lol = []
+
+        # Start from the first sub-seg
+        sseg = lol[0]
+        flag = False
+        for i in range(1, len(lol)):
+            next_sseg = lol[i]
+            # IF sub-segments overlap AND has same emotion THEN merge
+            if (
+                self.is_overlapped(sseg[2], next_sseg[1])
+                and sseg[3] == next_sseg[3]
+            ):
+                sseg[2] = next_sseg[2]  # just update the end time
+                # This is important. For the last sseg, if it is the same emotion then merge
+                # Make sure we don't append the last segment once more. Hence, set FLAG=True
+                if i == len(lol) - 1:
+                    flag = True
+                    new_lol.append(sseg)
+            else:
+                new_lol.append(sseg)
+                sseg = next_sseg
+        # Add last segment only when it was skipped earlier.
+        if flag is False:
+            new_lol.append(lol[-1])
+        return new_lol
diff --git a/speechbrain/inference/encoders.py b/speechbrain/inference/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a5ecdd38fd114ea310a3de4ae520022a085e144
--- /dev/null
+++ b/speechbrain/inference/encoders.py
@@ -0,0 +1,261 @@
+""" Specifies the inference interfaces for speech and audio encoders.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+from speechbrain.inference.interfaces import Pretrained
+
+
+class WaveformEncoder(Pretrained):
+    """A ready-to-use waveformEncoder model
+
+    It can be used to wrap different embedding models such as SSL ones (wav2vec2)
+    or speaker ones (Xvector) etc. Two functions are available: encode_batch and
+    encode_file. They can be used to obtain the embeddings directly from an audio
+    file or from a batch of audio tensors respectively.
+
+    The given YAML must contain the fields specified in the *_NEEDED[] lists.
+
+    Example
+    -------
+    >>> from speechbrain.inference.encoders import WaveformEncoder
+    >>> tmpdir = getfixture("tmpdir")
+    >>> ssl_model = WaveformEncoder.from_hparams(
+    ...     source="speechbrain/ssl-wav2vec2-base-libri",
+    ...     savedir=tmpdir,
+    ... ) # doctest: +SKIP
+    >>> ssl_model.encode_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
+    """
+
+    MODULES_NEEDED = ["encoder"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def encode_file(self, path, **kwargs):
+        """Encode the given audiofile into a sequence of embeddings.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file which to encode.
+
+        Returns
+        -------
+        torch.Tensor
+            The audiofile embeddings produced by this system.
+        """
+        waveform = self.load_audio(path, **kwargs)
+        # Fake a batch:
+        batch = waveform.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+        results = self.encode_batch(batch, rel_length)
+        return results["embeddings"]
+
+    def encode_batch(self, wavs, wav_lens):
+        """Encodes the input audio into a sequence of hidden states
+
+        The waveforms should already be in the model's desired format.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model.
+        wav_lens : torch.Tensor
+            Lengths of the waveforms relative to the longest one in the
+            batch, tensor of shape [batch]. The longest one should have
+            relative length 1.0 and others len(waveform) / max_length.
+            Used for ignoring padding.
+
+        Returns
+        -------
+        torch.Tensor
+            The encoded batch
+        """
+        wavs = wavs.float()
+        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
+        encoder_out = self.mods.encoder(wavs, wav_lens)
+        return encoder_out
+
+    def forward(self, wavs, wav_lens):
+        """Runs the encoder"""
+        return self.encode_batch(wavs, wav_lens)
+
+
+class MelSpectrogramEncoder(Pretrained):
+    """A MelSpectrogramEncoder class created for the Zero-Shot Multi-Speaker TTS models.
+
+    This is for speaker encoder models using the PyTorch MelSpectrogram transform for compatibility with the
+    current TTS pipeline.
+
+    This class can be used to encode a single waveform, a single mel-spectrogram, or a batch of mel-spectrograms.
+    ```
+
+    Example
+    -------
+    >>> import torchaudio
+    >>> from speechbrain.inference.encoders import MelSpectrogramEncoder
+    >>> # Model is downloaded from the speechbrain HuggingFace repo
+    >>> tmpdir = getfixture("tmpdir")
+    >>> encoder = MelSpectrogramEncoder.from_hparams(
+    ...     source="speechbrain/tts-ecapa-voxceleb",
+    ...     savedir=tmpdir,
+    ... ) # doctest: +SKIP
+
+    >>> # Compute embedding from a waveform (sample_rate must match the sample rate of the encoder)
+    >>> signal, fs = torchaudio.load("tests/samples/single-mic/example1.wav") # doctest: +SKIP
+    >>> spk_emb = encoder.encode_waveform(signal) # doctest: +SKIP
+
+    >>> # Compute embedding from a mel-spectrogram (sample_rate must match the sample rate of the ecoder)
+    >>> mel_spec = encoder.mel_spectogram(audio=signal) # doctest: +SKIP
+    >>> spk_emb = encoder.encode_mel_spectrogram(mel_spec) # doctest: +SKIP
+
+    >>> # Compute embeddings for a batch of mel-spectrograms
+    >>> spk_embs = encoder.encode_mel_spectrogram_batch(mel_spec) # doctest: +SKIP
+    """
+
+    MODULES_NEEDED = ["normalizer", "embedding_model"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def dynamic_range_compression(self, x, C=1, clip_val=1e-5):
+        """Dynamic range compression for audio signals
+        """
+        return torch.log(torch.clamp(x, min=clip_val) * C)
+
+    def mel_spectogram(self, audio):
+        """calculates MelSpectrogram for a raw audio signal
+
+        Arguments
+        ---------
+        audio : torch.tensor
+            input audio signal
+
+        Returns
+        -------
+        mel : torch.Tensor
+            Mel-spectrogram
+        """
+        from torchaudio import transforms
+
+        audio_to_mel = transforms.MelSpectrogram(
+            sample_rate=self.hparams.sample_rate,
+            hop_length=self.hparams.hop_length,
+            win_length=self.hparams.win_length,
+            n_fft=self.hparams.n_fft,
+            n_mels=self.hparams.n_mel_channels,
+            f_min=self.hparams.mel_fmin,
+            f_max=self.hparams.mel_fmax,
+            power=self.hparams.power,
+            normalized=self.hparams.mel_normalized,
+            norm=self.hparams.norm,
+            mel_scale=self.hparams.mel_scale,
+        ).to(audio.device)
+
+        mel = audio_to_mel(audio)
+
+        if self.hparams.dynamic_range_compression:
+            mel = self.dynamic_range_compression(mel)
+
+        return mel
+
+    def encode_waveform(self, wav):
+        """
+        Encodes a single waveform
+
+        Arguments
+        ---------
+
+        wav : torch.Tensor
+            waveform
+
+        Returns
+        -------
+        encoder_out : torch.Tensor
+            Speaker embedding for the input waveform
+        """
+
+        # Moves tensor to the appropriate device
+        wav = wav.to(self.device)
+
+        # Computes mel-spectrogram
+        mel_spec = self.mel_spectogram(audio=wav)
+
+        # Calls encode_mel_spectrogram to compute the speaker embedding
+        return self.encode_mel_spectrogram(mel_spec)
+
+    def encode_mel_spectrogram(self, mel_spec):
+        """
+        Encodes a single mel-spectrograms
+
+        Arguments
+        ---------
+
+        mel_spec : torch.Tensor
+            Mel-spectrograms
+
+        Returns
+        -------
+        encoder_out : torch.Tensor
+            Speaker embedding for the input mel-spectrogram
+        """
+
+        # Fakes a batch
+        batch = mel_spec
+        if len(mel_spec.shape) == 2:
+            batch = mel_spec.unsqueeze(0)
+        rel_length = torch.tensor([1.0])
+
+        # Calls encode_mel_spectrogram_batch to compute speaker embeddings
+        results = self.encode_mel_spectrogram_batch(batch, rel_length)
+
+        return results
+
+    def encode_mel_spectrogram_batch(self, mel_specs, lens=None):
+        """
+        Encodes a batch of mel-spectrograms
+
+        Arguments
+        ---------
+
+        mel_specs : torch.Tensor
+            Mel-spectrograms
+        lens : torch.Tensor
+            Relative lengths of the mel-spectrograms
+
+        Returns
+        -------
+        encoder_out : torch.Tensor
+            Speaker embedding for the input mel-spectrogram batch
+        """
+
+        # Assigns full length if lens is not assigned
+        if lens is None:
+            lens = torch.ones(mel_specs.shape[0], device=self.device)
+
+        # Moves the tensors to the appropriate device
+        mel_specs, lens = mel_specs.to(self.device), lens.to(self.device)
+
+        # Computes speaker embeddings
+        mel_specs = torch.transpose(mel_specs, 1, 2)
+        feats = self.hparams.normalizer(mel_specs, lens)
+        encoder_out = self.hparams.embedding_model(feats)
+
+        return encoder_out
+
+    def __forward(self, mel_specs, lens):
+        """Runs the encoder"""
+        return self.encode_batch(mel_specs, lens)
diff --git a/speechbrain/inference/enhancement.py b/speechbrain/inference/enhancement.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eb2072ac847f2e1f1b92095dfae40268d8818e6
--- /dev/null
+++ b/speechbrain/inference/enhancement.py
@@ -0,0 +1,186 @@
+""" Specifies the inference interfaces for speech enhancement modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+import torchaudio
+from speechbrain.inference.interfaces import Pretrained
+from speechbrain.utils.callchains import lengths_arg_exists
+
+
+class SpectralMaskEnhancement(Pretrained):
+    """A ready-to-use model for speech enhancement.
+
+    Arguments
+    ---------
+    See ``Pretrained``.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.inference.enhancement import SpectralMaskEnhancement
+    >>> # Model is downloaded from the speechbrain HuggingFace repo
+    >>> tmpdir = getfixture("tmpdir")
+    >>> enhancer = SpectralMaskEnhancement.from_hparams(
+    ...     source="speechbrain/metricgan-plus-voicebank",
+    ...     savedir=tmpdir,
+    ... )
+    >>> enhanced = enhancer.enhance_file(
+    ...     "speechbrain/metricgan-plus-voicebank/example.wav"
+    ... )
+    """
+
+    HPARAMS_NEEDED = ["compute_stft", "spectral_magnitude", "resynth"]
+    MODULES_NEEDED = ["enhance_model"]
+
+    def compute_features(self, wavs):
+        """Compute the log spectral magnitude features for masking.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            A batch of waveforms to convert to log spectral mags.
+        """
+        feats = self.hparams.compute_stft(wavs)
+        feats = self.hparams.spectral_magnitude(feats)
+        return torch.log1p(feats)
+
+    def enhance_batch(self, noisy, lengths=None):
+        """Enhance a batch of noisy waveforms.
+
+        Arguments
+        ---------
+        noisy : torch.Tensor
+            A batch of waveforms to perform enhancement on.
+        lengths : torch.Tensor
+            The lengths of the waveforms if the enhancement model handles them.
+
+        Returns
+        -------
+        torch.Tensor
+            A batch of enhanced waveforms of the same shape as input.
+        """
+        noisy = noisy.to(self.device)
+        noisy_features = self.compute_features(noisy)
+
+        # Perform masking-based enhancement, multiplying output with input.
+        if lengths is not None:
+            mask = self.mods.enhance_model(noisy_features, lengths=lengths)
+        else:
+            mask = self.mods.enhance_model(noisy_features)
+        enhanced = torch.mul(mask, noisy_features)
+
+        # Return resynthesized waveforms
+        return self.hparams.resynth(torch.expm1(enhanced), noisy)
+
+    def enhance_file(self, filename, output_filename=None, **kwargs):
+        """Enhance a wav file.
+
+        Arguments
+        ---------
+        filename : str
+            Location on disk to load file for enhancement.
+        output_filename : str
+            If provided, writes enhanced data to this file.
+        """
+        noisy = self.load_audio(filename, **kwargs)
+        noisy = noisy.to(self.device)
+
+        # Fake a batch:
+        batch = noisy.unsqueeze(0)
+        if lengths_arg_exists(self.enhance_batch):
+            enhanced = self.enhance_batch(batch, lengths=torch.tensor([1.0]))
+        else:
+            enhanced = self.enhance_batch(batch)
+
+        if output_filename is not None:
+            torchaudio.save(
+                uri=output_filename,
+                src=enhanced,
+                sample_rate=self.hparams.compute_stft.sample_rate,
+            )
+
+        return enhanced.squeeze(0)
+
+
+class WaveformEnhancement(Pretrained):
+    """A ready-to-use model for speech enhancement.
+
+    Arguments
+    ---------
+    See ``Pretrained``.
+
+    Example
+    -------
+    >>> from speechbrain.inference.enhancement import WaveformEnhancement
+    >>> # Model is downloaded from the speechbrain HuggingFace repo
+    >>> tmpdir = getfixture("tmpdir")
+    >>> enhancer = WaveformEnhancement.from_hparams(
+    ...     source="speechbrain/mtl-mimic-voicebank",
+    ...     savedir=tmpdir,
+    ... )
+    >>> enhanced = enhancer.enhance_file(
+    ...     "speechbrain/mtl-mimic-voicebank/example.wav"
+    ... )
+    """
+
+    MODULES_NEEDED = ["enhance_model"]
+
+    def enhance_batch(self, noisy, lengths=None):
+        """Enhance a batch of noisy waveforms.
+
+        Arguments
+        ---------
+        noisy : torch.Tensor
+            A batch of waveforms to perform enhancement on.
+        lengths : torch.Tensor
+            The lengths of the waveforms if the enhancement model handles them.
+
+        Returns
+        -------
+        torch.Tensor
+            A batch of enhanced waveforms of the same shape as input.
+        """
+        noisy = noisy.to(self.device)
+        enhanced_wav, _ = self.mods.enhance_model(noisy)
+        return enhanced_wav
+
+    def enhance_file(self, filename, output_filename=None, **kwargs):
+        """Enhance a wav file.
+
+        Arguments
+        ---------
+        filename : str
+            Location on disk to load file for enhancement.
+        output_filename : str
+            If provided, writes enhanced data to this file.
+        """
+        noisy = self.load_audio(filename, **kwargs)
+
+        # Fake a batch:
+        batch = noisy.unsqueeze(0)
+        enhanced = self.enhance_batch(batch)
+
+        if output_filename is not None:
+            torchaudio.save(
+                uri=output_filename,
+                src=enhanced,
+                sample_rate=self.audio_normalizer.sample_rate,
+            )
+
+        return enhanced.squeeze(0)
+
+    def forward(self, noisy, lengths=None):
+        """Runs enhancement on the noisy input"""
+        return self.enhance_batch(noisy, lengths)
diff --git a/speechbrain/inference/interfaces.py b/speechbrain/inference/interfaces.py
new file mode 100644
index 0000000000000000000000000000000000000000..c84d9aad15ac6bb67f25dd070910eebae0254b6e
--- /dev/null
+++ b/speechbrain/inference/interfaces.py
@@ -0,0 +1,702 @@
+"""Defines interfaces for simple inference with pretrained models
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import logging
+import hashlib
+import sys
+import warnings
+import torch
+import torchaudio
+from types import SimpleNamespace
+from torch.nn import SyncBatchNorm
+from torch.nn import DataParallel as DP
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.fetching import fetch
+from speechbrain.dataio.preprocess import AudioNormalizer
+from torch.nn.parallel import DistributedDataParallel as DDP
+from speechbrain.utils.data_utils import split_path
+from speechbrain.utils.distributed import run_on_main
+from speechbrain.dataio.batch import PaddedBatch, PaddedData
+from speechbrain.utils.data_pipeline import DataPipeline
+from speechbrain.utils.superpowers import import_from_path
+
+logger = logging.getLogger(__name__)
+
+
+def foreign_class(
+    source,
+    hparams_file="hyperparams.yaml",
+    pymodule_file="custom.py",
+    classname="CustomInterface",
+    overrides={},
+    overrides_must_match=True,
+    savedir=None,
+    use_auth_token=False,
+    download_only=False,
+    huggingface_cache_dir=None,
+    **kwargs,
+):
+    """Fetch and load an interface from an outside source
+
+    The source can be a location on the filesystem or online/huggingface
+
+    The pymodule file should contain a class with the given classname. An
+    instance of that class is returned. The idea is to have a custom Pretrained
+    subclass in the file. The pymodule file is also added to the python path
+    before the Hyperparams YAML file is loaded, so it can contain any custom
+    implementations that are needed.
+
+    The hyperparams file should contain a "modules" key, which is a
+    dictionary of torch modules used for computation.
+
+    The hyperparams file should contain a "pretrainer" key, which is a
+    speechbrain.utils.parameter_transfer.Pretrainer
+
+    Arguments
+    ---------
+    source : str or Path or FetchSource
+        The location to use for finding the model. See
+        ``speechbrain.pretrained.fetching.fetch`` for details.
+    hparams_file : str
+        The name of the hyperparameters file to use for constructing
+        the modules necessary for inference. Must contain two keys:
+        "modules" and "pretrainer", as described.
+    pymodule_file : str
+        The name of the Python file that should be fetched.
+    classname : str
+        The name of the Class, of which an instance is created and returned
+    overrides : dict
+        Any changes to make to the hparams file when it is loaded.
+    overrides_must_match : bool
+        Whether an error will be thrown when an override does not match
+        a corresponding key in the yaml_stream.
+    savedir : str or Path
+        Where to put the pretraining material. If not given, will use
+        ./pretrained_models/<class-name>-hash(source).
+    use_auth_token : bool (default: False)
+        If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub,
+        default is False because the majority of models are public.
+    download_only : bool (default: False)
+        If true, class and instance creation is skipped.
+    huggingface_cache_dir : str
+        Path to HuggingFace cache; if None -> "~/.cache/huggingface" (default: None)
+
+    Returns
+    -------
+    object
+        An instance of a class with the given classname from the given pymodule file.
+    """
+    if savedir is None:
+        savedir = f"./pretrained_models/{classname}-{hashlib.md5(source.encode('UTF-8', errors='replace')).hexdigest()}"
+    hparams_local_path = fetch(
+        filename=hparams_file,
+        source=source,
+        savedir=savedir,
+        overwrite=False,
+        save_filename=None,
+        use_auth_token=use_auth_token,
+        revision=None,
+        huggingface_cache_dir=huggingface_cache_dir,
+    )
+    pymodule_local_path = fetch(
+        filename=pymodule_file,
+        source=source,
+        savedir=savedir,
+        overwrite=False,
+        save_filename=None,
+        use_auth_token=use_auth_token,
+        revision=None,
+        huggingface_cache_dir=huggingface_cache_dir,
+    )
+    sys.path.append(str(pymodule_local_path.parent))
+
+    # Load the modules:
+    with open(hparams_local_path) as fin:
+        hparams = load_hyperpyyaml(fin, overrides, overrides_must_match)
+
+    # Pretraining:
+    pretrainer = hparams["pretrainer"]
+    pretrainer.set_collect_in(savedir)
+    # For distributed setups, have this here:
+    run_on_main(pretrainer.collect_files, kwargs={"default_source": source})
+    # Load on the CPU. Later the params can be moved elsewhere by specifying
+    if not download_only:
+        # run_opts={"device": ...}
+        pretrainer.load_collected()
+
+        # Import class and create instance
+        module = import_from_path(pymodule_local_path)
+        cls = getattr(module, classname)
+        return cls(modules=hparams["modules"], hparams=hparams, **kwargs)
+
+
+class Pretrained(torch.nn.Module):
+    """Takes a trained model and makes predictions on new data.
+
+    This is a base class which handles some common boilerplate.
+    It intentionally has an interface similar to ``Brain`` - these base
+    classes handle similar things.
+
+    Subclasses of Pretrained should implement the actual logic of how
+    the pretrained system runs, and add methods with descriptive names
+    (e.g. transcribe_file() for ASR).
+
+    Pretrained is a torch.nn.Module so that methods like .to() or .eval() can
+    work. Subclasses should provide a suitable forward() implementation: by
+    convention, it should be a method that takes a batch of audio signals and
+    runs the full model (as applicable).
+
+    Arguments
+    ---------
+    modules : dict of str:torch.nn.Module pairs
+        The Torch modules that make up the learned system. These can be treated
+        in special ways (put on the right device, frozen, etc.). These are available
+        as attributes under ``self.mods``, like self.mods.model(x)
+    hparams : dict
+        Each key:value pair should consist of a string key and a hyperparameter
+        that is used within the overridden methods. These will
+        be accessible via an ``hparams`` attribute, using "dot" notation:
+        e.g., self.hparams.model(x).
+    run_opts : dict
+        Options parsed from command line. See ``speechbrain.parse_arguments()``.
+        List that are supported here:
+         * device
+         * data_parallel_count
+         * data_parallel_backend
+         * distributed_launch
+         * distributed_backend
+         * jit
+         * jit_module_keys
+         * compule
+         * compile_module_keys
+         * compile_mode
+         * compile_using_fullgraph
+         * compile_using_dynamic_shape_tracing
+    freeze_params : bool
+        To freeze (requires_grad=False) parameters or not. Normally in inference
+        you want to freeze the params. Also calls .eval() on all modules.
+    """
+
+    HPARAMS_NEEDED = []
+    MODULES_NEEDED = []
+
+    def __init__(
+        self, modules=None, hparams=None, run_opts=None, freeze_params=True
+    ):
+        super().__init__()
+        # Arguments passed via the run opts dictionary. Set a limited
+        # number of these, since some don't apply to inference.
+        run_opt_defaults = {
+            "device": "cpu",
+            "data_parallel_count": -1,
+            "data_parallel_backend": False,
+            "distributed_launch": False,
+            "distributed_backend": "nccl",
+            "jit": False,
+            "jit_module_keys": None,
+            "compile": False,
+            "compile_module_keys": None,
+            "compile_mode": "reduce-overhead",
+            "compile_using_fullgraph": False,
+            "compile_using_dynamic_shape_tracing": False,
+        }
+        for arg, default in run_opt_defaults.items():
+            if run_opts is not None and arg in run_opts:
+                setattr(self, arg, run_opts[arg])
+            else:
+                # If any arg from run_opt_defaults exist in hparams and
+                # not in command line args "run_opts"
+                if hparams is not None and arg in hparams:
+                    setattr(self, arg, hparams[arg])
+                else:
+                    setattr(self, arg, default)
+
+        # Put modules on the right device, accessible with dot notation
+        self.mods = torch.nn.ModuleDict(modules)
+        for module in self.mods.values():
+            if module is not None:
+                module.to(self.device)
+
+        # Check MODULES_NEEDED and HPARAMS_NEEDED and
+        # make hyperparams available with dot notation
+        if self.HPARAMS_NEEDED and hparams is None:
+            raise ValueError("Need to provide hparams dict.")
+        if hparams is not None:
+            # Also first check that all required params are found:
+            for hp in self.HPARAMS_NEEDED:
+                if hp not in hparams:
+                    raise ValueError(f"Need hparams['{hp}']")
+            self.hparams = SimpleNamespace(**hparams)
+
+        # Prepare modules for computation, e.g. jit
+        self._prepare_modules(freeze_params)
+
+        # Audio normalization
+        self.audio_normalizer = hparams.get(
+            "audio_normalizer", AudioNormalizer()
+        )
+
+    def _prepare_modules(self, freeze_params):
+        """Prepare modules for computation, e.g. jit.
+
+        Arguments
+        ---------
+        freeze_params : bool
+            Whether to freeze the parameters and call ``eval()``.
+        """
+
+        # Make jit-able
+        self._compile()
+        self._wrap_distributed()
+
+        # If we don't want to backprop, freeze the pretrained parameters
+        if freeze_params:
+            self.mods.eval()
+            for p in self.mods.parameters():
+                p.requires_grad = False
+
+    def load_audio(self, path, savedir="."):
+        """Load an audio file with this model's input spec
+
+        When using a speech model, it is important to use the same type of data,
+        as was used to train the model. This means for example using the same
+        sampling rate and number of channels. It is, however, possible to
+        convert a file from a higher sampling rate to a lower one (downsampling).
+        Similarly, it is simple to downmix a stereo file to mono.
+        The path can be a local path, a web url, or a link to a huggingface repo.
+        """
+        source, fl = split_path(path)
+        path = fetch(fl, source=source, savedir=savedir)
+        signal, sr = torchaudio.load(str(path), channels_first=False)
+        return self.audio_normalizer(signal, sr)
+
+    def _compile(self):
+        """Compile requested modules with either JIT or TorchInductor."""
+        compile_available = hasattr(torch, "compile")
+
+        if not compile_available and self.compile_module_keys is not None:
+            raise ValueError(
+                "'compile_module_keys' specified, but this install of PyTorch "
+                "seems to be too old to support it."
+            )
+
+        # Modules to compile with torch.compile
+        compile_module_keys = set()
+        if self.compile:
+            if self.compile_module_keys is None:
+                compile_module_keys = set(self.mods)
+            else:
+                compile_module_keys = set(self.compile_module_keys)
+                logger.warning(
+                    "--compile and --compile_module_keys are both specified. "
+                    "Only modules specified in --compile_module_keys will be compiled."
+                )
+
+        # Modules to compile with jit
+        jit_module_keys = set()
+        if self.jit:
+            if self.jit_module_keys is None:
+                jit_module_keys = set(self.mods)
+            else:
+                jit_module_keys = set(self.jit_module_keys)
+                logger.warning(
+                    "--jit and --jit_module_keys are both specified. "
+                    "Only modules specified in --jit_module_keys will be compiled."
+                )
+
+        # find missing keys
+        for name in compile_module_keys | jit_module_keys:
+            if name not in self.mods:
+                raise ValueError(
+                    f"module {name} is not defined in your hparams file."
+                )
+
+        # try 'torch.compile', remove successful compiles from JIT list
+        for name in compile_module_keys:
+            try:
+                module = torch.compile(
+                    self.mods[name],
+                    mode=self.compile_mode,
+                    fullgraph=self.compile_using_fullgraph,
+                    dynamic=self.compile_using_dynamic_shape_tracing,
+                )
+            except Exception as e:
+                logger.warning(
+                    f"'{name}' in 'compile_module_keys' failed to compile "
+                    f"and will be skipped (may fallback onto JIT, if "
+                    f"specified): {e}"
+                )
+                continue
+
+            self.mods[name] = module.to(self.device)
+            jit_module_keys.discard(name)
+
+        for name in jit_module_keys:
+            module = torch.jit.script(self.mods[name])
+            self.mods[name] = module.to(self.device)
+
+    def _compile_jit(self):
+        warnings.warn("'_compile_jit' is deprecated; use '_compile' instead")
+        self._compile()
+
+    def _wrap_distributed(self):
+        """Wrap modules with distributed wrapper when requested."""
+        if not self.distributed_launch and not self.data_parallel_backend:
+            return
+        elif self.distributed_launch:
+            for name, module in self.mods.items():
+                if any(p.requires_grad for p in module.parameters()):
+                    # for ddp, all module must run on same GPU
+                    module = SyncBatchNorm.convert_sync_batchnorm(module)
+                    module = DDP(module, device_ids=[self.device])
+                    self.mods[name] = module
+        else:
+            # data_parallel_backend
+            for name, module in self.mods.items():
+                if any(p.requires_grad for p in module.parameters()):
+                    # if distributed_count = -1 then use all gpus
+                    # otherwise, specify the set of gpu to use
+                    if self.data_parallel_count == -1:
+                        module = DP(module)
+                    else:
+                        module = DP(
+                            module, [i for i in range(self.data_parallel_count)]
+                        )
+                    self.mods[name] = module
+
+    @classmethod
+    def from_hparams(
+        cls,
+        source,
+        hparams_file="hyperparams.yaml",
+        pymodule_file="custom.py",
+        overrides={},
+        savedir=None,
+        use_auth_token=False,
+        revision=None,
+        download_only=False,
+        huggingface_cache_dir=None,
+        **kwargs,
+    ):
+        """Fetch and load based from outside source based on HyperPyYAML file
+
+        The source can be a location on the filesystem or online/huggingface
+
+        You can use the pymodule_file to include any custom implementations
+        that are needed: if that file exists, then its location is added to
+        sys.path before Hyperparams YAML is loaded, so it can be referenced
+        in the YAML.
+
+        The hyperparams file should contain a "modules" key, which is a
+        dictionary of torch modules used for computation.
+
+        The hyperparams file should contain a "pretrainer" key, which is a
+        speechbrain.utils.parameter_transfer.Pretrainer
+
+        Arguments
+        ---------
+        source : str
+            The location to use for finding the model. See
+            ``speechbrain.pretrained.fetching.fetch`` for details.
+        hparams_file : str
+            The name of the hyperparameters file to use for constructing
+            the modules necessary for inference. Must contain two keys:
+            "modules" and "pretrainer", as described.
+        pymodule_file : str
+            A Python file can be fetched. This allows any custom
+            implementations to be included. The file's location is added to
+            sys.path before the hyperparams YAML file is loaded, so it can be
+            referenced in YAML.
+            This is optional, but has a default: "custom.py". If the default
+            file is not found, this is simply ignored, but if you give a
+            different filename, then this will raise in case the file is not
+            found.
+        overrides : dict
+            Any changes to make to the hparams file when it is loaded.
+        savedir : str or Path
+            Where to put the pretraining material. If not given, will use
+            ./pretrained_models/<class-name>-hash(source).
+        use_auth_token : bool (default: False)
+            If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub,
+            default is False because the majority of models are public.
+        revision : str
+            The model revision corresponding to the HuggingFace Hub model revision.
+            This is particularly useful if you wish to pin your code to a particular
+            version of a model hosted at HuggingFace.
+        download_only : bool (default: False)
+            If true, class and instance creation is skipped.
+        revision : str
+            The model revision corresponding to the HuggingFace Hub model revision.
+            This is particularly useful if you wish to pin your code to a particular
+            version of a model hosted at HuggingFace.
+        huggingface_cache_dir : str
+            Path to HuggingFace cache; if None -> "~/.cache/huggingface" (default: None)
+        """
+        if savedir is None:
+            clsname = cls.__name__
+            savedir = f"./pretrained_models/{clsname}-{hashlib.md5(source.encode('UTF-8', errors='replace')).hexdigest()}"
+        hparams_local_path = fetch(
+            filename=hparams_file,
+            source=source,
+            savedir=savedir,
+            overwrite=False,
+            save_filename=None,
+            use_auth_token=use_auth_token,
+            revision=revision,
+            huggingface_cache_dir=huggingface_cache_dir,
+        )
+        try:
+            pymodule_local_path = fetch(
+                filename=pymodule_file,
+                source=source,
+                savedir=savedir,
+                overwrite=False,
+                save_filename=None,
+                use_auth_token=use_auth_token,
+                revision=revision,
+                huggingface_cache_dir=huggingface_cache_dir,
+            )
+            sys.path.append(str(pymodule_local_path.parent))
+        except ValueError:
+            if pymodule_file == "custom.py":
+                # The optional custom Python module file did not exist
+                # and had the default name
+                pass
+            else:
+                # Custom Python module file not found, but some other
+                # filename than the default was given.
+                raise
+
+        # Load the modules:
+        with open(hparams_local_path) as fin:
+            hparams = load_hyperpyyaml(fin, overrides)
+
+        # Pretraining:
+        pretrainer = hparams["pretrainer"]
+        pretrainer.set_collect_in(savedir)
+        # For distributed setups, have this here:
+        run_on_main(pretrainer.collect_files, kwargs={"default_source": source})
+        # Load on the CPU. Later the params can be moved elsewhere by specifying
+        if not download_only:
+            # run_opts={"device": ...}
+            pretrainer.load_collected()
+
+            # Now return the system
+            return cls(hparams["modules"], hparams, **kwargs)
+
+
+class EncodeDecodePipelineMixin:
+    """
+    A mixin for pretrained models that makes it possible to specify an encoding pipeline and a decoding pipeline
+    """
+
+    def create_pipelines(self):
+        """
+        Initializes the encode and decode pipeline
+        """
+        self._run_init_steps(self.hparams.encode_pipeline)
+        self._run_init_steps(self.hparams.decode_pipeline)
+        self.encode_pipeline = DataPipeline(
+            static_data_keys=self.INPUT_STATIC_KEYS,
+            dynamic_items=self.hparams.encode_pipeline["steps"],
+            output_keys=self.hparams.encode_pipeline["output_keys"],
+        )
+        self.decode_pipeline = DataPipeline(
+            static_data_keys=self.hparams.model_output_keys,
+            dynamic_items=self.hparams.decode_pipeline["steps"],
+            output_keys=self.OUTPUT_KEYS,
+        )
+
+    def _run_init_steps(self, pipeline_definition):
+        """Encode/decode pipelines may include initialization
+        steps, such as filling text encoders with tokens. Calling
+        this method will run them, if defined"""
+        steps = pipeline_definition.get("init", [])
+        for step in steps:
+            step_func = step.get("func")
+            if not step_func or not callable(step_func):
+                raise ValueError("Invalid pipeline init definition")
+            step_func()
+
+    def _run_pipeline(self, pipeline, input, batch):
+        if batch:
+            output = pipeline(input)
+        else:
+            output = [pipeline(item) for item in input]
+        return output
+
+    def _get_encode_pipeline_input(self, input):
+        return input if self.batch_inputs else self._itemize(input)
+
+    def _get_decode_pipeline_input(self, model_output):
+        model_output_keys = getattr(self.hparams, "model_output_keys", None)
+        pipeline_input = model_output
+        if len(model_output_keys) == 1:
+            pipeline_input = (pipeline_input,)
+        # The input to a pipeline is a dictionary. If model_output_keys
+        # is provided, the output of the model is assumed to be a collection
+        # (e.g. a list or a tuple).
+        if model_output_keys:
+            pipeline_input = dict(zip(model_output_keys, pipeline_input))
+
+        # By default, the pipeline will be applied to in batch mode
+        # to the entire model input
+        if not self.batch_outputs:
+            pipeline_input = self._itemize(pipeline_input)
+        return pipeline_input
+
+    def _itemize(self, pipeline_input):
+        first_item = next(iter(pipeline_input.values()))
+        keys, values = pipeline_input.keys(), pipeline_input.values()
+        batch_length = len(first_item)
+        return [
+            dict(zip(keys, [value[idx] for value in values]))
+            for idx in range(batch_length)
+        ]
+
+    def to_dict(self, data):
+        """
+        Converts padded batches to dictionaries, leaves
+        other data types as is
+
+        Arguments
+        ---------
+        data: object
+            a dictionary or a padded batch
+
+        Returns
+        -------
+        results: dict
+            the dictionary
+        """
+        if isinstance(data, PaddedBatch):
+            data = {
+                key: self._get_value(data, key)
+                for key in self.hparams.encode_pipeline["output_keys"]
+            }
+        return data
+
+    def _get_value(self, data, key):
+        """
+        Retrieves the value associated with the specified key, dereferencing
+        .data where applicable
+
+        Arguments
+        ---------
+        data: PaddedBatch
+            a padded batch
+        key: str
+            the key
+
+        Returns
+        -------
+        result: object
+            the result
+        """
+        value = getattr(data, key)
+        if not self.input_use_padded_data and isinstance(value, PaddedData):
+            value = value.data
+        return value
+
+    @property
+    def batch_inputs(self):
+        """
+        Determines whether the input pipeline
+        operates on batches or individual examples
+        (true means batched)
+
+        Returns
+        -------
+        batch_inputs: bool
+        """
+        return self.hparams.encode_pipeline.get("batch", True)
+
+    @property
+    def input_use_padded_data(self):
+        """
+        If turned on, raw PaddedData instances will be passed to
+        the model. If turned off, only .data will be used
+
+        Returns
+        -------
+        result: bool
+            whether padded data is used as is
+        """
+        return self.hparams.encode_pipeline.get("use_padded_data", False)
+
+    @property
+    def batch_outputs(self):
+        """
+        Determines whether the output pipeline
+        operates on batches or individual examples
+        (true means batched)
+
+        Returns
+        -------
+        batch_outputs: bool
+        """
+        return self.hparams.decode_pipeline.get("batch", True)
+
+    def _collate(self, data):
+        if not self.batch_inputs:
+            collate_fn = getattr(self.hparams, "collate_fn", PaddedBatch)
+            data = collate_fn(data)
+        return data
+
+    def encode_input(self, input):
+        """
+        Encodes the inputs using the pipeline
+
+        Arguments
+        ---------
+        input: dict
+            the raw inputs
+
+        Returns
+        -------
+        results: object
+
+        """
+        pipeline_input = self._get_encode_pipeline_input(input)
+        model_input = self._run_pipeline(
+            pipeline=self.encode_pipeline,
+            input=pipeline_input,
+            batch=self.batch_inputs,
+        )
+        model_input = self._collate(model_input)
+        if hasattr(model_input, "to"):
+            model_input = model_input.to(self.device)
+        return self.to_dict(model_input)
+
+    def decode_output(self, output):
+        """
+        Decodes the raw model outputs
+
+        Arguments
+        ---------
+        output: tuple
+            raw model outputs
+
+        Returns
+        -------
+        result: dict or list
+            the output of the pipeline
+        """
+        pipeline_input = self._get_decode_pipeline_input(output)
+        return self._run_pipeline(
+            pipeline=self.decode_pipeline,
+            input=pipeline_input,
+            batch=self.batch_outputs,
+        )
diff --git a/speechbrain/inference/interpretability.py b/speechbrain/inference/interpretability.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6b844db297bd6d813b0ba1963e50e0603637de9
--- /dev/null
+++ b/speechbrain/inference/interpretability.py
@@ -0,0 +1,172 @@
+""" Specifies the inference interfaces for interpretability modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+import torchaudio
+import torch.nn.functional as F
+import speechbrain
+from speechbrain.utils.fetching import fetch
+from speechbrain.utils.data_utils import split_path
+from speechbrain.processing.NMF import spectral_phase
+from speechbrain.inference.interfaces import Pretrained
+
+
+class PIQAudioInterpreter(Pretrained):
+    """
+    This class implements the interface for the PIQ posthoc interpreter for an audio classifier.
+
+    Example
+    -------
+    >>> from speechbrain.inference.interpretability import PIQAudioInterpreter
+    >>> tmpdir = getfixture("tmpdir")
+    >>> interpreter = PIQAudioInterpreter.from_hparams(
+    ...     source="speechbrain/PIQ-ESC50",
+    ...     savedir=tmpdir,
+    ... )
+    >>> signal = torch.randn(1, 16000)
+    >>> interpretation, _ = interpreter.interpret_batch(signal)
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def preprocess(self, wavs):
+        """Pre-process wavs to calculate STFTs"""
+        X_stft = self.mods.compute_stft(wavs)
+        X_stft_power = speechbrain.processing.features.spectral_magnitude(
+            X_stft, power=self.hparams.spec_mag_power
+        )
+        X_stft_logpower = torch.log1p(X_stft_power)
+
+        return X_stft_logpower, X_stft, X_stft_power
+
+    def classifier_forward(self, X_stft_logpower):
+        """the forward pass for the classifier"""
+        hcat = self.mods.embedding_model(X_stft_logpower)
+        embeddings = hcat.mean((-1, -2))
+        predictions = self.mods.classifier(embeddings).squeeze(1)
+        class_pred = predictions.argmax(1)
+        return hcat, embeddings, predictions, class_pred
+
+    def invert_stft_with_phase(self, X_int, X_stft_phase):
+        """Inverts STFT spectra given phase."""
+        X_stft_phase_sb = torch.cat(
+            (
+                torch.cos(X_stft_phase).unsqueeze(-1),
+                torch.sin(X_stft_phase).unsqueeze(-1),
+            ),
+            dim=-1,
+        )
+
+        X_stft_phase_sb = X_stft_phase_sb[:, : X_int.shape[1], :, :]
+        if X_int.ndim == 3:
+            X_int = X_int.unsqueeze(-1)
+        X_wpsb = X_int * X_stft_phase_sb
+        x_int_sb = self.mods.compute_istft(X_wpsb)
+        return x_int_sb
+
+    def interpret_batch(self, wavs):
+        """Classifies the given audio into the given set of labels.
+        It also provides the interpretation in the audio domain.
+
+        Arguments
+        ---------
+        wavs : torch.Tensor
+            Batch of waveforms [batch, time, channels] or [batch, time]
+            depending on the model. Make sure the sample rate is fs=16000 Hz.
+
+        Returns
+        -------
+        x_int_sound_domain
+            The interpretation in the waveform domain
+        text_lab:
+            The text label for the classification
+        fs_model:
+            The sampling frequency of the model. Useful to save the audio.
+        """
+        wavs = wavs.to(self.device)
+        X_stft_logpower, X_stft, X_stft_power = self.preprocess(wavs)
+        X_stft_phase = spectral_phase(X_stft)
+
+        # Embeddings + sound classifier
+        hcat, embeddings, predictions, class_pred = self.classifier_forward(
+            X_stft_logpower
+        )
+
+        if self.hparams.use_vq:
+            xhat, hcat, z_q_x = self.mods.psi(hcat, class_pred)
+        else:
+            xhat = self.mods.psi.decoder(hcat)
+        xhat = xhat.squeeze(1)
+        Tmax = xhat.shape[1]
+        if self.hparams.use_mask_output:
+            xhat = F.sigmoid(xhat)
+            X_int = xhat * X_stft_logpower[:, :Tmax, :]
+        else:
+            xhat = F.softplus(xhat)
+            th = xhat.max() * self.hparams.mask_th
+            X_int = (xhat > th) * X_stft_logpower[:, :Tmax, :]
+        X_int = torch.expm1(X_int)
+        x_int_sound_domain = self.invert_stft_with_phase(X_int, X_stft_phase)
+        text_lab = self.hparams.label_encoder.decode_torch(
+            class_pred.unsqueeze(0)
+        )
+
+        return x_int_sound_domain, text_lab
+
+    def interpret_file(self, path, savedir="audio_cache"):
+        """Classifies the given audiofile into the given set of labels.
+        It also provides the interpretation in the audio domain.
+
+        Arguments
+        ---------
+        path : str
+            Path to audio file to classify.
+
+        Returns
+        -------
+        x_int_sound_domain
+            The interpretation in the waveform domain
+        text_lab:
+            The text label for the classification
+        fs_model:
+            The sampling frequency of the model. Useful to save the audio.
+        """
+        source, fl = split_path(path)
+        path = fetch(fl, source=source, savedir=savedir)
+
+        batch, fs_file = torchaudio.load(path)
+        batch = batch.to(self.device)
+        fs_model = self.hparams.sample_rate
+
+        # resample the data if needed
+        if fs_file != fs_model:
+            print(
+                "Resampling the audio from {} Hz to {} Hz".format(
+                    fs_file, fs_model
+                )
+            )
+            tf = torchaudio.transforms.Resample(
+                orig_freq=fs_file, new_freq=fs_model
+            ).to(self.device)
+            batch = batch.mean(dim=0, keepdim=True)
+            batch = tf(batch)
+
+        x_int_sound_domain, text_lab = self.interpret_batch(batch)
+        return x_int_sound_domain, text_lab, fs_model
+
+    def forward(self, wavs, wav_lens=None):
+        """Runs the classification"""
+        return self.interpret_batch(wavs, wav_lens)
diff --git a/speechbrain/inference/metrics.py b/speechbrain/inference/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..91d6b785a0a1b23967891185e7f6aefef9acd509
--- /dev/null
+++ b/speechbrain/inference/metrics.py
@@ -0,0 +1,95 @@
+""" Specifies the inference interfaces for metric estimation modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+from speechbrain.inference.interfaces import Pretrained
+
+
+class SNREstimator(Pretrained):
+    """A "ready-to-use" SNR estimator."""
+
+    MODULES_NEEDED = ["encoder", "encoder_out"]
+    HPARAMS_NEEDED = ["stat_pooling", "snrmax", "snrmin"]
+
+    def estimate_batch(self, mix, predictions):
+        """Run SI-SNR estimation on the estimated sources, and mixture.
+
+        Arguments
+        ---------
+        mix : torch.Tensor
+            The mixture of sources of shape B X T
+        predictions : torch.Tensor
+            of size (B x T x C),
+            where B is batch size
+                  T is number of time points
+                  C is number of sources
+
+        Returns
+        -------
+        tensor
+            Estimate of SNR
+        """
+
+        predictions = predictions.permute(0, 2, 1)
+        predictions = predictions.reshape(-1, predictions.size(-1))
+
+        if hasattr(self.hparams, "separation_norm_type"):
+            if self.hparams.separation_norm_type == "max":
+                predictions = (
+                    predictions / predictions.max(dim=1, keepdim=True)[0]
+                )
+                mix = mix / mix.max(dim=1, keepdim=True)[0]
+
+            elif self.hparams.separation_norm_type == "stnorm":
+                predictions = (
+                    predictions - predictions.mean(dim=1, keepdim=True)
+                ) / predictions.std(dim=1, keepdim=True)
+                mix = (mix - mix.mean(dim=1, keepdim=True)) / mix.std(
+                    dim=1, keepdim=True
+                )
+
+        min_T = min(predictions.shape[1], mix.shape[1])
+        assert predictions.shape[1] == mix.shape[1], "lengths change"
+
+        mix_repeat = mix.repeat(2, 1)
+        inp_cat = torch.cat(
+            [
+                predictions[:, :min_T].unsqueeze(1),
+                mix_repeat[:, :min_T].unsqueeze(1),
+            ],
+            dim=1,
+        )
+
+        enc = self.mods.encoder(inp_cat)
+        enc = enc.permute(0, 2, 1)
+        enc_stats = self.hparams.stat_pooling(enc)
+
+        # this gets the SI-SNR estimate in the compressed range 0-1
+        snrhat = self.mods.encoder_out(enc_stats).squeeze()
+
+        # get the SI-SNR estimate in the true range
+        snrhat = self.gettrue_snrrange(snrhat)
+        return snrhat
+
+    def forward(self, mix, predictions):
+        """Just run the batch estimate"""
+        return self.estimate_batch(mix, predictions)
+
+    def gettrue_snrrange(self, inp):
+        """Convert from 0-1 range to true snr range"""
+        rnge = self.hparams.snrmax - self.hparams.snrmin
+        inp = inp * rnge
+        inp = inp + self.hparams.snrmin
+        return inp
diff --git a/speechbrain/inference/separation.py b/speechbrain/inference/separation.py
new file mode 100644
index 0000000000000000000000000000000000000000..44a930a8d3bf8977cc4691a33d03ad1dbbcabd02
--- /dev/null
+++ b/speechbrain/inference/separation.py
@@ -0,0 +1,125 @@
+""" Specifies the inference interfaces for speech separation modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+import torchaudio
+import torch.nn.functional as F
+from speechbrain.utils.fetching import fetch
+from speechbrain.utils.data_utils import split_path
+from speechbrain.inference.interfaces import Pretrained
+
+
+class SepformerSeparation(Pretrained):
+    """A "ready-to-use" speech separation model.
+
+    Uses Sepformer architecture.
+
+    Example
+    -------
+    >>> tmpdir = getfixture("tmpdir")
+    >>> model = SepformerSeparation.from_hparams(
+    ...     source="speechbrain/sepformer-wsj02mix",
+    ...     savedir=tmpdir)
+    >>> mix = torch.randn(1, 400)
+    >>> est_sources = model.separate_batch(mix)
+    >>> print(est_sources.shape)
+    torch.Size([1, 400, 2])
+    """
+
+    MODULES_NEEDED = ["encoder", "masknet", "decoder"]
+
+    def separate_batch(self, mix):
+        """Run source separation on batch of audio.
+
+        Arguments
+        ---------
+        mix : torch.Tensor
+            The mixture of sources.
+
+        Returns
+        -------
+        tensor
+            Separated sources
+        """
+
+        # Separation
+        mix = mix.to(self.device)
+        mix_w = self.mods.encoder(mix)
+        est_mask = self.mods.masknet(mix_w)
+        mix_w = torch.stack([mix_w] * self.hparams.num_spks)
+        sep_h = mix_w * est_mask
+
+        # Decoding
+        est_source = torch.cat(
+            [
+                self.mods.decoder(sep_h[i]).unsqueeze(-1)
+                for i in range(self.hparams.num_spks)
+            ],
+            dim=-1,
+        )
+
+        # T changed after conv1d in encoder, fix it here
+        T_origin = mix.size(1)
+        T_est = est_source.size(1)
+        if T_origin > T_est:
+            est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
+        else:
+            est_source = est_source[:, :T_origin, :]
+        return est_source
+
+    def separate_file(self, path, savedir="audio_cache"):
+        """Separate sources from file.
+
+        Arguments
+        ---------
+        path : str
+            Path to file which has a mixture of sources. It can be a local
+            path, a web url, or a huggingface repo.
+        savedir : path
+            Path where to store the wav signals (when downloaded from the web).
+        Returns
+        -------
+        tensor
+            Separated sources
+        """
+        source, fl = split_path(path)
+        path = fetch(fl, source=source, savedir=savedir)
+
+        batch, fs_file = torchaudio.load(path)
+        batch = batch.to(self.device)
+        fs_model = self.hparams.sample_rate
+
+        # resample the data if needed
+        if fs_file != fs_model:
+            print(
+                "Resampling the audio from {} Hz to {} Hz".format(
+                    fs_file, fs_model
+                )
+            )
+            tf = torchaudio.transforms.Resample(
+                orig_freq=fs_file, new_freq=fs_model
+            ).to(self.device)
+            batch = batch.mean(dim=0, keepdim=True)
+            batch = tf(batch)
+
+        est_sources = self.separate_batch(batch)
+        est_sources = (
+            est_sources / est_sources.abs().max(dim=1, keepdim=True)[0]
+        )
+        return est_sources
+
+    def forward(self, mix):
+        """Runs separation on the input mix"""
+        return self.separate_batch(mix)
diff --git a/speechbrain/inference/speaker.py b/speechbrain/inference/speaker.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3cdc1a7b8ba542338cca52343edc10dc2b05b2e
--- /dev/null
+++ b/speechbrain/inference/speaker.py
@@ -0,0 +1,116 @@
+""" Specifies the inference interfaces for speaker recognition modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+from speechbrain.inference.classifiers import EncoderClassifier
+
+
+class SpeakerRecognition(EncoderClassifier):
+    """A ready-to-use model for speaker recognition. It can be used to
+    perform speaker verification with verify_batch().
+
+    ```
+    Example
+    -------
+    >>> import torchaudio
+    >>> from speechbrain.inference.speaker import SpeakerRecognition
+    >>> # Model is downloaded from the speechbrain HuggingFace repo
+    >>> tmpdir = getfixture("tmpdir")
+    >>> verification = SpeakerRecognition.from_hparams(
+    ...     source="speechbrain/spkrec-ecapa-voxceleb",
+    ...     savedir=tmpdir,
+    ... )
+
+    >>> # Perform verification
+    >>> signal, fs = torchaudio.load("tests/samples/single-mic/example1.wav")
+    >>> signal2, fs = torchaudio.load("tests/samples/single-mic/example2.flac")
+    >>> score, prediction = verification.verify_batch(signal, signal2)
+    """
+
+    MODULES_NEEDED = [
+        "compute_features",
+        "mean_var_norm",
+        "embedding_model",
+        "mean_var_norm_emb",
+    ]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
+
+    def verify_batch(
+        self, wavs1, wavs2, wav1_lens=None, wav2_lens=None, threshold=0.25
+    ):
+        """Performs speaker verification with cosine distance.
+
+        It returns the score and the decision (0 different speakers,
+        1 same speakers).
+
+        Arguments
+        ---------
+        wavs1 : Torch.Tensor
+                Tensor containing the speech waveform1 (batch, time).
+                Make sure the sample rate is fs=16000 Hz.
+        wavs2 : Torch.Tensor
+                Tensor containing the speech waveform2 (batch, time).
+                Make sure the sample rate is fs=16000 Hz.
+        wav1_lens: Torch.Tensor
+                Tensor containing the relative length for each sentence
+                in the length (e.g., [0.8 0.6 1.0])
+        wav2_lens: Torch.Tensor
+                Tensor containing the relative length for each sentence
+                in the length (e.g., [0.8 0.6 1.0])
+        threshold: Float
+                Threshold applied to the cosine distance to decide if the
+                speaker is different (0) or the same (1).
+
+        Returns
+        -------
+        score
+            The score associated to the binary verification output
+            (cosine distance).
+        prediction
+            The prediction is 1 if the two signals in input are from the same
+            speaker and 0 otherwise.
+        """
+        emb1 = self.encode_batch(wavs1, wav1_lens, normalize=False)
+        emb2 = self.encode_batch(wavs2, wav2_lens, normalize=False)
+        score = self.similarity(emb1, emb2)
+        return score, score > threshold
+
+    def verify_files(self, path_x, path_y, **kwargs):
+        """Speaker verification with cosine distance
+
+        Returns the score and the decision (0 different speakers,
+        1 same speakers).
+
+        Returns
+        -------
+        score
+            The score associated to the binary verification output
+            (cosine distance).
+        prediction
+            The prediction is 1 if the two signals in input are from the same
+            speaker and 0 otherwise.
+        """
+        waveform_x = self.load_audio(path_x, **kwargs)
+        waveform_y = self.load_audio(path_y, **kwargs)
+        # Fake batches:
+        batch_x = waveform_x.unsqueeze(0)
+        batch_y = waveform_y.unsqueeze(0)
+        # Verify:
+        score, decision = self.verify_batch(batch_x, batch_y)
+        # Squeeze:
+        return score[0], decision[0]
diff --git a/speechbrain/inference/text.py b/speechbrain/inference/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccf1d7ff798b03e5b67ea986eb3574d3bc318df8
--- /dev/null
+++ b/speechbrain/inference/text.py
@@ -0,0 +1,381 @@
+""" Specifies the inference interfaces for text-processing modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import torch
+from itertools import chain
+from speechbrain.inference.interfaces import (
+    Pretrained,
+    EncodeDecodePipelineMixin,
+)
+
+
+class GraphemeToPhoneme(Pretrained, EncodeDecodePipelineMixin):
+    """
+    A pretrained model implementation for Grapheme-to-Phoneme (G2P) models
+    that take raw natural language text as an input and
+
+    Example
+    -------
+    >>> text = ("English is tough. It can be understood "
+    ...         "through thorough thought though")
+    >>> from speechbrain.inference.text import GraphemeToPhoneme
+    >>> tmpdir = getfixture('tmpdir')
+    >>> g2p = GraphemeToPhoneme.from_hparams('path/to/model', savedir=tmpdir) # doctest: +SKIP
+    >>> phonemes = g2p.g2p(text) # doctest: +SKIP
+    """
+
+    INPUT_STATIC_KEYS = ["txt"]
+    OUTPUT_KEYS = ["phonemes"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.create_pipelines()
+        self.load_dependencies()
+
+    @property
+    def phonemes(self):
+        """Returns the available phonemes"""
+        return self.hparams.phonemes
+
+    @property
+    def language(self):
+        """Returns the language for which this model is available"""
+        return self.hparams.language
+
+    def g2p(self, text):
+        """Performs the Grapheme-to-Phoneme conversion
+
+        Arguments
+        ---------
+        text: str or list[str]
+            a single string to be encoded to phonemes - or a
+            sequence of strings
+
+        Returns
+        -------
+        result: list
+            if a single example was provided, the return value is a
+            single list of phonemes
+        """
+        single = isinstance(text, str)
+        if single:
+            text = [text]
+
+        model_inputs = self.encode_input({"txt": text})
+        self._update_graphemes(model_inputs)
+        model_outputs = self.mods.model(**model_inputs)
+        decoded_output = self.decode_output(model_outputs)
+        phonemes = decoded_output["phonemes"]
+        if single:
+            phonemes = phonemes[0]
+        return phonemes
+
+    def _update_graphemes(self, model_inputs):
+        grapheme_sequence_mode = getattr(self.hparams, "grapheme_sequence_mode")
+        if grapheme_sequence_mode and grapheme_sequence_mode != "raw":
+            grapheme_encoded_key = f"grapheme_encoded_{grapheme_sequence_mode}"
+            if grapheme_encoded_key in model_inputs:
+                model_inputs["grapheme_encoded"] = model_inputs[
+                    grapheme_encoded_key
+                ]
+
+    def load_dependencies(self):
+        """Loads any relevant model dependencies"""
+        deps_pretrainer = getattr(self.hparams, "deps_pretrainer", None)
+        if deps_pretrainer:
+            deps_pretrainer.collect_files()
+            deps_pretrainer.load_collected()
+
+    def __call__(self, text):
+        """A convenience callable wrapper - same as G2P
+
+        Arguments
+        ---------
+        text: str or list[str]
+            a single string to be encoded to phonemes - or a
+            sequence of strings
+
+        Returns
+        -------
+        result: list
+            if a single example was provided, the return value is a
+            single list of phonemes
+        """
+        return self.g2p(text)
+
+    def forward(self, noisy, lengths=None):
+        """Runs enhancement on the noisy input"""
+        return self.enhance_batch(noisy, lengths)
+
+
+class ResponseGenerator(Pretrained):
+    """A ready-to-use Response Generator  model
+
+    The class can be used to generate and continue dialogue given the user input.
+    The given YAML must contain the fields specified in the *_NEEDED[] lists.
+    It needs to be used with custom.py to load the expanded  model with added tokens like bos,eos, and speaker's tokens.
+    """
+
+    MODULES_NEEDED = ["model"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        #  Load model
+        self.model = self.hparams.model
+        self.tokenizer = self.model.tokenizer
+        self.history_window = 2 * self.hparams.max_history + 1
+        self.history = []
+
+    def generate_response(self, turn):
+        """
+        Complete a dialogue given the user's input.
+        Arguments
+        ---------
+        turn: str
+            User input which is the last turn of the dialogue.
+
+        Returns
+        -------
+        response
+            Generated response for the user input based on the dialogue history.
+        """
+
+        self.history.append(turn)
+        inputs = self.prepare_input()
+        hyps = self.generate(inputs)
+        predicted_words = self.model.tokenizer.batch_decode(
+            hyps[:, inputs[0].shape[1] :],
+            skip_special_tokens=True,
+            clean_up_tokenization_spaces=True,
+        )
+        response = predicted_words[0]
+        self.history.append(response)
+        return response
+
+    def prepare_input(self):
+        """Users should modify this function according to their own tasks."""
+        raise NotImplementedError
+
+    def generate(self):
+        """Users should modify this function according to their own tasks."""
+        raise NotImplementedError
+
+
+class GPTResponseGenerator(ResponseGenerator):
+    """A ready-to-use Response Generator  model
+
+    The class can be used to generate and continue dialogue given the user input.
+    The given YAML must contain the fields specified in the *_NEEDED[] lists.
+    It needs to be used with custom.py to load the expanded GPT model with added tokens like bos,eos, and speaker's tokens.
+
+    Example
+    -------
+    >>> from speechbrain.inference.text import GPTResponseGenerator
+
+    >>> tmpdir = getfixture("tmpdir")
+    >>> res_gen_model = GPTResponseGenerator.from_hparams(source="speechbrain/MultiWOZ-GPT-Response_Generation",
+    ... savedir="tmpdir",
+    ... pymodule_file="custom.py")  # doctest: +SKIP
+    >>> response = res_gen_model.generate_response("I want to book a table for dinner")  # doctest: +SKIP
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # convert special tokens to their ids
+        (
+            self.bos,
+            self.eos,
+            self.system,
+            self.user,
+        ) = self.model.tokenizer.convert_tokens_to_ids(
+            self.hparams.special_tokens
+        )
+
+    def generate(self, inputs):
+        """
+        Complete a dialogue given the user's input.
+        Arguments
+        ---------
+        inputs: tuple
+            history_bos which is the tokenized history+input values with appropriate speaker token appended before each turn and history_token_type which determines
+            the type of each token basd on who is uttered that token (either User or Sytem).
+
+        Returns
+        -------
+        response
+            Generated hypothesis for the user input based on the dialogue history.
+        """
+
+        history_bos, history_token_type = inputs
+        padding_mask = ~self.hparams.padding_mask(
+            history_bos, pad_idx=self.model.tokenizer.unk_token_id
+        )
+        hyps = self.model.generate(
+            history_bos.detach(),
+            history_token_type.detach(),
+            padding_mask.detach(),
+            "beam",
+        )
+        return hyps
+
+    def prepare_input(self):
+        """Convert user input and previous histories to the format acceptable for  GPT model.
+            It appends all previous history and input and truncates it based on max_history value.
+            It then tokenizes the input and generates additional input that determines the type of each token (Sytem or User).
+
+        Arguments
+        ---------
+
+        Returns
+        -------
+        history_bos:
+            Tokenized history+input values with appropriate speaker token appended before each turn.
+        history_token_type:
+            Type of each token basd on who is uttered that token (either User or Sytem)
+        """
+        history_tokens_lists = [
+            self.model.tokenizer.encode(turn) for turn in self.history
+        ]
+        # add speaker tokens to the history turns (user is even, system is odd)
+        # BEFORE:  [Hi how are you?], [I'm fine, thanks]
+        # AFTER:   [SPK_1 Hi how are you?], [SPK_2 I'm fine, thanks]
+        history_input_lists = [
+            [self.user if i % 2 == 0 else self.system] + encoded_turn
+            for i, encoded_turn in enumerate(history_tokens_lists)
+        ]
+        history_ids = history_input_lists[-self.history_window :]
+        # concatenate every token into a single list
+        # list(chain(*[[1, 2], [3, 4], [5]]))
+        # >>> [1, 2, 3, 4, 5]
+        history_ids = torch.LongTensor(list(chain(*history_ids)))
+        # create bos version for the input
+        history_bos = torch.cat(
+            (torch.tensor([self.bos]), history_ids, torch.tensor([self.system]))
+        )
+        # create a mapping that associates each token in the input to a speaker
+        # INPUT: [SPK_1 Hi    how   are   you? ], [SPK_2 I'm   fine, thanks]
+        # TYPE:  [SPK_1 SPK_1 SPK_1 SPK_1 SPK_1], [SPK_2 SPK_2 SPK_2 SPK_2 ]
+        history_token_type_lists = [
+            [self.user if i % 2 == 0 else self.system] * len(encoded_turn)
+            for i, encoded_turn in enumerate(history_input_lists)
+        ]
+        history_token_type = torch.LongTensor(
+            list(
+                chain(
+                    *(
+                        [[self.system]]
+                        + history_token_type_lists[-self.history_window :]
+                        + [[self.system]]
+                    )
+                )
+            )
+        )
+        return history_bos.unsqueeze(0), history_token_type.unsqueeze(0)
+
+
+class Llama2ResponseGenerator(ResponseGenerator):
+    """A ready-to-use Response Generator  model
+
+    The class can be used to generate and continue dialogue given the user input.
+    The given YAML must contain the fields specified in the *_NEEDED[] lists.
+    It needs to be used with custom.py to load the expanded Llama2 model with added tokens like bos,eos, and speaker's tokens.
+
+    Example
+    -------
+    >>> from speechbrain.inference.text import Llama2ResponseGenerator
+
+    >>> tmpdir = getfixture("tmpdir")
+    >>> res_gen_model = Llama2ResponseGenerator.from_hparams(source="speechbrain/MultiWOZ-Llama2-Response_Generation",
+    ... savedir="tmpdir",
+    ... pymodule_file="custom.py")  # doctest: +SKIP
+    >>> response = res_gen_model.generate_response("I want to book a table for dinner")  # doctest: +SKIP
+    """
+
+    def __init__(self, *args, **kwargs):
+        run_opts = {"device": "cuda"}
+        super().__init__(run_opts=run_opts, *args, **kwargs)
+        # self.model = self.model#.to("cuda")
+
+    def generate(self, inputs):
+        """
+        Complete a dialogue given the user's input.
+        Arguments
+        ---------
+        inputs: prompt_bos
+            prompted imputs to be passed to llama2 model for generation.
+
+        Returns
+        -------
+        response
+            Generated hypothesis for the user input based on the dialogue history.
+        """
+        prompt_bos = inputs[0].to(self.model.model.device)
+        padding_mask = ~self.hparams.padding_mask(
+            prompt_bos, pad_idx=self.tokenizer.pad_token_id
+        )
+        hyps = self.model.generate(
+            prompt_bos.detach(), padding_mask.detach(), "beam",
+        )
+        return hyps
+
+    def prepare_input(self):
+        """Convert user input and previous histories to the format acceptable for  Llama2 model.
+            It appends all previous history and input and truncates it based on max_history value.
+            It then tokenizes the input and add propmts.
+
+        Arguments
+        ---------
+
+        Returns
+        -------
+        prompt_bos:
+            Tokenized history+input values with appropriate prompt.
+        """
+
+        def generate_prompt(idx_and_item):
+            """add [INST] and [/INST] prompt to the start and end ogf item.
+
+            Arguments
+            ---------
+            idx_and_item:
+                id and its corresponding text. If the id is even, it is user turn and [ INST] is added.
+            Returns
+            -------
+            prompt_bos:
+                prompted text  for one item.
+            """
+            index, item = idx_and_item
+            if index % 2 == 0:
+                return "[INST] " + item + " [/INST]"
+            else:
+                return item
+
+        prompts = list(map(generate_prompt, enumerate(self.history)))
+
+        # encode each turn of the history
+        propmt_tokens_lists = [self.tokenizer.encode(turn) for turn in prompts]
+
+        prompt_ids = propmt_tokens_lists[-self.history_window :]
+        # concatenate every token into a single list
+        # list(chain(*[[1, 2], [3, 4], [5]]))
+        # >>> [1, 2, 3, 4, 5]
+        prompt_ids = torch.LongTensor(list(chain(*prompt_ids)))
+        # without bos for lm_labels
+
+        # # create bos version for the input
+        prompt_bos = torch.cat(
+            (torch.tensor([self.tokenizer.bos_token_id]), prompt_ids)
+        )
+        return prompt_bos.unsqueeze(0).unsqueeze(dim=0)
diff --git a/speechbrain/inference/vocoders.py b/speechbrain/inference/vocoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a6f92b67c4429a9a0002391059dd9188e69558
--- /dev/null
+++ b/speechbrain/inference/vocoders.py
@@ -0,0 +1,360 @@
+""" Specifies the inference interfaces for Text-To-Speech (TTS) modules.
+
+Authors:
+ * Aku Rouhe 2021
+ * Peter Plantinga 2021
+ * Loren Lugosch 2020
+ * Mirco Ravanelli 2020
+ * Titouan Parcollet 2021
+ * Abdel Heba 2021
+ * Andreas Nautsch 2022, 2023
+ * Pooneh Mousavi 2023
+ * Sylvain de Langen 2023
+ * Adel Moumen 2023
+ * Pradnya Kandarkar 2023
+"""
+import logging
+import torch
+from speechbrain.dataio.dataio import length_to_mask
+from speechbrain.inference.interfaces import Pretrained
+
+logger = logging.getLogger(__name__)
+
+
+class HIFIGAN(Pretrained):
+    """
+    A ready-to-use wrapper for HiFiGAN (mel_spec -> waveform).
+    Arguments
+    ---------
+    hparams
+        Hyperparameters (from HyperPyYAML)
+    Example
+    -------
+    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
+    >>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder)
+    >>> mel_specs = torch.rand(2, 80,298)
+    >>> waveforms = hifi_gan.decode_batch(mel_specs)
+    >>> # You can use the vocoder coupled with a TTS system
+    >>>	# Initialize TTS (tacotron2)
+    >>> tmpdir_tts = getfixture('tmpdir') / "tts"
+    >>> from speechbrain.inference.TTS import Tacotron2
+    >>>	tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir=tmpdir_tts)
+    >>>	# Running the TTS
+    >>>	mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
+    >>>	# Running Vocoder (spectrogram-to-waveform)
+    >>>	waveforms = hifi_gan.decode_batch(mel_output)
+    """
+
+    HPARAMS_NEEDED = ["generator"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.infer = self.hparams.generator.inference
+        self.first_call = True
+
+    def decode_batch(self, spectrogram, mel_lens=None, hop_len=None):
+        """Computes waveforms from a batch of mel-spectrograms
+        Arguments
+        ---------
+        spectrogram: torch.Tensor
+            Batch of mel-spectrograms [batch, mels, time]
+        mel_lens: torch.tensor
+            A list of lengths of mel-spectrograms for the batch
+            Can be obtained from the output of Tacotron/FastSpeech
+        hop_len: int
+            hop length used for mel-spectrogram extraction
+            should be the same value as in the .yaml file
+        Returns
+        -------
+        waveforms: torch.Tensor
+            Batch of mel-waveforms [batch, 1, time]
+        """
+        # Prepare for inference by removing the weight norm
+        if self.first_call:
+            self.hparams.generator.remove_weight_norm()
+            self.first_call = False
+        with torch.no_grad():
+            waveform = self.infer(spectrogram.to(self.device))
+
+        # Mask the noise caused by padding during batch inference
+        if mel_lens is not None and hop_len is not None:
+            waveform = self.mask_noise(waveform, mel_lens, hop_len)
+
+        return waveform
+
+    def mask_noise(self, waveform, mel_lens, hop_len):
+        """Mask the noise caused by padding during batch inference
+        Arguments
+        ---------
+        wavform: torch.tensor
+            Batch of generated waveforms [batch, 1, time]
+        mel_lens: torch.tensor
+            A list of lengths of mel-spectrograms for the batch
+            Can be obtained from the output of Tacotron/FastSpeech
+        hop_len: int
+            hop length used for mel-spectrogram extraction
+            same value as in the .yaml file
+        Returns
+        -------
+        waveform: torch.tensor
+            Batch of waveforms without padded noise [batch, 1, time]
+        """
+        waveform = waveform.squeeze(1)
+        # the correct audio length should be hop_len * mel_len
+        mask = length_to_mask(
+            mel_lens * hop_len, waveform.shape[1], device=waveform.device
+        ).bool()
+        waveform.masked_fill_(~mask, 0.0)
+        return waveform.unsqueeze(1)
+
+    def decode_spectrogram(self, spectrogram):
+        """Computes waveforms from a single mel-spectrogram
+        Arguments
+        ---------
+        spectrogram: torch.Tensor
+            mel-spectrogram [mels, time]
+        Returns
+        -------
+        waveform: torch.Tensor
+            waveform [1, time]
+        audio can be saved by:
+        >>> import torchaudio
+        >>> waveform = torch.rand(1, 666666)
+        >>> sample_rate = 22050
+        >>> torchaudio.save(str(getfixture('tmpdir') / "test.wav"), waveform, sample_rate)
+        """
+        if self.first_call:
+            self.hparams.generator.remove_weight_norm()
+            self.first_call = False
+        with torch.no_grad():
+            waveform = self.infer(spectrogram.unsqueeze(0).to(self.device))
+        return waveform.squeeze(0)
+
+    def forward(self, spectrogram):
+        "Decodes the input spectrograms"
+        return self.decode_batch(spectrogram)
+
+
+class DiffWaveVocoder(Pretrained):
+    """
+    A ready-to-use inference wrapper for DiffWave as vocoder.
+    The wrapper allows to perform generative tasks:
+        locally-conditional generation: mel_spec -> waveform
+    Arguments
+    ---------
+    hparams
+        Hyperparameters (from HyperPyYAML)
+    """
+
+    HPARAMS_NEEDED = ["diffusion"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        if hasattr(self.hparams, "diffwave"):
+            self.infer = self.hparams.diffusion.inference
+        else:
+            raise NotImplementedError
+
+    def decode_batch(
+        self,
+        mel,
+        hop_len,
+        mel_lens=None,
+        fast_sampling=False,
+        fast_sampling_noise_schedule=None,
+    ):
+        """Generate waveforms from spectrograms
+        Arguments
+        ---------
+        mel: torch.tensor
+            spectrogram [batch, mels, time]
+        hop_len: int
+            Hop length during mel-spectrogram extraction
+            Should be the same value as in the .yaml file
+            Used to determine the output wave length
+            Also used to mask the noise for vocoding task
+        mel_lens: torch.tensor
+            Used to mask the noise caused by padding
+            A list of lengths of mel-spectrograms for the batch
+            Can be obtained from the output of Tacotron/FastSpeech
+        fast_sampling: bool
+            whether to do fast sampling
+        fast_sampling_noise_schedule: list
+            the noise schedules used for fast sampling
+        Returns
+        -------
+        waveforms: torch.tensor
+            Batch of mel-waveforms [batch, 1, time]
+
+        """
+        with torch.no_grad():
+            waveform = self.infer(
+                unconditional=False,
+                scale=hop_len,
+                condition=mel.to(self.device),
+                fast_sampling=fast_sampling,
+                fast_sampling_noise_schedule=fast_sampling_noise_schedule,
+            )
+
+        # Mask the noise caused by padding during batch inference
+        if mel_lens is not None and hop_len is not None:
+            waveform = self.mask_noise(waveform, mel_lens, hop_len)
+        return waveform
+
+    def mask_noise(self, waveform, mel_lens, hop_len):
+        """Mask the noise caused by padding during batch inference
+        Arguments
+        ---------
+        wavform: torch.tensor
+            Batch of generated waveforms [batch, 1, time]
+        mel_lens: torch.tensor
+            A list of lengths of mel-spectrograms for the batch
+            Can be obtained from the output of Tacotron/FastSpeech
+        hop_len: int
+            hop length used for mel-spectrogram extraction
+            same value as in the .yaml file
+        Returns
+        -------
+        waveform: torch.tensor
+            Batch of waveforms without padded noise [batch, 1, time]
+        """
+        waveform = waveform.squeeze(1)
+        # the correct audio length should be hop_len * mel_len
+        mask = length_to_mask(
+            mel_lens * hop_len, waveform.shape[1], device=waveform.device
+        ).bool()
+        waveform.masked_fill_(~mask, 0.0)
+        return waveform.unsqueeze(1)
+
+    def decode_spectrogram(
+        self,
+        spectrogram,
+        hop_len,
+        fast_sampling=False,
+        fast_sampling_noise_schedule=None,
+    ):
+        """Computes waveforms from a single mel-spectrogram
+        Arguments
+        ---------
+        spectrogram: torch.tensor
+            mel-spectrogram [mels, time]
+        hop_len: int
+            hop length used for mel-spectrogram extraction
+            same value as in the .yaml file
+        fast_sampling: bool
+            whether to do fast sampling
+        fast_sampling_noise_schedule: list
+            the noise schedules used for fast sampling
+        Returns
+        -------
+        waveform: torch.tensor
+            waveform [1, time]
+
+        audio can be saved by:
+        >>> import torchaudio
+        >>> waveform = torch.rand(1, 666666)
+        >>> sample_rate = 22050
+        >>> torchaudio.save(str(getfixture('tmpdir') / "test.wav"), waveform, sample_rate)
+        """
+        with torch.no_grad():
+            waveform = self.infer(
+                unconditional=False,
+                scale=hop_len,
+                condition=spectrogram.unsqueeze(0).to(self.device),
+                fast_sampling=fast_sampling,
+                fast_sampling_noise_schedule=fast_sampling_noise_schedule,
+            )
+        return waveform.squeeze(0)
+
+    def forward(self, spectrogram):
+        """Decodes the input spectrograms"""
+        return self.decode_batch(spectrogram)
+
+
+class UnitHIFIGAN(Pretrained):
+    """
+    A ready-to-use wrapper for Unit HiFiGAN (discrete units -> waveform).
+    Arguments
+    ---------
+    hparams
+        Hyperparameters (from HyperPyYAML)
+    Example
+    -------
+    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
+    >>> hifi_gan = UnitHIFIGAN.from_hparams(source="speechbrain/tts-hifigan-unit-hubert-l6-k100-ljspeech", savedir=tmpdir_vocoder)
+    >>> codes = torch.randint(0, 99, (100,))
+    >>> waveform = hifi_gan.decode_unit(codes)
+    """
+
+    HPARAMS_NEEDED = ["generator"]
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.infer = self.hparams.generator.inference
+        self.first_call = True
+        # Temporary fix for mapping indices from the range [0, k] to [1, k+1]
+        self.tokenize = True
+
+    def decode_batch(self, units):
+        """Computes waveforms from a batch of discrete units
+        Arguments
+        ---------
+        units: torch.tensor
+            Batch of discrete units [batch, codes]
+        Returns
+        -------
+        waveforms: torch.tensor
+            Batch of mel-waveforms [batch, 1, time]
+        """
+        # Remove weight norm for inference if it's the first call
+        if self.first_call:
+            self.hparams.generator.remove_weight_norm()
+            self.first_call = False
+
+        # Ensure that the units sequence has a length of at least 4
+        if units.size(1) < 4:
+            raise RuntimeError(
+                "The 'units' argument should have a length of at least 4 because of padding size."
+            )
+
+        # Increment units if tokenization is enabled
+        if self.tokenize:
+            # Avoid changing the input in-place
+            units = units + 1
+        with torch.no_grad():
+            waveform = self.infer(units.to(self.device))
+        return waveform
+
+    def decode_unit(self, units):
+        """Computes waveforms from a single sequence of discrete units
+        Arguments
+        ---------
+        units: torch.tensor
+            codes: [time]
+        Returns
+        -------
+        waveform: torch.tensor
+            waveform [1, time]
+        """
+        # Remove weight norm for inference if it's the first call
+        if self.first_call:
+            self.hparams.generator.remove_weight_norm()
+            self.first_call = False
+
+        # Ensure that the units sequence has a length of at least 4
+        if units.size(0) < 4:
+            raise RuntimeError(
+                "The 'units' argument should have a length of at least 4 because of padding size."
+            )
+
+        # Increment units if tokenization is enabled
+        if self.tokenize:
+            # Avoid changing the input in-place
+            units = units + 1
+        with torch.no_grad():
+            waveform = self.infer(units.unsqueeze(0).to(self.device))
+        return waveform.squeeze(0)
+
+    def forward(self, units):
+        "Decodes the input units"
+        return self.decode_batch(units)
diff --git a/speechbrain/k2_integration/__init__.py b/speechbrain/k2_integration/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b0111f296cefd8ce178eded97d330dd4227e2f
--- /dev/null
+++ b/speechbrain/k2_integration/__init__.py
@@ -0,0 +1,34 @@
+"""
+Package providing `k2-fsa <https://github.com/k2-fsa/k2>`_ integration.
+
+Intended loading manner:
+
+    >>> import speechbrain.k2_integration as sbk2
+    >>> # Then use: sbk2.graph_compiler.CtcGraphCompiler for example
+
+"""
+
+
+__all__ = [
+    "k2",
+    "utils",
+    "graph_compiler",
+    "lattice_decoder",
+    "lexicon",
+    "losses",
+    "prepare_lang",
+]
+
+try:
+    import k2
+except ImportError:
+    MSG = "Please install k2 to use k2\n"
+    MSG += "Checkout: https://k2-fsa.github.io/k2/installation/from_wheels.html"
+    raise ImportError(MSG)
+
+from . import utils
+from . import graph_compiler
+from . import lattice_decoder
+from . import lexicon
+from . import losses
+from . import prepare_lang
diff --git a/speechbrain/k2_integration/graph_compiler.py b/speechbrain/k2_integration/graph_compiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e444419ebdcd785ce00c102e41c2a023b8023b
--- /dev/null
+++ b/speechbrain/k2_integration/graph_compiler.py
@@ -0,0 +1,379 @@
+"""Graph compiler class to create, store, and use k2 decoding graphs in
+speechbrain. Limits the output words to the ones in the lexicon.
+
+This code is an extension, and therefore heavily inspired or taken from
+icefall's (https://github.com/k2-fsa/icefall) graph compiler.
+
+Authors:
+  * Pierre Champion 2023
+  * Zeyu Zhao 2023
+  * Georgios Karakasidis 2023
+"""
+
+
+import os
+from typing import List, Optional, Tuple
+import abc
+import torch
+import logging
+
+from . import k2  # import k2 from ./__init__.py
+from . import lexicon
+
+logger = logging.getLogger(__name__)
+
+
+class GraphCompiler(abc.ABC):
+    """
+    This abstract class is used to compile graphs for training and decoding.
+    """
+
+    @abc.abstractproperty
+    def topo(self) -> k2.Fsa:
+        """
+        Return the topology used to compile the graph.
+        """
+        pass
+
+    @abc.abstractproperty
+    def lexicon(self) -> lexicon.Lexicon:
+        """
+        Return the lexicon used to compile the graph.
+        """
+        pass
+
+    @abc.abstractproperty
+    def device(self):
+        """
+        Return the device used to compile the graph.
+        """
+        pass
+
+    @abc.abstractmethod
+    def compile(
+        self, texts: List[str], is_training: bool = True
+    ) -> Tuple[k2.Fsa, torch.Tensor]:
+        """
+        Compile the graph for the given texts.
+
+        Arguments
+        ---------
+        texts: List[str]
+            A list of strings. Each string contains a sentence for an utterance.
+            A sentence consists of spaces separated words. An example `texts`
+            looks like:
+
+                ['hello world', 'CTC training with k2']
+
+        is_training: bool
+            Indictating whether this is for training or not
+            (OOV warning in training).
+        Returns
+        -------
+        graph: GraphCompiler
+            An FsaVec, the composition result of `self.ctc_topo` and the
+            transcript FSA.
+        target_lens: Torch.tensor
+            It is an long tensor of shape (batch,). It contains lengths of
+            each target sequence.
+        """
+        pass
+
+    def compile_HL(self, cache_dir: Optional[str] = None, cache: bool = False):
+        """
+        Compile the decoding graph by composing H with L.
+        This is for decoding without language model.
+
+        Arguments
+        ---------
+        cache_dir: str
+            The path to store the composition in a .pt format.
+        cache: bool
+            Whether or not to load the composition from the .pt format (in the
+            cache_dir dir).
+
+        Returns
+        -------
+        HL: k2.Fsa
+            The HL composition
+        """
+        logger.info("Arc sorting L")
+        L = k2.arc_sort(self.lexicon.L).to("cpu")
+        H = self.topo.to("cpu")
+
+        file_hash = str(hash(H.shape[0])) + str(hash(L.shape[0]))
+        if cache and cache_dir is not None:
+            path = cache_dir + "/.HL_" + file_hash + ".pt"
+            if os.path.exists(path):
+                logger.warning(
+                    f"Loading HL '{path}' from its cached .pt format."
+                    " Set 'caching: False' in the yaml"
+                    " if this is not what you want."
+                )
+                HL = k2.Fsa.from_dict(torch.load(path, map_location="cpu"))
+                return HL
+
+        logger.info("Composing H and L")
+        HL = k2.compose(H, L, inner_labels="tokens")
+
+        logger.info("Connecting HL")
+        HL = k2.connect(HL)
+
+        logger.info("Arc sorting HL")
+        HL = k2.arc_sort(HL)
+        logger.debug(f"HL.shape: {HL.shape}")
+
+        if cache_dir is not None:
+            path = cache_dir + "/.HL_" + file_hash + ".pt"
+            logger.info("Caching HL to: " + path)
+            torch.save(HL.as_dict(), path)
+
+        return HL
+
+    def compile_HLG(
+        self, G, cache_dir: Optional[str] = None, cache: bool = False
+    ):
+        """
+        Compile the decoding graph by composing H with LG.
+        This is for decoding with small language model.
+
+        Arguments
+        ---------
+        G: k2.Fsa
+            The language model FSA.
+        cache_dir: str
+            The path to store the composition in a .pt format.
+        cache: bool
+            Whether or not to load the composition from the .pt format (in the
+            cache_dir dir).
+
+        Returns
+        -------
+        HL: k2.Fsa
+            The HLG composition
+        """
+        logger.info("Arc sorting L")
+        L = k2.arc_sort(self.lexicon.L_disambig).to("cpu")
+        G = k2.arc_sort(G).to("cpu")
+        H = self.topo.to("cpu")
+
+        file_hash = (
+            str(hash(H.shape[0]))
+            + str(hash(L.shape[0]))
+            + str(hash(G.shape[0]))
+        )
+        if cache and cache_dir is not None:
+            path = cache_dir + "/.HLG_" + file_hash + ".pt"
+            if os.path.exists(path):
+                logger.warning(
+                    f"Loading HLG '{path}' from its cached .pt format."
+                    " Set 'caching: False' in the yaml"
+                    " if this is not what you want."
+                )
+                HLG = k2.Fsa.from_dict(torch.load(path, map_location="cpu"))
+                return HLG
+
+        logger.info("Intersecting L and G")
+        LG = k2.compose(L, G)
+
+        logger.info("Connecting LG")
+        LG = k2.connect(LG)
+
+        logger.info("Determinizing LG")
+        LG = k2.determinize(LG)
+
+        logger.info("Connecting LG after k2.determinize")
+        LG = k2.connect(LG)
+        LG = self.lexicon.remove_LG_disambig_symbols(LG)
+
+        LG = k2.remove_epsilon(LG)
+
+        LG = k2.connect(LG)
+        LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+        logger.info("Arc sorting LG")
+        LG = k2.arc_sort(LG)
+
+        logger.info("Composing H and LG")
+        HLG = k2.compose(H, LG, inner_labels="tokens")
+
+        logger.info("Connecting HLG")
+        HLG = k2.connect(HLG)
+
+        logger.info("Arc sorting HLG")
+        HLG = k2.arc_sort(HLG)
+        logger.debug(f"HLG.shape: {HLG.shape}")
+
+        if cache_dir is not None:
+            path = cache_dir + "/.HLG_" + file_hash + ".pt"
+            logger.info("Caching HLG to: " + path)
+            torch.save(HLG.as_dict(), path)
+
+        return HLG
+
+
+class CtcGraphCompiler(GraphCompiler):
+    """
+    This class is used to compile decoding graphs for CTC training.
+
+    Arguments
+    ---------
+    lexicon: Lexicon
+        It is built from `data/lang/lexicon.txt`.
+    device: torch.device
+        The device to use for operations compiling transcripts to FSAs.
+    need_repeat_flag: bool
+        If True, will add an attribute named `_is_repeat_token_` to ctc_topo
+        indicating whether this token is a repeat token in ctc graph.
+        This attribute is needed to implement delay-penalty for phone-based
+        ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
+        details. Note: The above change MUST be included in k2 to enable this
+        flag so make sure you have an up-to-date version.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.k2_integration.losses import ctc_k2
+    >>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
+    >>> from speechbrain.k2_integration.lexicon import Lexicon
+    >>> from speechbrain.k2_integration.prepare_lang import prepare_lang
+
+    >>> # Create a random batch of log-probs
+    >>> batch_size = 4
+
+    >>> log_probs = torch.randn(batch_size, 100, 30)
+    >>> log_probs.requires_grad = True
+    >>> # Assume all utterances have the same length so no padding was needed.
+    >>> input_lens = torch.ones(batch_size)
+    >>> # Create a samll lexicon containing only two words and write it to a file.
+    >>> lang_tmpdir = getfixture('tmpdir')
+    >>> lexicon_sample = "hello h e l l o\\nworld w o r l d\\n<UNK> <unk>"
+    >>> lexicon_file = lang_tmpdir.join("lexicon.txt")
+    >>> lexicon_file.write(lexicon_sample)
+    >>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
+    >>> prepare_lang(lang_tmpdir)
+    >>> # Create a lexicon object
+    >>> lexicon = Lexicon(lang_tmpdir)
+    >>> # Create a random decoding graph
+    >>> graph = CtcGraphCompiler(
+    ...     lexicon,
+    ...     log_probs.device,
+    ... )
+    >>> isinstance(graph.topo, k2.Fsa)
+    True
+
+    """
+
+    def __init__(
+        self,
+        _lexicon: lexicon.Lexicon,
+        device: torch.device,
+        need_repeat_flag: bool = False,
+    ):
+        self._device = device
+
+        self._lexicon = _lexicon
+        self.lexicon.to(device)
+        assert self.lexicon.L_inv.requires_grad is False
+        self.lexicon.arc_sort()
+
+        max_token_id = max(self.lexicon.tokens)
+        ctc_topo = k2.ctc_topo(max_token_id, modified=False)
+
+        self.ctc_topo = ctc_topo.to(device)
+
+        if need_repeat_flag:
+            self.ctc_topo._is_repeat_token_ = (
+                self.ctc_topo.labels != self.ctc_topo.aux_labels
+            )
+
+    @property
+    def topo(self):
+        """
+        Return the ctc_topo.
+        """
+        return self.ctc_topo
+
+    @property
+    def lexicon(self):
+        """
+        Return the lexicon.
+        """
+        return self._lexicon
+
+    @property
+    def device(self):
+        """Return the device used for compiling graphs."""
+        return self._device
+
+    def compile(
+        self, texts: List[str], is_training: bool = True
+    ) -> Tuple[k2.Fsa, torch.Tensor]:
+        """
+        Build decoding graphs by composing ctc_topo with given transcripts.
+
+        Arguments
+        ---------
+        texts: List[str]
+            A list of strings. Each string contains a sentence for an utterance.
+            A sentence consists of spaces separated words. An example `texts`
+            looks like:
+
+                ['hello world', 'CTC training with k2']
+
+        is_training: bool
+            Indictating whether this is for training or not
+            (OOV warning in training).
+
+        Returns
+        -------
+        graph: GraphCompiler
+            An FsaVec, the composition result of `self.ctc_topo` and the
+            transcript FSA.
+        target_lens: Torch.tensor
+            It is an long tensor of shape (batch,). It contains lengths of
+            each target sequence.
+        """
+
+        word_idx = self.lexicon.texts_to_word_ids(
+            texts, log_unknown_warning=is_training
+        )
+
+        # ["test", "testa"] -> [[23, 8, 22, 23], [23, 8, 22, 23, 5]] -> [4, 5]
+        word2tids = self.lexicon.texts_to_token_ids(
+            texts, log_unknown_warning=is_training
+        )
+        scentence_ids = [sum(inner, []) for inner in word2tids]
+
+        target_lens = torch.tensor(
+            [len(t) for t in scentence_ids], dtype=torch.long
+        )
+
+        word_fsa_with_self_loops = k2.add_epsilon_self_loops(
+            k2.linear_fsa(word_idx, self.device)
+        )
+
+        fsa = k2.intersect(
+            self.lexicon.L_inv,
+            word_fsa_with_self_loops,
+            treat_epsilons_specially=False,
+        )
+        # fsa has word ID as labels and token ID as aux_labels, so
+        # we need to invert it
+        ans_fsa = fsa.invert_()
+        transcript_fsa = k2.arc_sort(ans_fsa)
+
+        # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
+        # is False, so we add epsilon self-loops here
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
+            transcript_fsa
+        )
+
+        fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
+
+        graph = k2.compose(
+            self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
+        )
+
+        assert graph.requires_grad is False
+
+        return graph, target_lens
diff --git a/speechbrain/k2_integration/lattice_decoder.py b/speechbrain/k2_integration/lattice_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..14a152370cfda5e3cf53d5d0e9364d4914737bd8
--- /dev/null
+++ b/speechbrain/k2_integration/lattice_decoder.py
@@ -0,0 +1,442 @@
+"""Different decoding graph algorithms for k2, be it HL or HLG (with G LM
+and bigger rescoring LM).
+
+This code was adjusted from icefall (https://github.com/k2-fsa/icefall/blob/master/icefall/decode.py).
+
+
+Authors:
+  * Pierre Champion 2023
+  * Zeyu Zhao 2023
+  * Georgios Karakasidis 2023
+"""
+
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+from collections import OrderedDict
+
+from . import k2  # import k2 from ./__init__.py
+from speechbrain.utils.distributed import run_on_main
+from speechbrain.lm.arpa import arpa_to_fst
+
+import torch
+import logging
+
+from . import graph_compiler, utils
+
+logger = logging.getLogger(__name__)
+
+
+def get_decoding(
+    hparams: Dict, graphCompiler: graph_compiler.GraphCompiler, device="cpu"
+):
+    """
+    This function reads a config and creates the decoder for k2 graph compiler
+    decoding.
+    There are the following cases:
+        - HLG is compiled and LM rescoring is used. In that case,
+          compose_HL_with_G and use_G_rescoring are both True and we will
+          create for example G_3_gram.fst.txt and G_4_gram.fst.txt. Note that
+          the 3gram and 4gram ARPA lms will need to exist under
+          `hparams['lm_dir']`.
+        - HLG is compiled but LM rescoring is not used. In that case,
+          compose_HL_with_G is True and use_G_rescoring is False and we will
+          create for example G_3_gram.fst.txt. Note that the 3gram ARPA lm will
+          need to exist under `hparams['lm_dir']`.
+        - HLG is not compiled (only use HL graph) and LM rescoring used.
+          In that case, compose_HL_with_G is False and use_G_rescoring is True.
+          Note that the 4gram ARPA lms will need to exist under
+          `hparams['lm_dir']`.
+        - HLG is not compiled (only use HL graph) and LM rescoring is not used.
+          In that case, compose_HL_with_G is False and use_G_rescoring is False
+          and we will not convert LM to FST.
+
+    Arguments
+    ---------
+    hparams: dict
+        The hyperparameters.
+    graphCompiler: graph_compiler.GraphCompiler
+        The graphCompiler (H)
+    device : torch.device
+        The device to use.
+
+    Returns
+    -------
+    Dict:
+        decoding_graph: k2.Fsa
+            A HL or HLG decoding graph.
+            Used with a nnet output and the function `get_lattice` to
+            obtain a decoding lattice `k2.Fsa`.
+        decoding_method: Callable[[k2.Fsa], k2.Fsa]
+            A function to call with a decoding lattice `k2.Fsa` (obtained
+            after nnet output intersect with a HL or HLG).
+            Retuns an FsaVec containing linear FSAs
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.k2_integration.losses import ctc_k2
+    >>> from speechbrain.k2_integration.utils import lattice_paths_to_text
+    >>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
+    >>> from speechbrain.k2_integration.lexicon import Lexicon
+    >>> from speechbrain.k2_integration.prepare_lang import prepare_lang
+    >>> from speechbrain.k2_integration.lattice_decoder import get_decoding
+    >>> from speechbrain.k2_integration.lattice_decoder import get_lattice
+
+    >>> batch_size = 1
+
+    >>> log_probs = torch.randn(batch_size, 40, 10)
+    >>> log_probs.requires_grad = True
+    >>> # Assume all utterances have the same length so no padding was needed.
+    >>> input_lens = torch.ones(batch_size)
+    >>> # Create a samll lexicon containing only two words and write it to a file.
+    >>> lang_tmpdir = getfixture('tmpdir')
+    >>> lexicon_sample = "hello h e l l o\\nworld w o r l d\\n<UNK> <unk>"
+    >>> lexicon_file = lang_tmpdir.join("lexicon.txt")
+    >>> lexicon_file.write(lexicon_sample)
+    >>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
+    >>> prepare_lang(lang_tmpdir)
+    >>> # Create a lexicon object
+    >>> lexicon = Lexicon(lang_tmpdir)
+    >>> # Create a random decoding graph
+    >>> graph = CtcGraphCompiler(
+    ...     lexicon,
+    ...     log_probs.device,
+    ... )
+
+    >>> decode = get_decoding(
+    ...     {"compose_HL_with_G": False,
+    ...      "decoding_method": "onebest",
+    ...      "lang_dir": lang_tmpdir},
+    ...     graph)
+    >>> lattice = get_lattice(log_probs, input_lens, decode["decoding_graph"])
+    >>> path = decode["decoding_method"](lattice)['1best']
+    >>> text = lattice_paths_to_text(path, lexicon.word_table)
+    """
+
+    compose_HL_with_G = hparams.get("compose_HL_with_G")
+    use_G_rescoring = (
+        hparams.get("decoding_method") == "whole-lattice-rescoring"
+    )
+
+    caching = (
+        False if "caching" in hparams and hparams["caching"] is False else True
+    )
+
+    if compose_HL_with_G or use_G_rescoring:
+        lm_dir = Path(hparams["lm_dir"])
+        G_path = lm_dir / (hparams["G_arpa"].replace("arpa", "fst.txt"))
+        G_rescoring_path = (
+            lm_dir / (hparams["G_rescoring_arpa"].replace("arpa", "fst.txt"))
+            if use_G_rescoring
+            else None
+        )
+        if compose_HL_with_G:
+            run_on_main(
+                arpa_to_fst,
+                kwargs={
+                    "words_txt": Path(hparams["lang_dir"]) / "words.txt",
+                    "in_arpa": lm_dir / hparams["G_arpa"],
+                    "out_fst": G_path,
+                    "ngram_order": 3,  # by default use 3-gram for HLG's LM
+                    "cache": caching,
+                },
+            )
+        if use_G_rescoring:
+            run_on_main(
+                arpa_to_fst,
+                kwargs={
+                    "words_txt": Path(hparams["lang_dir"]) / "words.txt",
+                    "in_arpa": lm_dir / hparams["G_rescoring_arpa"],
+                    "out_fst": G_rescoring_path,
+                    "ngram_order": 4,  # by default use 4-gram for rescoring LM
+                    "cache": caching,
+                },
+            )
+
+    output_folder = None
+    if "output_folder" in hparams:
+        output_folder = output_folder
+
+    if compose_HL_with_G:
+        G = utils.load_G(G_path, cache=caching)
+        decoding_graph = graphCompiler.compile_HLG(
+            G, cache_dir=output_folder, cache=caching
+        )
+    else:
+        decoding_graph = graphCompiler.compile_HL(
+            cache_dir=output_folder, cache=caching
+        )
+
+    if hparams.get("decoding_method") == "whole-lattice-rescoring":
+        G_rescoring = None
+        if not isinstance(hparams["rescoring_lm_scale"], list):
+            hparams["rescoring_lm_scale"] = [hparams["rescoring_lm_scale"]]
+
+        def decoding_method(lattice: k2.Fsa) -> Dict[str, k2.Fsa]:
+            """Get the best path from a lattice given rescoring_lm_scale."""
+
+            # Lazy load rescoring G (takes a lot of time) for developer happiness
+            nonlocal G_rescoring
+            if G_rescoring is None:
+                logger.info("Decoding method: whole-lattice-rescoring")
+                logger.info(f"Loading rescoring LM: {G_rescoring_path}")
+                G_rescoring_pt = utils.load_G(G_rescoring_path, cache=caching)
+                graphCompiler.lexicon.remove_G_rescoring_disambig_symbols(
+                    G_rescoring_pt
+                )
+                G_rescoring = utils.prepare_rescoring_G(G_rescoring_pt)
+
+            # rescore_with_whole_lattice returns a list of paths depending on
+            # lm_scale values.
+            return rescore_with_whole_lattice(
+                lattice,
+                G_rescoring,
+                lm_scale_list=hparams["rescoring_lm_scale"],
+            )
+
+    elif hparams.get("decoding_method") in ["1best", "onebest"]:
+        logger.info("Decoding method: one-best-decoding")
+
+        def decoding_method(lattice: k2.Fsa) -> Dict[str, k2.Fsa]:
+            """Get the best path from a lattice."""
+            return OrderedDict({"1best": one_best_decoding(lattice)})
+
+    else:
+
+        def decoding_method(lattice: k2.Fsa):
+            """A dummy decoding method that raises an error."""
+            raise NotImplementedError(
+                f"{hparams.get('decoding_method')} not implemented as a decoding_method"
+            )
+
+    return {
+        "decoding_graph": decoding_graph.to(device),
+        "decoding_method": decoding_method,
+    }
+
+
+@torch.no_grad()
+def get_lattice(
+    log_probs_nnet_output: torch.Tensor,
+    input_lens: torch.Tensor,
+    decoder: k2.Fsa,
+    search_beam: int = 5,
+    output_beam: int = 5,
+    min_active_states: int = 300,
+    max_active_states: int = 1000,
+    ac_scale: float = 1.0,
+    subsampling_factor: int = 1,
+) -> k2.Fsa:
+    """
+    Get the decoding lattice from a decoding graph and neural network output.
+
+    Arguments
+    ---------
+    log_probs_nnet_output:
+        It is the output of a neural model of shape `(batch, seq_len, num_tokens)`.
+    input_lens:
+        It is an int tensor of shape (batch,). It contains lengths of
+        each sequence in `log_probs_nnet_output`.
+    decoder:
+        It is an instance of :class:`k2.Fsa` that represents the decoding graph.
+    search_beam:
+        Decoding beam, e.g. 20.  Ger is faster, larger is more exact
+        (less pruning). This is the default value; it may be modified by
+        `min_active_states` and `max_active_states`.
+    output_beam:
+         Beam to prune output, similar to lattice-beam in Kaldi.  Relative
+         to best path of output.
+    min_active_states:
+        Minimum number of FSA states that are allowed to be active on any given
+        frame for any given intersection/composition task. This is advisory,
+        in that it will try not to have fewer than this number active.
+        Set it to zero if there is no constraint.
+    max_active_states:
+        Maximum number of FSA states that are allowed to be active on any given
+        frame for any given intersection/composition task. This is advisory,
+        in that it will try not to exceed that but may not always succeed.
+        You can use a very large number if no constraint is needed.
+    ac_scale:
+        acoustic scale applied to `log_probs_nnet_output`
+    subsampling_factor:
+        The subsampling factor of the model.
+
+    Returns
+    -------
+    lattice:
+        An FsaVec containing the decoding result. It has axes [utt][state][arc].
+    """
+
+    device = log_probs_nnet_output.device
+    input_lens = input_lens.to(device)
+    if decoder.device != device:
+        logger.warn(
+            "Decoding graph (HL or HLG) not loaded on the same device"
+            "  as nnet, this will cause decoding speed degradation"
+        )
+        decoder = decoder.to(device)
+
+    input_lens = (input_lens * log_probs_nnet_output.shape[1]).round().int()
+    # NOTE: low ac_scales may results in very big lattices and OOM errors.
+    log_probs_nnet_output *= ac_scale
+
+    lattice = k2.get_lattice(
+        log_probs_nnet_output,
+        input_lens,
+        decoder,
+        search_beam=search_beam,
+        output_beam=output_beam,
+        min_active_states=min_active_states,
+        max_active_states=max_active_states,
+        subsampling_factor=subsampling_factor,
+    )
+
+    return lattice
+
+
+@torch.no_grad()
+def one_best_decoding(
+    lattice: k2.Fsa, use_double_scores: bool = True,
+) -> k2.Fsa:
+    """
+    Get the best path from a lattice.
+
+    Arguments
+    ---------
+      lattice:
+        The decoding lattice returned by :func:`get_lattice`.
+      use_double_scores:
+        True to use double precision floating point in the computation.
+        False to use single precision.
+
+    Returns
+    -------
+    best_path:
+        An FsaVec containing linear paths.
+    """
+    best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
+    return best_path
+
+
+@torch.no_grad()
+def rescore_with_whole_lattice(
+    lattice: k2.Fsa,
+    G_with_epsilon_loops: k2.Fsa,
+    lm_scale_list: Optional[List[float]] = None,
+    use_double_scores: bool = True,
+) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
+    """
+    Intersect the lattice with an n-gram LM and use shortest path to decode.
+    The input lattice is obtained by intersecting `HLG` with
+    a DenseFsaVec, where the `G` in `HLG` is in general a 3-gram LM.
+    The input `G_with_epsilon_loops` is usually a 4-gram LM. You can consider
+    this function as a second pass decoding. In the first pass decoding, we
+    use a small G, while we use a larger G in the second pass decoding.
+
+    Arguments
+    ---------
+    lattice: k2.Fsa
+        An FsaVec with axes [utt][state][arc]. Its `aux_labels` are word IDs.
+        It must have an attribute `lm_scores`.
+    G_with_epsilon_loops: k2.Fsa
+        An FsaVec containing only a single FSA. It contains epsilon self-loops.
+        It is an acceptor and its labels are word IDs.
+    lm_scale_list: Optional[List[float]]
+        If none, return the intersection of `lattice` and `G_with_epsilon_loops`.
+        If not None, it contains a list of values to scale LM scores.
+        For each scale, there is a corresponding decoding result contained in
+        the resulting dict.
+    use_double_scores: bool
+        True to use double precision in the computation.
+        False to use single precision.
+
+    Returns
+    -------
+        If `lm_scale_list` is None, return a new lattice which is the intersection
+        result of `lattice` and `G_with_epsilon_loops`.
+        Otherwise, return a dict whose key is an entry in `lm_scale_list` and the
+        value is the decoding result (i.e., an FsaVec containing linear FSAs).
+    """
+    assert G_with_epsilon_loops.shape == (1, None, None)
+    G_with_epsilon_loops = G_with_epsilon_loops.to(lattice.device)
+    device = lattice.device
+    if hasattr(lattice, "lm_scores"):
+        lattice.scores = lattice.scores - lattice.lm_scores
+        # We will use lm_scores from G, so remove lats.lm_scores here
+        del lattice.lm_scores
+
+    assert hasattr(G_with_epsilon_loops, "lm_scores")
+
+    # Now, lattice.scores contains only am_scores
+
+    # inv_lattice has word IDs as labels.
+    # Its `aux_labels` is token IDs
+    inv_lattice = k2.invert(lattice)
+    num_seqs = lattice.shape[0]
+
+    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
+
+    # NOTE: The choice of the threshold list is arbitrary here to avoid OOM.
+    # You may need to fine tune it.
+    prune_th_list = [1e-10, 1e-9, 1e-8, 1e-7, 1e-6]
+    prune_th_list += [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
+    max_loop_count = 10
+    loop_count = 0
+    while loop_count <= max_loop_count:
+        try:
+            if device == "cpu":
+                rescoring_lattice = k2.intersect(
+                    G_with_epsilon_loops,
+                    inv_lattice,
+                    treat_epsilons_specially=True,
+                )
+            else:
+                rescoring_lattice = k2.intersect_device(
+                    G_with_epsilon_loops,
+                    inv_lattice,
+                    b_to_a_map,
+                    sorted_match_a=True,
+                )
+            rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
+            break
+        except RuntimeError as e:
+            logger.info(f"Caught exception:\n{e}\n")
+            if loop_count >= max_loop_count:
+                logger.info(
+                    "Return None as the resulting lattice is too large."
+                )
+                return None
+            logger.info(
+                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
+            )
+            logger.info(
+                "This OOM is not an error. You can ignore it. "
+                "If your model does not converge well, or the segment length "
+                "is too large, or the input sound file is difficult to "
+                "decode, you will meet this exception."
+            )
+            inv_lattice = k2.prune_on_arc_post(
+                inv_lattice, prune_th_list[loop_count], True,
+            )
+            logger.info(
+                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
+            )
+        loop_count += 1
+
+    # lat has token IDs as labels
+    # and word IDs as aux_labels.
+    lat = k2.invert(rescoring_lattice)
+
+    if lm_scale_list is None:
+        return lat
+
+    ans = OrderedDict()
+    saved_am_scores = lat.scores - lat.lm_scores
+    for lm_scale in lm_scale_list:
+        am_scores = saved_am_scores / lm_scale
+        lat.scores = am_scores + lat.lm_scores
+
+        best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
+        key = f"whole_lattice_rescore_lm_scale_{lm_scale:.1f}"
+        ans[key] = best_path
+    return ans
diff --git a/speechbrain/k2_integration/lexicon.py b/speechbrain/k2_integration/lexicon.py
new file mode 100644
index 0000000000000000000000000000000000000000..4445c75fe73d9734dea7d8e0e57044f4296760e2
--- /dev/null
+++ b/speechbrain/k2_integration/lexicon.py
@@ -0,0 +1,565 @@
+"""Lexicon class and utilities. Provides functions to read/write
+lexicon files and convert them to k2 ragged tensors. The Lexicon
+class provides a way to convert a list of words to a ragged tensor
+containing token IDs. It also stores the lexicon graph which can
+be used by a graph compiler to decode sequences.
+
+This code was adjusted, and therefore heavily inspired or taken from
+from icefall's (https://github.com/k2-fsa/icefall) Lexicon class and
+its utility functions.
+
+
+Authors:
+  * Pierre Champion 2023
+  * Zeyu Zhao 2023
+  * Georgios Karakasidis 2023
+"""
+
+
+import logging
+import re
+import os
+import csv
+from pathlib import Path
+from typing import List, Union, Tuple, Optional
+
+from . import k2  # import k2 from ./__init__.py
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+UNK = "<UNK>"  # unknow word
+UNK_t = "<unk>"  # unknow token
+EOW = "<eow>"  # end of word
+EPS = "<eps>"  # epsilon
+
+DISAMBIG_PATTERN: re.Pattern = re.compile(
+    r"^#\d+$"
+)  # pattern for disambiguation symbols.
+
+
+class Lexicon(object):
+    """
+    Unit based lexicon. It is used to map a list of words to each word's
+    sequence of tokens (characters). It also stores the lexicon graph which
+    can be used by a graph compiler to decode sequences.
+
+    Arguments
+    ---------
+    lang_dir: str
+        Path to the lang directory. It is expected to contain the following
+        files:
+            - tokens.txt
+            - words.txt
+            - L.pt
+
+    Example
+    -------
+    >>> from speechbrain.k2_integration import k2
+    >>> from speechbrain.k2_integration.lexicon import Lexicon
+    >>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
+    >>> from speechbrain.k2_integration.prepare_lang import prepare_lang
+
+    >>> # Create a small lexicon containing only two words and write it to a file.
+    >>> lang_tmpdir = getfixture('tmpdir')
+    >>> lexicon_sample = '''hello h e l l o\\nworld w o r l d'''
+    >>> lexicon_file = lang_tmpdir.join("lexicon.txt")
+    >>> lexicon_file.write(lexicon_sample)
+    >>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
+    >>> prepare_lang(lang_tmpdir)
+    >>> # Create a lexicon object
+    >>> lexicon = Lexicon(lang_tmpdir)
+    >>> # Make sure the lexicon was loaded correctly
+    >>> assert isinstance(lexicon.token_table, k2.SymbolTable)
+    >>> assert isinstance(lexicon.L, k2.Fsa)
+    """
+
+    def __init__(
+        self, lang_dir: Path,
+    ):
+        self.lang_dir = lang_dir = Path(lang_dir)
+        self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
+        self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
+        self.word2tokenids = {}
+        with open(lang_dir / "lexicon.txt", "r", encoding="utf-8") as f:
+            for line in f:
+                word = line.strip().split()[0]
+                tokens = line.strip().split()[1:]
+                tids = [self.token_table[t] for t in tokens]
+                # handle multiple pronunciation
+                if word not in self.word2tokenids:
+                    self.word2tokenids[word] = []
+                self.word2tokenids[word].append(tids)
+
+        self._L_disambig = None
+
+        if (lang_dir / "L.pt").exists():
+            logger.info(f"Loading compiled {lang_dir}/L.pt")
+            L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt"))
+        else:
+            raise RuntimeError(
+                f"{lang_dir}/L.pt does not exist. Please make sure "
+                f"you have successfully created L.pt in {lang_dir}"
+            )
+
+        if (lang_dir / "Linv.pt").exists():
+            logger.info(f"Loading compiled {lang_dir}/Linv.pt")
+            L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt"))
+        else:
+            logger.info("Converting L.pt to Linv.pt")
+            L_inv = k2.arc_sort(L.invert())
+            torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
+
+        # We save L_inv instead of L because it will be used to intersect with
+        # transcript FSAs, both of whose labels are word IDs.
+        self.L_inv = L_inv
+        self.L = L
+
+    @property
+    def tokens(self) -> List[int]:
+        """
+        Return a list of token IDs excluding those from
+        disambiguation symbols and epsilon.
+        """
+        symbols = self.token_table.symbols
+        ans = []
+        for s in symbols:
+            if not DISAMBIG_PATTERN.match(s) or s != EPS:
+                ans.append(self.token_table[s])
+        ans.sort()
+        return ans
+
+    @property
+    def L_disambig(self) -> k2.Fsa:
+        """
+        Return the lexicon FSA (with disambiguation symbols).
+        Needed for HLG construction.
+        """
+        if self._L_disambig is None:
+            logger.info(f"Loading compiled {self.lang_dir}/L_disambig.pt")
+            if (self.lang_dir / "L_disambig.pt").exists():
+                self._L_disambig = k2.Fsa.from_dict(
+                    torch.load(self.lang_dir / "L_disambig.pt")
+                )
+            else:
+                raise RuntimeError(
+                    f"{self.lang_dir}/L_disambig.pt does not exist. Please make sure "
+                    f"you have successfully created L_disambig.pt in {self.lang_dir}"
+                )
+        return self._L_disambig
+
+    def remove_G_rescoring_disambig_symbols(self, G: k2.Fsa):
+        """
+        Remove the disambiguation symbols of a G graph
+
+        Arguments
+        ---------
+        G: k2.Fsa
+            The G graph to be modified
+        """
+        G.labels[G.labels >= self.word_table["#0"]] = 0
+
+    def remove_LG_disambig_symbols(self, LG: k2.Fsa) -> k2.Fsa:
+        """
+        Remove the disambiguation symbols of an LG graph
+        Needed for HLG construction.
+
+        Arguments
+        ---------
+        LG: k2.Fsa
+            The LG graph to be modified
+        """
+
+        first_token_disambig_id = self.token_table["#0"]
+        first_word_disambig_id = self.word_table["#0"]
+
+        logger.debug("Removing disambiguation symbols on LG")
+        # NOTE: We need to clone here since LG.labels is just a reference to a tensor
+        #       and we will end up having issues with misversioned updates on fsa's
+        #       properties.
+        labels = LG.labels.clone()
+        labels[labels >= first_token_disambig_id] = 0
+        LG.labels = labels
+
+        assert isinstance(LG.aux_labels, k2.RaggedTensor)
+        LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
+        return LG
+
+    def texts_to_word_ids(
+        self,
+        texts: List[str],
+        add_sil_token_as_separator=False,
+        sil_token_id: Optional[int] = None,
+        log_unknown_warning=True,
+    ) -> List[List[int]]:
+        """
+        Convert a list of texts into word IDs.
+
+        This method performs the mapping of each word in the input texts to its corresponding ID.
+        The result is a list of lists, where each inner list contains the word IDs for a sentence.
+        If the `add_sil_token_as_separator` flag is True, a silence token is inserted between words,
+        and the `sil_token_id` parameter specifies the ID for the silence token.
+        If a word is not found in the vocabulary, a warning is logged if `log_unknown_warning` is True.
+
+        Arguments
+        ---------
+        texts: List[str]
+            A list of strings where each string represents a sentence.
+            Each sentence is composed of space-separated words.
+
+        add_sil_token_as_separator: bool
+            Flag indicating whether to add a silence token as a separator between words.
+
+        sil_token_id: Optional[int]
+            The ID of the silence token. If not provided, the separator is not added.
+
+        log_unknown_warning: bool
+            Flag indicating whether to log a warning for unknown words.
+
+        Returns
+        -------
+        word_ids: List[List[int]]
+            A list of lists where each inner list represents the word IDs for a sentence.
+            The word IDs are obtained based on the vocabulary mapping.
+        """
+        word_ids = self._texts_to_ids(
+            texts, log_unknown_warning, _mapper="word_table"
+        )
+        if add_sil_token_as_separator:
+            assert (
+                sil_token_id is not None
+            ), "sil_token_id=None while add_sil_token_as_separator=True"
+            for i in range(len(word_ids)):
+                word_ids[i] = [
+                    x for item in word_ids[i] for x in (item, sil_token_id)
+                ][:-1]
+        return word_ids
+
+    def texts_to_token_ids(
+        self, texts: List[str], log_unknown_warning=True,
+    ) -> List[List[List[int]]]:
+        """
+        Convert a list of text sentences into token IDs.
+
+        Parameters
+        ----------
+        texts: List[str]
+            A list of strings, where each string represents a sentence.
+            Each sentence consists of space-separated words.
+            Example:
+                ['hello world', 'tokenization with lexicon']
+
+        log_unknown_warning: bool
+            Flag indicating whether to log warnings for out-of-vocabulary tokens.
+            If True, warnings will be logged when encountering unknown tokens.
+
+        Returns
+        -------
+        token_ids: List[List[List[int]]]
+            A list containing token IDs for each sentence in the input.
+            The structure of the list is as follows:
+            [
+                [  # For the first sentence
+                    [token_id_1, token_id_2, ..., token_id_n],
+                    [token_id_1, token_id_2, ..., token_id_m],
+                    ...
+                ],
+                [  # For the second sentence
+                    [token_id_1, token_id_2, ..., token_id_p],
+                    [token_id_1, token_id_2, ..., token_id_q],
+                    ...
+                ],
+                ...
+            ]
+            Each innermost list represents the token IDs for a word in the sentence.
+        """
+        return self._texts_to_ids(
+            texts, log_unknown_warning, _mapper="word2tokenids"
+        )
+
+    def texts_to_token_ids_with_multiple_pronunciation(
+        self, texts: List[str], log_unknown_warning=True,
+    ) -> List[List[List[List[int]]]]:
+        """
+        Convert a list of input texts to token IDs with multiple pronunciation variants.
+
+        This method converts input texts into token IDs, considering multiple pronunciation variants.
+        The resulting structure allows for handling various pronunciations of words within the given texts.
+
+        Arguments
+        ---------
+        texts: List[str]
+            A list of strings, where each string represents a sentence for an utterance.
+            Each sentence consists of space-separated words.
+
+        log_unknown_warning: bool
+            Indicates whether to log warnings for out-of-vocabulary (OOV) tokens.
+            If set to True, warnings will be logged for OOV tokens during the conversion.
+
+        Returns
+        -------
+        token_ids: List[List[List[List[int]]]]
+            A nested list structure containing token IDs for each utterance. The structure is as follows:
+            - Outer List: Represents different utterances.
+            - Middle List: Represents different pronunciation variants for each utterance.
+            - Inner List: Represents the sequence of token IDs for each pronunciation variant.
+            - Innermost List: Represents the token IDs for each word in the sequence.
+        """
+        return self._texts_to_ids(
+            texts,
+            log_unknown_warning,
+            _mapper="word2tokenids",
+            _multiple_pronunciation=True,
+        )
+
+    def _texts_to_ids(
+        self,
+        texts: List[str],
+        log_unknown_warning: bool,
+        _mapper: str,
+        _multiple_pronunciation=False,
+    ):
+        """
+        Convert a list of texts to a list of IDs, which can be either word IDs or
+        a list of token IDs.
+
+        Arguments
+        ---------
+        texts: List[str]
+            A list of strings where each string consists of space-separated words.
+            Example:
+                ['hello world', 'tokenization with lexicon']
+
+        log_unknown_warning: bool
+            Log a warning if a word is not found in the token-to-IDs mapping.
+
+        _mapper: str
+            The mapper to use, either "word_table" (e.g., "TEST" -> 176838) or
+            "word2tokenids" (e.g., "TEST" -> [23, 8, 22, 23]).
+
+        _multiple_pronunciation: bool
+            Allow returning all pronunciations of a word from the lexicon.
+            If False, only return the first pronunciation.
+
+        Returns
+        -------
+        ids_list: List[List[int] or int]
+            Returns a list-of-list of word IDs or a list of token IDs.
+        """
+        oov_token_id = self.word_table[UNK]
+        if _mapper == "word2tokenids":
+            oov_token_id = [self.token_table[UNK_t]]
+        ids = getattr(self, _mapper)
+
+        ids_list = []
+        for text in texts:
+            word_ids = []
+            words = text.split()
+            for i, word in enumerate(words):
+                if word in ids:
+                    idword = ids[word]
+                    if isinstance(idword, list) and not _multiple_pronunciation:
+                        idword = idword[
+                            0
+                        ]  # only first spelling of a word (for word2tokenids mapper)
+                    word_ids.append(idword)
+                else:
+                    word_ids.append(oov_token_id)
+                    if log_unknown_warning:
+                        logger.warning(
+                            f"Cannot find word {word} in the mapper {_mapper}."
+                            f" Replacing it with OOV token."
+                            f" Note that it is fine if you are testing."
+                        )
+
+            ids_list.append(word_ids)
+        return ids_list
+
+    def arc_sort(self):
+        """
+        Sort L, L_inv, L_disambig arcs of every state.
+        """
+        self.L = k2.arc_sort(self.L)
+        self.L_inv = k2.arc_sort(self.L_inv)
+        if self._L_disambig is not None:
+            self._L_disambig = k2.arc_sort(self._L_disambig)
+
+    def to(self, device: str = "cpu"):
+        """
+        Device to move L, L_inv and L_disambig to
+
+        Arguments
+        ---------
+        device: str
+            The device
+        """
+        self.L = self.L.to(device)
+        self.L_inv = self.L_inv.to(device)
+        if self._L_disambig is not None:
+            self._L_disambig = self._L_disambig.to(device)
+
+
+def prepare_char_lexicon(
+    lang_dir,
+    vocab_files,
+    extra_csv_files=[],
+    column_text_key="wrd",
+    add_word_boundary=True,
+):
+    """
+    Read extra_csv_files to generate a $lang_dir/lexicon.txt for k2 training.
+    This usually includes the csv files of the training set and the dev set in the
+    output_folder. During training, we need to make sure that the lexicon.txt contains
+    all (or the majority of) the words in the training set and the dev set.
+
+    NOTE: This assumes that the csv files contain the transcription in the last column.
+
+    Also note that in each csv_file, the first line is the header, and the remaining
+    lines are in the following format:
+
+    ID, duration, wav, spk_id, wrd (transcription)
+
+    We only need the transcription in this function.
+
+    Writes out $lang_dir/lexicon.txt
+
+    Note that the lexicon.txt is a text file with the following format:
+    word1 phone1 phone2 phone3 ...
+    word2 phone1 phone2 phone3 ...
+
+    In this code, we simply use the characters in the word as the phones.
+    You can use other phone sets, e.g., phonemes, BPEs, to train a better model.
+
+    Arguments
+    ---------
+    lang_dir: str
+        The directory to store the lexicon.txt
+    vocab_files: List[str]
+        A list of extra vocab files. For example, for librispeech this could be the
+        librispeech-vocab.txt file.
+    extra_csv_files: List[str]
+        A list of csv file paths
+    column_text_key: str
+        The column name of the transcription in the csv file. By default, it is "wrd".
+    add_word_boundary: bool
+        whether to add word boundary symbols <eow> at the end of each line to the
+        lexicon for every word.
+
+    Example
+    -------
+    >>> from speechbrain.k2_integration.lexicon import prepare_char_lexicon
+    >>> # Create some dummy csv files containing only the words `hello`, `world`.
+    >>> # The first line is the header, and the remaining lines are in the following
+    >>> # format:
+    >>> # ID, duration, wav, spk_id, wrd (transcription)
+    >>> csv_file = getfixture('tmpdir').join("train.csv")
+    >>> # Data to be written to the CSV file.
+    >>> import csv
+    >>> data = [
+    ...    ["ID", "duration", "wav", "spk_id", "wrd"],
+    ...    [1, 1, 1, 1, "hello world"],
+    ...    [2, 0.5, 1, 1, "hello"]
+    ... ]
+    >>> with open(csv_file, "w", newline="") as f:
+    ...    writer = csv.writer(f)
+    ...    writer.writerows(data)
+    >>> extra_csv_files = [csv_file]
+    >>> lang_dir = getfixture('tmpdir')
+    >>> vocab_files = []
+    >>> prepare_char_lexicon(lang_dir, vocab_files, extra_csv_files=extra_csv_files, add_word_boundary=False)
+    """
+    # Read train.csv, dev-clean.csv to generate a lexicon.txt for k2 training
+    lexicon = dict()
+    if len(extra_csv_files) != 0:
+        for file in extra_csv_files:
+            with open(file, "r") as f:
+                csv_reader = csv.DictReader(f)
+                for row in csv_reader:
+                    # Split the transcription into words
+                    words = row[column_text_key].split()
+                    for word in words:
+                        if word not in lexicon:
+                            if add_word_boundary:
+                                lexicon[word] = list(word) + [EOW]
+                            else:
+                                lexicon[word] = list(word)
+
+    for file in vocab_files:
+        with open(file) as f:
+            for line in f:
+                # Split the line
+                word = line.strip().split()[0]
+                # Split the transcription into words
+                if word not in lexicon:
+                    if add_word_boundary:
+                        lexicon[word] = list(word) + [EOW]
+                    else:
+                        lexicon[word] = list(word)
+    # Write the lexicon to lang_dir/lexicon.txt
+    os.makedirs(lang_dir, exist_ok=True)
+    with open(os.path.join(lang_dir, "lexicon.txt"), "w") as f:
+        fc = f"{UNK} {UNK_t}\n"
+        for word in lexicon:
+            fc += word + " " + " ".join(lexicon[word]) + "\n"
+        f.write(fc)
+
+
+def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
+    """
+    Read a lexicon from `filename`.
+
+    Each line in the lexicon contains "word p1 p2 p3 ...".
+    That is, the first field is a word and the remaining
+    fields are tokens. Fields are separated by space(s).
+
+    Arguments
+    ---------
+    filename: str
+        Path to the lexicon.txt
+
+    Returns
+    -------
+    ans:
+        A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
+    """
+    ans = []
+
+    with open(filename, "r", encoding="utf-8") as f:
+        whitespace = re.compile("[ \t]+")
+        for line in f:
+            a = whitespace.split(line.strip(" \t\r\n"))
+            if len(a) == 0:
+                continue
+            if len(a) < 2:
+                raise RuntimeError(
+                    f"Found bad line {line} in lexicon file {filename}"
+                    "Every line is expected to contain at least 2 fields"
+                )
+            word = a[0]
+            if word == EPS:
+                raise RuntimeError(
+                    f"Found bad line {line} in lexicon file {filename}"
+                    f"{EPS} should not be a valid word"
+                )
+            tokens = a[1:]
+            ans.append((word, tokens))
+    return ans
+
+
+def write_lexicon(
+    filename: Union[str, Path], lexicon: List[Tuple[str, List[str]]]
+) -> None:
+    """
+    Write a lexicon to a file.
+
+    Arguments
+    ---------
+    filename: str
+        Path to the lexicon file to be generated.
+    lexicon: List[Tuple[str, List[str]]]
+        It can be the return value of :func:`read_lexicon`.
+    """
+    with open(filename, "w", encoding="utf-8") as f:
+        for word, tokens in lexicon:
+            f.write(f"{word} {' '.join(tokens)}\n")
diff --git a/speechbrain/k2_integration/losses.py b/speechbrain/k2_integration/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c80c281b7e2e84120aa60206153c7bca6386f6d
--- /dev/null
+++ b/speechbrain/k2_integration/losses.py
@@ -0,0 +1,130 @@
+""" This file contains the loss functions for k2 training. Currently, we only
+support CTC loss.
+
+Authors:
+ * Pierre Champion 2023
+ * Zeyu Zhao 2023
+ * Georgios Karakasidis 2023
+"""
+
+from . import k2  # import k2 from ./__init__.py
+
+import torch
+
+
+def ctc_k2(
+    log_probs,
+    input_lens,
+    graph_compiler,
+    texts,
+    reduction="mean",
+    beam_size=10,
+    use_double_scores=True,
+    is_training=True,
+):
+    """
+    CTC loss implemented with k2. Make sure that k2 has been installed properly.
+    Note that the blank index must be 0 in this implementation.
+
+    Arguments
+    ---------
+    log_probs: torch.Tensor
+        Log-probs of shape (batch, time, num_classes).
+    input_lens : torch.Tensor
+        Length of each utterance.
+    graph_compiler : k2.Fsa
+        Decoding graph.
+    texts : List[str]
+        List of texts.
+    reduction : str
+        What reduction to apply to the output. 'mean', 'sum', 'none'.
+        See k2.ctc_loss for 'mean', 'sum', 'none'.
+    beam_size : int
+        Beam size.
+    use_double_scores : bool
+        If true, use double precision for scores.
+    is_training : bool
+        If true, the returned loss requires gradient.
+
+    Returns
+    -------
+    loss: torch.Tensor
+        CTC loss.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.k2_integration.losses import ctc_k2
+    >>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
+    >>> from speechbrain.k2_integration.lexicon import Lexicon
+    >>> from speechbrain.k2_integration.prepare_lang import prepare_lang
+
+    >>> # Create a random batch of log-probs
+    >>> batch_size = 4
+
+    >>> log_probs = torch.randn(batch_size, 100, 30)
+    >>> log_probs.requires_grad = True
+    >>> # Assume all utterances have the same length so no padding was needed.
+    >>> input_lens = torch.ones(batch_size)
+    >>> # Create a samll lexicon containing only two words and write it to a file.
+    >>> lang_tmpdir = getfixture('tmpdir')
+    >>> lexicon_sample = "hello h e l l o\\nworld w o r l d\\n<UNK> <unk>"
+    >>> lexicon_file = lang_tmpdir.join("lexicon.txt")
+    >>> lexicon_file.write(lexicon_sample)
+    >>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
+    >>> prepare_lang(lang_tmpdir)
+    >>> # Create a lexicon object
+    >>> lexicon = Lexicon(lang_tmpdir)
+    >>> # Create a random decoding graph
+    >>> graph = CtcGraphCompiler(
+    ...     lexicon,
+    ...     log_probs.device,
+    ... )
+    >>> # Create a random batch of texts
+    >>> texts = ["hello world", "world hello", "hello", "world"]
+    >>> # Compute the loss
+    >>> loss = ctc_k2(
+    ...     log_probs=log_probs,
+    ...     input_lens=input_lens,
+    ...     graph_compiler=graph,
+    ...     texts=texts,
+    ...     reduction="mean",
+    ...     beam_size=10,
+    ...     use_double_scores=True,
+    ...     is_training=True,
+    ... )
+    """
+    input_lens = (input_lens * log_probs.shape[1]).round().int()
+
+    batch_size = log_probs.shape[0]
+
+    supervision_segments = torch.tensor(
+        [[i, 0, input_lens[i]] for i in range(batch_size)],
+        device="cpu",
+        dtype=torch.int32,
+    )
+
+    decoding_graph, target_lens = graph_compiler.compile(
+        texts, is_training=is_training
+    )
+
+    # An introduction to DenseFsaVec:
+    # https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector
+    # It could be viewed as a fsa-type log_probs,
+    # whose weight on the arcs are initialized with log_probs.
+    # The goal of converting tensor-type to fsa-type is using
+    # fsa related functions in k2. e.g. k2.ctc_loss.
+    dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments)
+
+    loss = k2.ctc_loss(
+        decoding_graph=decoding_graph.to(log_probs.device),
+        dense_fsa_vec=dense_fsa_vec,
+        target_lengths=target_lens.to(log_probs.device),
+        output_beam=beam_size,
+        reduction=reduction,
+        use_double_scores=use_double_scores,
+    )
+
+    assert loss.requires_grad == is_training
+
+    return loss
diff --git a/speechbrain/k2_integration/prepare_lang.py b/speechbrain/k2_integration/prepare_lang.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb22efec080e3df929e580e86c8ba51ab2122774
--- /dev/null
+++ b/speechbrain/k2_integration/prepare_lang.py
@@ -0,0 +1,555 @@
+#!/usr/bin/env python3
+""" This module contains functions to prepare the lexicon and the language model
+for k2 training. It is based on the script `prepare_lang.sh` from k2/icefall (work
+of Fangjun Kuang). The original script is under Apache 2.0 license.
+This script is modified to work with SpeechBrain.
+
+Modified by:
+  * Pierre Champion 2023
+  * Zeyu Zhao 2023
+  * Georgios Karakasidis 2023
+"""
+
+
+from . import k2  # import k2 from ./__init__.py
+from .lexicon import read_lexicon, write_lexicon, EPS
+import math
+import os
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+import torch
+
+logger = logging.getLogger(__name__)
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def write_mapping(filename: Union[str, Path], sym2id: Dict[str, int]) -> None:
+    """
+    Write a symbol to ID mapping to a file.
+
+    NOTE: No need to implement `read_mapping` as it can be done through
+      :func:`k2.SymbolTable.from_file`.
+
+    Arguments
+    ---------
+    filename: str
+        Filename to save the mapping.
+    sym2id: Dict[str, int]
+        A dict mapping symbols to IDs.
+    """
+    with open(filename, "w", encoding="utf-8") as f:
+        for sym, i in sym2id.items():
+            f.write(f"{sym} {i}\n")
+
+
+def get_tokens(
+    lexicon: Lexicon, sil_token="SIL", manually_add_sil_to_tokens=False
+) -> List[str]:
+    """
+    Get tokens from a lexicon.
+
+    Arguments
+    ---------
+    lexicon: Lexicon
+        It is the return value of :func:`read_lexicon`.
+    sil_token: str
+        The optional silence token between words. It should not appear in the lexicon,
+        otherwise it will cause an error.
+    manually_add_sil_to_tokens: bool
+        If true, add `sil_token` to the tokens. This is useful when the lexicon
+        does not contain `sil_token` but it is needed in the tokens.
+
+    Returns
+    -------
+    sorted_ans: List[str]
+        A list of unique tokens.
+    """
+    ans = set()
+    if manually_add_sil_to_tokens:
+        ans.add(sil_token)
+    for _, tokens in lexicon:
+        assert (
+            sil_token not in tokens
+        ), f"{sil_token} should not appear in the lexicon but it is found in {_}"
+        ans.update(tokens)
+    sorted_ans = sorted(list(ans))
+    return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+    """
+    Get words from a lexicon.
+
+    Arguments
+    ---------
+    lexicon: Lexicon
+        It is the return value of :func:`read_lexicon`.
+
+    Returns
+    -------
+    sorted_ans:
+        Return a list of unique words.
+    """
+    ans = set()
+    for word, _ in lexicon:
+        ans.add(word)
+    sorted_ans = sorted(list(ans))
+    return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+    """
+    It adds pseudo-token disambiguation symbols #1, #2 and so on
+    at the ends of tokens to ensure that all pronunciations are different,
+    and that none is a prefix of another.
+
+    See also add_lex_disambig.pl from kaldi.
+
+    Arguments
+    ---------
+    lexicon: Lexicon
+        It is returned by :func:`read_lexicon`.
+
+    Returns
+    -------
+    ans:
+        The output lexicon with disambiguation symbols
+    max_disambig:
+        The ID of the max disambiguation symbol that appears
+        in the lexicon
+    """
+
+    # (1) Work out the count of each token-sequence in the
+    # lexicon.
+    count = defaultdict(int)
+    for _, tokens in lexicon:
+        count[" ".join(tokens)] += 1
+
+    # (2) For each left sub-sequence of each token-sequence, note down
+    # that it exists (for identifying prefixes of longer strings).
+    issubseq = defaultdict(int)
+    for _, tokens in lexicon:
+        tokens = tokens.copy()
+        tokens.pop()
+        while tokens:
+            issubseq[" ".join(tokens)] = 1
+            tokens.pop()
+
+    # (3) For each entry in the lexicon:
+    # if the token sequence is unique and is not a
+    # prefix of another word, no disambig symbol.
+    # Else output #1, or #2, #3, ... if the same token-seq
+    # has already been assigned a disambig symbol.
+    ans = []
+
+    # We start with #1 since #0 has its own purpose
+    first_allowed_disambig = 1
+    max_disambig = first_allowed_disambig - 1
+    last_used_disambig_symbol_of = defaultdict(int)
+
+    for word, tokens in lexicon:
+        tokenseq = " ".join(tokens)
+        assert tokenseq != ""
+        if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+            ans.append((word, tokens))
+            continue
+
+        cur_disambig = last_used_disambig_symbol_of[tokenseq]
+        if cur_disambig == 0:
+            cur_disambig = first_allowed_disambig
+        else:
+            cur_disambig += 1
+
+        if cur_disambig > max_disambig:
+            max_disambig = cur_disambig
+        last_used_disambig_symbol_of[tokenseq] = cur_disambig
+        tokenseq += f" #{cur_disambig}"
+        ans.append((word, tokenseq.split()))
+    return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+    """
+    Generate ID maps, i.e., map a symbol to a unique ID.
+
+    Arguments
+    ---------
+    symbols: List[str]
+        A list of unique symbols.
+
+    Returns
+    -------
+    A dict containing the mapping between symbols and IDs.
+    """
+    return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+    arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+    """
+    Adds self-loops to states of an FST to propagate disambiguation symbols
+    through it. They are added on each state with non-epsilon output symbols
+    on at least one arc out of the state.
+
+    See also fstaddselfloops.pl from Kaldi. One difference is that
+    Kaldi uses OpenFst style FSTs and it has multiple final states.
+    This function uses k2 style FSTs and it does not need to add self-loops
+    to the final state.
+
+    The input label of a self-loop is `disambig_token`, while the output
+    label is `disambig_word`.
+
+    Arguments
+    ---------
+    arcs: List[List[Any]]
+        A list-of-list. The sublist contains
+        `[src_state, dest_state, label, aux_label, score]`
+    disambig_token: int
+        It is the token ID of the symbol `#0`.
+    disambig_word: int
+        It is the word ID of the symbol `#0`.
+
+    Returns
+    -------
+    Return new `arcs` containing self-loops.
+    """
+    states_needs_self_loops = set()
+    for arc in arcs:
+        src, dst, ilabel, olabel, score = arc
+        if olabel != 0:
+            states_needs_self_loops.add(src)
+
+    ans = []
+    for s in states_needs_self_loops:
+        ans.append([s, s, disambig_token, disambig_word, 0])
+
+    return arcs + ans
+
+
+def lexicon_to_fst(
+    lexicon: Lexicon,
+    token2id: Dict[str, int],
+    word2id: Dict[str, int],
+    sil_token: str = "SIL",
+    sil_prob: float = 0.5,
+    need_self_loops: bool = False,
+) -> k2.Fsa:
+    """
+    Convert a lexicon to an FST (in k2 format) with optional silence at the
+    beginning and end of each word.
+
+    Arguments
+    ---------
+    lexicon: Lexicon
+        The input lexicon. See also :func:`read_lexicon`
+    token2id: Dict[str, int]
+        A dict mapping tokens to IDs.
+    word2id: Dict[str, int]
+        A dict mapping words to IDs.
+    sil_token: str
+        The silence token.
+    sil_prob: float
+        The probability for adding a silence at the beginning and end
+        of the word.
+    need_self_loops: bool
+        If True, add self-loop to states with non-epsilon output symbols
+        on at least one arc out of the state. The input label for this
+        self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+
+    Returns
+    -------
+    fsa: k2.Fsa
+        An FSA representing the given lexicon.
+    """
+    assert sil_prob > 0.0 and sil_prob < 1.0
+    # CAUTION: we use score, i.e, negative cost.
+    sil_score = math.log(sil_prob)
+    no_sil_score = math.log(1.0 - sil_prob)
+
+    start_state = 0
+    loop_state = 1  # words enter and leave from here
+    sil_state = 2  # words terminate here when followed by silence; this state
+    # has a silence transition to loop_state.
+    next_state = 3  # the next un-allocated state, will be incremented as we go.
+    arcs = []
+
+    assert token2id[EPS] == 0
+    assert word2id[EPS] == 0
+
+    eps = 0
+
+    sil_token_id = token2id[sil_token]
+
+    arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+    arcs.append([start_state, sil_state, eps, eps, sil_score])
+    arcs.append([sil_state, loop_state, sil_token_id, eps, 0])
+
+    for word, tokens in lexicon:
+        assert len(tokens) > 0, f"{word} has no pronunciations"
+        cur_state = loop_state
+
+        word = word2id[word]
+        tokens = [token2id[i] for i in tokens]
+
+        for i in range(len(tokens) - 1):
+            w = word if i == 0 else eps
+            arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+            cur_state = next_state
+            next_state += 1
+
+        # now for the last token of this word
+        # It has two out-going arcs, one to the loop state,
+        # the other one to the sil_state.
+        i = len(tokens) - 1
+        w = word if i == 0 else eps
+        arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+        arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+    if need_self_loops:
+        disambig_token = token2id["#0"]
+        disambig_word = word2id["#0"]
+        arcs = add_self_loops(
+            arcs, disambig_token=disambig_token, disambig_word=disambig_word,
+        )
+
+    final_state = next_state
+    arcs.append([loop_state, final_state, -1, -1, 0])
+    arcs.append([final_state])
+
+    arcs = sorted(arcs, key=lambda arc: arc[0])
+    arcs = [[str(i) for i in arc] for arc in arcs]
+    arcs = [" ".join(arc) for arc in arcs]
+    arcs = "\n".join(arcs)
+
+    fsa = k2.Fsa.from_str(arcs, acceptor=False)
+    return fsa
+
+
+def lexicon_to_fst_no_sil(
+    lexicon: Lexicon,
+    token2id: Dict[str, int],
+    word2id: Dict[str, int],
+    need_self_loops: bool = False,
+) -> k2.Fsa:
+    """
+    Convert a lexicon to an FST (in k2 format).
+
+    Arguments
+    ---------
+    lexicon: Lexicon
+        The input lexicon. See also :func:`read_lexicon`
+    token2id: Dict[str, int]
+        A dict mapping tokens to IDs.
+    word2id: Dict[str, int]
+        A dict mapping words to IDs.
+    need_self_loops: bool
+        If True, add self-loop to states with non-epsilon output symbols
+        on at least one arc out of the state. The input label for this
+        self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+
+    Returns
+    -------
+    fsa: k2.Fsa
+        An FSA representing the given lexicon.
+    """
+    loop_state = 0  # words enter and leave from here
+    next_state = 1  # the next un-allocated state, will be incremented as we go
+
+    arcs = []
+
+    assert token2id[EPS] == 0
+    assert word2id[EPS] == 0
+
+    eps = 0
+
+    for word, pieces in lexicon:
+        assert len(pieces) > 0, f"{word} has no pronunciations"
+        cur_state = loop_state
+
+        word = word2id[word]
+        pieces = [token2id[i] for i in pieces]
+
+        for i in range(len(pieces) - 1):
+            w = word if i == 0 else eps
+            arcs.append([cur_state, next_state, pieces[i], w, 0])
+
+            cur_state = next_state
+            next_state += 1
+
+        # now for the last piece of this word
+        i = len(pieces) - 1
+        w = word if i == 0 else eps
+        arcs.append([cur_state, loop_state, pieces[i], w, 0])
+
+    if need_self_loops:
+        disambig_token = token2id["#0"]
+        disambig_word = word2id["#0"]
+        arcs = add_self_loops(
+            arcs, disambig_token=disambig_token, disambig_word=disambig_word,
+        )
+
+    final_state = next_state
+    arcs.append([loop_state, final_state, -1, -1, 0])
+    arcs.append([final_state])
+
+    arcs = sorted(arcs, key=lambda arc: arc[0])
+    arcs = [[str(i) for i in arc] for arc in arcs]
+    arcs = [" ".join(arc) for arc in arcs]
+    arcs = "\n".join(arcs)
+
+    fsa = k2.Fsa.from_str(arcs, acceptor=False)
+    return fsa
+
+
+def prepare_lang(lang_dir, sil_token="SIL", sil_prob=0.5, cache=True):
+    """
+    This function takes as input a lexicon file "$lang_dir/lexicon.txt"
+    consisting of words and tokens (i.e., phones) and does the following:
+
+    1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+    2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+    3. Generate words.txt, the word table mapping a word to a unique integer.
+
+    4. Generate L.pt, in k2 format. It can be loaded by
+
+            d = torch.load("L.pt")
+            lexicon = k2.Fsa.from_dict(d)
+
+    5. Generate L_disambig.pt, in k2 format.
+
+
+    Arguments
+    ---------
+    lang_dir: str
+        The directory to store the output files and read the input file lexicon.txt.
+    sil_token: str
+        The silence token. Default is "SIL".
+    sil_prob: float
+        The probability for adding a silence at the beginning and end of the word.
+        Default is 0.5.
+    cache: bool
+        Whether or not to load/cache from/to the .pt format.
+
+    Example
+    -------
+    >>> from speechbrain.k2_integration.prepare_lang import prepare_lang
+
+    >>> # Create a small lexicon containing only two words and write it to a file.
+    >>> lang_tmpdir = getfixture('tmpdir')
+    >>> lexicon_sample = '''hello h e l l o\\nworld w o r l d'''
+    >>> lexicon_file = lang_tmpdir.join("lexicon.txt")
+    >>> lexicon_file.write(lexicon_sample)
+
+    >>> prepare_lang(lang_tmpdir)
+    >>> for expected_file in ["tokens.txt", "words.txt", "L.pt", "L_disambig.pt", "Linv.pt" ]:
+    ...     assert os.path.exists(os.path.join(lang_tmpdir, expected_file))
+    """
+
+    out_dir = Path(lang_dir)
+    lexicon_filename = out_dir / "lexicon.txt"
+
+    # if source lexicon_filename has been re-created (only use 'Linv.pt' for date modification query)
+    if (
+        cache
+        and (out_dir / "Linv.pt").exists()
+        and (out_dir / "Linv.pt").stat().st_mtime
+        < lexicon_filename.stat().st_mtime
+    ):
+        logger.warning(
+            f"Skipping lang preparation of '{out_dir}'."
+            " Set 'caching: False' in the yaml"
+            " if this is not what you want."
+        )
+        return
+
+    # backup L.pt, L_disambig.pt, tokens.txt and words.txt, Linv.pt and lexicon_disambig.txt
+    for f in [
+        "L.pt",
+        "L_disambig.pt",
+        "tokens.txt",
+        "words.txt",
+        "Linv.pt",
+        "lexicon_disambig.txt",
+    ]:
+        if (out_dir / f).exists():
+            os.makedirs(out_dir / "backup", exist_ok=True)
+            logger.debug(f"Backing up {out_dir / f} to {out_dir}/backup/{f}")
+            os.rename(out_dir / f, out_dir / "backup" / f)
+
+    lexicon = read_lexicon(str(lexicon_filename))
+    if sil_prob != 0:
+        # add silence to the tokens
+        tokens = get_tokens(
+            lexicon, sil_token=sil_token, manually_add_sil_to_tokens=True
+        )
+    else:
+        tokens = get_tokens(lexicon, manually_add_sil_to_tokens=False)
+    words = get_words(lexicon)
+
+    lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+    for i in range(max_disambig + 1):
+        disambig = f"#{i}"
+        assert disambig not in tokens
+        tokens.append(f"#{i}")
+
+    assert EPS not in tokens
+    tokens = [EPS] + tokens
+
+    assert EPS not in words
+    assert "#0" not in words
+    assert "<s>" not in words
+    assert "</s>" not in words
+
+    words = [EPS] + words + ["#0", "<s>", "</s>"]
+
+    token2id = generate_id_map(tokens)
+    word2id = generate_id_map(words)
+
+    logger.info(
+        f"Saving tokens.txt, words.txt, lexicon_disambig.txt to '{out_dir}'"
+    )
+    write_mapping(out_dir / "tokens.txt", token2id)
+    write_mapping(out_dir / "words.txt", word2id)
+    write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+    if sil_prob != 0:
+        L = lexicon_to_fst(
+            lexicon,
+            token2id=token2id,
+            word2id=word2id,
+            sil_token=sil_token,
+            sil_prob=sil_prob,
+        )
+    else:
+        L = lexicon_to_fst_no_sil(lexicon, token2id=token2id, word2id=word2id,)
+
+    if sil_prob != 0:
+        L_disambig = lexicon_to_fst(
+            lexicon_disambig,
+            token2id=token2id,
+            word2id=word2id,
+            sil_token=sil_token,
+            sil_prob=sil_prob,
+            need_self_loops=True,
+        )
+    else:
+        L_disambig = lexicon_to_fst_no_sil(
+            lexicon_disambig,
+            token2id=token2id,
+            word2id=word2id,
+            need_self_loops=True,
+        )
+
+    L_inv = k2.arc_sort(L.invert())
+    logger.info(f"Saving L.pt, Linv.pt, L_disambig.pt to '{out_dir}'")
+    torch.save(L.as_dict(), out_dir / "L.pt")
+    torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
+    torch.save(L_inv.as_dict(), out_dir / "Linv.pt")
diff --git a/speechbrain/k2_integration/utils.py b/speechbrain/k2_integration/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..358e86d94c6d56b20f8e04937c85ab9a9efef35a
--- /dev/null
+++ b/speechbrain/k2_integration/utils.py
@@ -0,0 +1,166 @@
+"""Utilities for k2 integration with SpeechBrain.
+
+This code was adjusted from icefall (https://github.com/k2-fsa/icefall).
+
+
+Authors:
+  * Pierre Champion 2023
+  * Zeyu Zhao 2023
+  * Georgios Karakasidis 2023
+"""
+
+import os
+import logging
+from pathlib import Path
+from typing import List, Union
+import torch
+
+from . import k2  # import k2 from ./__init__.py
+
+logger = logging.getLogger(__name__)
+
+
+def lattice_path_to_textid(
+    best_paths: k2.Fsa, return_ragged: bool = False
+) -> Union[List[List[int]], k2.RaggedTensor]:
+    """
+    Extract the texts (as word IDs) from the best-path FSAs.
+
+    Arguments
+    ---------
+    best_paths: k2.Fsa
+        A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
+        containing multiple FSAs, which is expected to be the result
+        of k2.shortest_path (otherwise the returned values won't
+        be meaningful).
+    return_ragged: bool
+        True to return a ragged tensor with two axes [utt][word_id].
+        False to return a list-of-list word IDs.
+
+    Returns
+    -------
+    Returns a list of lists of int, containing the label sequences we
+    decoded.
+    """
+    if isinstance(best_paths.aux_labels, k2.RaggedTensor):
+        # remove 0's and -1's.
+        aux_labels = best_paths.aux_labels.remove_values_leq(0)
+        # TODO: change arcs.shape() to arcs.shape
+        aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
+
+        # remove the states and arcs axes.
+        aux_shape = aux_shape.remove_axis(1)
+        aux_shape = aux_shape.remove_axis(1)
+        aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
+    else:
+        # remove axis corresponding to states.
+        aux_shape = best_paths.arcs.shape().remove_axis(1)
+        aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
+        # remove 0's and -1's.
+        aux_labels = aux_labels.remove_values_leq(0)
+
+    assert aux_labels.num_axes == 2
+    if return_ragged:
+        return aux_labels
+    else:
+        return aux_labels.tolist()
+
+
+def lattice_paths_to_text(best_paths: k2.Fsa, word_table) -> List[str]:
+    """
+    Convert the best path to a list of strings.
+
+    Arguments
+    ---------
+    best_paths: k2.Fsa
+        It is the path in the lattice with the highest score for a
+        given utterance.
+    word_table: List[str] or Dict[int,str]
+        It is a list or dict that maps word IDs to words.
+
+    Returns
+    -------
+    texts: List[str]
+        A list of strings, each of which is the decoding result of the
+        corresponding utterance.
+    """
+    hyps: List[List[int]] = lattice_path_to_textid(
+        best_paths, return_ragged=False
+    )
+    texts = []
+    for wids in hyps:
+        texts.append(" ".join([word_table[wid] for wid in wids]))
+    return texts
+
+
+def load_G(path: Union[str, Path], cache: bool = True) -> k2.Fsa:
+    """
+    load a lm to be used in the decoding graph creation (or lm rescoring).
+
+    Arguments
+    ---------
+    path: str
+        The path to an FST LM (ending with .fst.txt) or a k2-converted
+        LM (in pytorch .pt format).
+    cache: bool
+        Whether or not to load/cache the LM from/to the .pt format (in the same dir).
+
+    Returns
+    -------
+    G: k2.Fsa
+        An FSA representing the LM.
+    """
+    path = str(path)
+    if os.path.exists(path.replace(".fst.txt", ".pt")) and cache:
+        logger.warning(
+            f"Loading '{path}' from its cached .pt format."
+            " Set 'caching: False' in the yaml"
+            " if this is not what you want."
+        )
+        G = k2.Fsa.from_dict(
+            torch.load(path.replace(".fst.txt", ".pt"), map_location="cpu")
+        )
+        return G
+
+    logger.info(f"Loading G LM: {path}")
+    # If G_path is an fst.txt file then convert to .pt file
+    if not os.path.isfile(path):
+        raise FileNotFoundError(
+            f"File {path} not found. " "You need to run arpa_to_fst to get it."
+        )
+    with open(path) as f:
+        G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+        torch.save(G.as_dict(), path[:-8] + ".pt")
+    return G
+
+
+def prepare_rescoring_G(G: k2.Fsa) -> k2.Fsa:
+    """
+    Prepare a LM with the purpose of using it for LM rescoring.
+    For instance, in the librispeech recipe this is a 4-gram LM (while a
+    3gram LM is used for HLG construction).
+
+    Arguments
+    ---------
+    G: k2.Fsa
+        An FSA representing the LM.
+
+    Returns
+    -------
+    G: k2.Fsa
+        An FSA representing the LM, with the following modifications:
+        - G.aux_labels is removed
+        - G.lm_scores is set to G.scores
+        - G is arc-sorted
+    """
+    if "_properties" in G.__dict__:
+        G.__dict__["_properties"] = None
+    del G.aux_labels
+    G = k2.Fsa.from_fsas([G]).to("cpu")  # only used for decoding
+    G = k2.arc_sort(G)
+    G = k2.add_epsilon_self_loops(G)
+    G = k2.arc_sort(G)
+    # G.lm_scores is used to replace HLG.lm_scores during LM rescoring.
+    if not hasattr(G, "lm_scores"):
+        G.lm_scores = G.scores.clone()
+    return G
diff --git a/speechbrain/lm/arpa.py b/speechbrain/lm/arpa.py
index 0d287dd7f74744d1b12b4070e769cdeccff93f0b..4fad9034bdfba441486ead0f424700590e4dbba6 100644
--- a/speechbrain/lm/arpa.py
+++ b/speechbrain/lm/arpa.py
@@ -58,9 +58,12 @@ Example
 
 Authors
  * Aku Rouhe 2020
+ * Pierre Champion 2023
 """
 import collections
 import logging
+from pathlib import Path
+from typing import Union
 
 logger = logging.getLogger(__name__)
 
@@ -227,3 +230,112 @@ def _parse_order(line):
 
 def _ends_arpa(line):
     return line == "\\end\\"
+
+
+def arpa_to_fst(
+    words_txt: Union[str, Path],
+    in_arpa: Union[str, Path],
+    out_fst: Union[str, Path],
+    ngram_order: int,
+    disambig_symbol: str = "#0",
+    cache: bool = True,
+):
+    r"""
+    Use kaldilm to convert an ARPA LM to FST. For example, you could use
+    speechbrain.lm.train_ngram to create an ARPA LM and then use this function
+    to convert it to an FST.
+
+    It is worth noting that if the fst already exists in the output_dir,
+    then they will not be converted again (so you may need to delete them
+    by hand if you, at any point, change your ARPA model).
+
+    Arguments
+    ---------
+    words_txt: str | Path
+        path to the words.txt file created by prepare_lang.
+    in_arpa: str | Path
+        Path to an ARPA LM to convert to an FST.
+    out_fst: str | Path
+        Path to where the fst will be saved.
+    ngram_order: int
+        ARPA (and FST) ngram order.
+    disambig_symbol: str
+        the disambiguation symbol to use.
+    cache: bool
+        Whether or not to re-create the fst.txt file if it already exist.
+
+    Raises
+    ---------
+        ImportError: If kaldilm is not installed.
+
+    Example
+    -------
+    >>> from speechbrain.lm.arpa import arpa_to_fst
+
+    >>> # Create a small arpa model
+    >>> arpa_file = getfixture('tmpdir').join("bigram.arpa")
+    >>> arpa_file.write(
+    ...     "Anything can be here\n"
+    ...     + "\n"
+    ...     + "\\data\\\n"
+    ...     + "ngram 1=3\n"
+    ...     + "ngram 2=4\n"
+    ...     + "\n"
+    ...     + "\\1-grams:\n"
+    ...     + "0 <s>\n"
+    ...     + "-0.6931 a\n"
+    ...     + "-0.6931 b 0.\n"
+    ...     + "" # Ends unigram section
+    ...     + "\\2-grams:\n"
+    ...     + "-0.6931 <s> a\n"
+    ...     + "-0.6931 a a\n"
+    ...     + "-0.6931 a b\n"
+    ...     + "-0.6931 b a\n"
+    ...     + "\n"  # Ends bigram section
+    ...     + "\\end\\\n")  # Ends whole file
+    >>> # Create words vocab
+    >>> vocav = getfixture('tmpdir').join("words.txt")
+    >>> vocav.write(
+    ...     "a 1\n"
+    ...     + "b 2\n"
+    ...     + "<s> 3\n"
+    ...     + "#0 4")  # Ends whole file
+    >>> out = getfixture('tmpdir').join("bigram.txt.fst")
+    >>> arpa_to_fst(vocav, arpa_file, out, 2)
+    """
+    try:
+        from kaldilm.arpa2fst import arpa2fst
+    except ImportError:
+        # This error will occur when there is fst LM in the provided lm_dir
+        # and we are trying to create it by converting an ARPA LM to FST.
+        # For this, we need to install kaldilm.
+        raise ImportError(
+            "Optional dependencies must be installed to use kaldilm.\n"
+            "Install using `pip install kaldilm`."
+        )
+
+    if cache and out_fst.exists():
+        return
+    if not in_arpa.exists():
+        raise FileNotFoundError(
+            f"{in_arpa} not found while trying to create"
+            f" the {ngram_order} FST."
+        )
+    try:
+        logger.info(f"Converting arpa LM '{in_arpa}' to FST")
+        s = arpa2fst(
+            input_arpa=str(in_arpa),
+            disambig_symbol=disambig_symbol,
+            read_symbol_table=str(words_txt),
+            max_order=ngram_order,
+        )
+    except Exception as e:
+        logger.info(
+            f"Failed to create {ngram_order}-gram FST from input={in_arpa}"
+            f", disambig_symbol={disambig_symbol},"
+            f" read_symbol_table={words_txt}"
+        )
+        raise e
+    logger.info(f"Writing {out_fst}")
+    with open(out_fst, "w") as f:
+        f.write(s)
diff --git a/speechbrain/lobes/augment.py b/speechbrain/lobes/augment.py
deleted file mode 100644
index bf64da37ba808e38a6d19f5d12beec3044ffb0eb..0000000000000000000000000000000000000000
--- a/speechbrain/lobes/augment.py
+++ /dev/null
@@ -1,563 +0,0 @@
-"""
-Combinations of processing algorithms to implement common augmentations.
-
-Examples:
- * SpecAugment
- * Environmental corruption (noise, reverberation)
-
-Authors
- * Peter Plantinga 2020
- * Jianyuan Zhong 2020
-"""
-import os
-import torch
-import torchaudio
-import speechbrain as sb
-from speechbrain.utils.data_utils import download_file
-from speechbrain.processing.speech_augmentation import (
-    SpeedPerturb,
-    DropFreq,
-    DropChunk,
-    AddBabble,
-    AddNoise,
-    AddReverb,
-)
-from speechbrain.utils.torch_audio_backend import check_torchaudio_backend
-
-check_torchaudio_backend()
-
-OPENRIR_URL = "http://www.openslr.org/resources/28/rirs_noises.zip"
-
-
-class SpecAugment(torch.nn.Module):
-    """An implementation of the SpecAugment algorithm.
-
-    Reference:
-        https://arxiv.org/abs/1904.08779
-
-    Arguments
-    ---------
-    time_warp : bool
-        Whether applying time warping.
-    time_warp_window : int
-        Time warp window.
-    time_warp_mode : str
-        Interpolation mode for time warping (default "bicubic").
-    freq_mask : bool
-        Whether applying freq mask.
-    freq_mask_width : int or tuple
-        Freq mask width range.
-    n_freq_mask : int
-        Number of freq mask.
-    time_mask : bool
-        Whether applying time mask.
-    time_mask_width : int or tuple
-        Time mask width range.
-    n_time_mask : int
-        Number of time mask.
-    replace_with_zero : bool
-        If True, replace masked value with 0, else replace masked value with mean of the input tensor.
-
-    Example
-    -------
-    >>> aug = SpecAugment()
-    >>> a = torch.rand([8, 120, 80])
-    >>> a = aug(a)
-    >>> print(a.shape)
-    torch.Size([8, 120, 80])
-    """
-
-    def __init__(
-        self,
-        time_warp=True,
-        time_warp_window=5,
-        time_warp_mode="bicubic",
-        freq_mask=True,
-        freq_mask_width=(0, 20),
-        n_freq_mask=2,
-        time_mask=True,
-        time_mask_width=(0, 100),
-        n_time_mask=2,
-        replace_with_zero=True,
-    ):
-        super().__init__()
-        assert (
-            time_warp or freq_mask or time_mask
-        ), "at least one of time_warp, time_mask, or freq_mask should be applied"
-
-        self.apply_time_warp = time_warp
-        self.time_warp_window = time_warp_window
-        self.time_warp_mode = time_warp_mode
-        self.freq_mask = freq_mask
-        if isinstance(freq_mask_width, int):
-            freq_mask_width = (0, freq_mask_width)
-        self.freq_mask_width = freq_mask_width
-        self.n_freq_mask = n_freq_mask
-
-        self.time_mask = time_mask
-        if isinstance(time_mask_width, int):
-            time_mask_width = (0, time_mask_width)
-        self.time_mask_width = time_mask_width
-        self.n_time_mask = n_time_mask
-
-        self.replace_with_zero = replace_with_zero
-
-    def forward(self, x):
-        """Takes in input a tensors and returns an augmented one."""
-        if self.apply_time_warp:
-            x = self.time_warp(x)
-        if self.freq_mask:
-            x = self.mask_along_axis(x, dim=2)
-        if self.time_mask:
-            x = self.mask_along_axis(x, dim=1)
-        return x
-
-    def time_warp(self, x):
-        """Time warping with torch.nn.functional.interpolate"""
-        original_size = x.shape
-        window = self.time_warp_window
-
-        # 2d interpolation requires 4D or higher dimension tensors
-        # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
-        if x.dim() == 3:
-            x = x.unsqueeze(1)
-
-        time = x.shape[2]
-        if time - window <= window:
-            return x.view(*original_size)
-
-        # compute center and corresponding window
-        c = torch.randint(window, time - window, (1,))[0]
-        w = torch.randint(c - window, c + window, (1,))[0] + 1
-
-        left = torch.nn.functional.interpolate(
-            x[:, :, :c],
-            (w, x.shape[3]),
-            mode=self.time_warp_mode,
-            align_corners=True,
-        )
-        right = torch.nn.functional.interpolate(
-            x[:, :, c:],
-            (time - w, x.shape[3]),
-            mode=self.time_warp_mode,
-            align_corners=True,
-        )
-
-        x[:, :, :w] = left
-        x[:, :, w:] = right
-
-        return x.view(*original_size)
-
-    def mask_along_axis(self, x, dim):
-        """Mask along time or frequency axis.
-
-        Arguments
-        ---------
-        x : tensor
-            Input tensor.
-        dim : int
-            Corresponding dimension to mask.
-        """
-        original_size = x.shape
-        if x.dim() == 4:
-            x = x.view(-1, x.shape[2], x.shape[3])
-
-        batch, time, fea = x.shape
-
-        if dim == 1:
-            D = time
-            n_mask = self.n_time_mask
-            width_range = self.time_mask_width
-        else:
-            D = fea
-            n_mask = self.n_freq_mask
-            width_range = self.freq_mask_width
-
-        mask_len = torch.randint(
-            width_range[0], width_range[1], (batch, n_mask), device=x.device
-        ).unsqueeze(2)
-
-        mask_pos = torch.randint(
-            0, max(1, D - mask_len.max()), (batch, n_mask), device=x.device
-        ).unsqueeze(2)
-
-        # compute masks
-        arange = torch.arange(D, device=x.device).view(1, 1, -1)
-        mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
-        mask = mask.any(dim=1)
-
-        if dim == 1:
-            mask = mask.unsqueeze(2)
-        else:
-            mask = mask.unsqueeze(1)
-
-        if self.replace_with_zero:
-            val = 0.0
-        else:
-            with torch.no_grad():
-                val = x.mean()
-
-        x = x.masked_fill_(mask, val)
-        return x.view(*original_size)
-
-
-class TimeDomainSpecAugment(torch.nn.Module):
-    """A time-domain approximation of the SpecAugment algorithm.
-
-    This augmentation module implements three augmentations in
-    the time-domain.
-
-     1. Drop chunks of the audio (zero amplitude or white noise)
-     2. Drop frequency bands (with band-drop filters)
-     3. Speed peturbation (via resampling to slightly different rate)
-
-    Arguments
-    ---------
-    perturb_prob : float from 0 to 1
-        The probability that a batch will have speed perturbation applied.
-    drop_freq_prob : float from 0 to 1
-        The probability that a batch will have frequencies dropped.
-    drop_chunk_prob : float from 0 to 1
-        The probability that a batch will have chunks dropped.
-    speeds : list of ints
-        A set of different speeds to use to perturb each batch.
-        See ``speechbrain.processing.speech_augmentation.SpeedPerturb``
-    sample_rate : int
-        Sampling rate of the input waveforms.
-    drop_freq_count_low : int
-        Lowest number of frequencies that could be dropped.
-    drop_freq_count_high : int
-        Highest number of frequencies that could be dropped.
-    drop_chunk_count_low : int
-        Lowest number of chunks that could be dropped.
-    drop_chunk_count_high : int
-        Highest number of chunks that could be dropped.
-    drop_chunk_length_low : int
-        Lowest length of chunks that could be dropped.
-    drop_chunk_length_high : int
-        Highest length of chunks that could be dropped.
-    drop_chunk_noise_factor : float
-        The noise factor used to scale the white noise inserted, relative to
-        the average amplitude of the utterance. Default 0 (no noise inserted).
-
-    Example
-    -------
-    >>> inputs = torch.randn([10, 16000])
-    >>> feature_maker = TimeDomainSpecAugment(speeds=[80])
-    >>> feats = feature_maker(inputs, torch.ones(10))
-    >>> feats.shape
-    torch.Size([10, 12800])
-    """
-
-    def __init__(
-        self,
-        perturb_prob=1.0,
-        drop_freq_prob=1.0,
-        drop_chunk_prob=1.0,
-        speeds=[95, 100, 105],
-        sample_rate=16000,
-        drop_freq_count_low=0,
-        drop_freq_count_high=3,
-        drop_chunk_count_low=0,
-        drop_chunk_count_high=5,
-        drop_chunk_length_low=1000,
-        drop_chunk_length_high=2000,
-        drop_chunk_noise_factor=0,
-    ):
-        super().__init__()
-        self.speed_perturb = SpeedPerturb(
-            perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds
-        )
-        self.drop_freq = DropFreq(
-            drop_prob=drop_freq_prob,
-            drop_count_low=drop_freq_count_low,
-            drop_count_high=drop_freq_count_high,
-        )
-        self.drop_chunk = DropChunk(
-            drop_prob=drop_chunk_prob,
-            drop_count_low=drop_chunk_count_low,
-            drop_count_high=drop_chunk_count_high,
-            drop_length_low=drop_chunk_length_low,
-            drop_length_high=drop_chunk_length_high,
-            noise_factor=drop_chunk_noise_factor,
-        )
-
-    def forward(self, waveforms, lengths):
-        """Returns the distorted waveforms.
-
-        Arguments
-        ---------
-        waveforms : torch.Tensor
-            The waveforms to distort
-        """
-        # Augmentation
-        with torch.no_grad():
-            waveforms = self.speed_perturb(waveforms)
-            waveforms = self.drop_freq(waveforms)
-            waveforms = self.drop_chunk(waveforms, lengths)
-
-        return waveforms
-
-
-class EnvCorrupt(torch.nn.Module):
-    """Environmental Corruptions for speech signals: noise, reverb, babble.
-
-    Arguments
-    ---------
-    reverb_prob : float from 0 to 1
-        The probability that each batch will have reverberation applied.
-    babble_prob : float from 0 to 1
-        The probability that each batch will have babble added.
-    noise_prob : float from 0 to 1
-        The probability that each batch will have noise added.
-    openrir_folder : str
-        If provided, download and prepare openrir to this location. The
-        reverberation csv and noise csv will come from here unless overridden
-        by the ``reverb_csv`` or ``noise_csv`` arguments.
-    openrir_max_noise_len : float
-        The maximum length in seconds for a noise segment from openrir. Only
-        takes effect if ``openrir_folder`` is used for noises. Cuts longer
-        noises into segments equal to or less than this length.
-    reverb_csv : str
-        A prepared csv file for loading room impulse responses.
-    noise_csv : str
-        A prepared csv file for loading noise data.
-    noise_num_workers : int
-        Number of workers to use for loading noises.
-    babble_speaker_count : int
-        Number of speakers to use for babble. Must be less than batch size.
-    babble_snr_low : int
-        Lowest generated SNR of reverbed signal to babble.
-    babble_snr_high : int
-        Highest generated SNR of reverbed signal to babble.
-    noise_snr_low : int
-        Lowest generated SNR of babbled signal to noise.
-    noise_snr_high : int
-        Highest generated SNR of babbled signal to noise.
-    rir_scale_factor : float
-        It compresses or dilates the given impulse response.
-        If ``0 < rir_scale_factor < 1``, the impulse response is compressed
-        (less reverb), while if ``rir_scale_factor > 1`` it is dilated
-        (more reverb).
-    reverb_sample_rate : int
-        Sample rate of input audio signals (rirs) used for reverberation.
-    noise_sample_rate: int
-        Sample rate of input audio signals used for adding noise.
-    clean_sample_rate: int
-        Sample rate of original (clean) audio signals.
-
-    Example
-    -------
-    >>> inputs = torch.randn([10, 16000])
-    >>> corrupter = EnvCorrupt(babble_speaker_count=9)
-    >>> feats = corrupter(inputs, torch.ones(10))
-    """
-
-    def __init__(
-        self,
-        reverb_prob=1.0,
-        babble_prob=1.0,
-        noise_prob=1.0,
-        openrir_folder=None,
-        openrir_max_noise_len=None,
-        reverb_csv=None,
-        noise_csv=None,
-        noise_num_workers=0,
-        babble_speaker_count=0,
-        babble_snr_low=0,
-        babble_snr_high=0,
-        noise_snr_low=0,
-        noise_snr_high=0,
-        rir_scale_factor=1.0,
-        reverb_sample_rate=16000,
-        noise_sample_rate=16000,
-        clean_sample_rate=16000,
-    ):
-        super().__init__()
-
-        # Download and prepare openrir
-        if openrir_folder and (not reverb_csv or not noise_csv):
-
-            open_reverb_csv = os.path.join(openrir_folder, "reverb.csv")
-            open_noise_csv = os.path.join(openrir_folder, "noise.csv")
-            _prepare_openrir(
-                openrir_folder,
-                open_reverb_csv,
-                open_noise_csv,
-                openrir_max_noise_len,
-            )
-
-            # Specify filepath and sample rate if not specified already
-            if not reverb_csv:
-                reverb_csv = open_reverb_csv
-                reverb_sample_rate = 16000
-
-            if not noise_csv:
-                noise_csv = open_noise_csv
-                noise_sample_rate = 16000
-
-        # Initialize corrupters
-        if reverb_csv is not None and reverb_prob > 0.0:
-            self.add_reverb = AddReverb(
-                reverb_prob=reverb_prob,
-                csv_file=reverb_csv,
-                replacements={"rir_root": openrir_folder},
-                rir_scale_factor=rir_scale_factor,
-                reverb_sample_rate=reverb_sample_rate,
-                clean_sample_rate=clean_sample_rate,
-            )
-
-        if babble_speaker_count > 0 and babble_prob > 0.0:
-            self.add_babble = AddBabble(
-                mix_prob=babble_prob,
-                speaker_count=babble_speaker_count,
-                snr_low=babble_snr_low,
-                snr_high=babble_snr_high,
-            )
-
-        if noise_csv is not None and noise_prob > 0.0:
-            self.add_noise = AddNoise(
-                mix_prob=noise_prob,
-                csv_file=noise_csv,
-                replacements={"rir_root": openrir_folder},
-                num_workers=noise_num_workers,
-                snr_low=noise_snr_low,
-                snr_high=noise_snr_high,
-                noise_sample_rate=noise_sample_rate,
-                clean_sample_rate=clean_sample_rate,
-            )
-
-    def forward(self, waveforms, lengths):
-        """Returns the distorted waveforms.
-
-        Arguments
-        ---------
-        waveforms : torch.Tensor
-            The waveforms to distort.
-        """
-        # Augmentation
-        with torch.no_grad():
-            if hasattr(self, "add_reverb"):
-                try:
-                    waveforms = self.add_reverb(waveforms, lengths)
-                except Exception:
-                    pass
-            if hasattr(self, "add_babble"):
-                waveforms = self.add_babble(waveforms, lengths)
-            if hasattr(self, "add_noise"):
-                waveforms = self.add_noise(waveforms, lengths)
-
-        return waveforms
-
-
-def _prepare_openrir(folder, reverb_csv, noise_csv, max_noise_len):
-    """Prepare the openrir dataset for adding reverb and noises.
-
-    Arguments
-    ---------
-    folder : str
-        The location of the folder containing the dataset.
-    reverb_csv : str
-        Filename for storing the prepared reverb csv.
-    noise_csv : str
-        Filename for storing the prepared noise csv.
-    max_noise_len : float
-        The maximum noise length in seconds. Noises longer
-        than this will be cut into pieces.
-    """
-
-    # Download and unpack if necessary
-    filepath = os.path.join(folder, "rirs_noises.zip")
-
-    if not os.path.isdir(os.path.join(folder, "RIRS_NOISES")):
-        download_file(OPENRIR_URL, filepath, unpack=True)
-    else:
-        download_file(OPENRIR_URL, filepath)
-
-    # Prepare reverb csv if necessary
-    if not os.path.isfile(reverb_csv):
-        rir_filelist = os.path.join(
-            folder, "RIRS_NOISES", "real_rirs_isotropic_noises", "rir_list"
-        )
-        _prepare_csv(folder, rir_filelist, reverb_csv)
-
-    # Prepare noise csv if necessary
-    if not os.path.isfile(noise_csv):
-        noise_filelist = os.path.join(
-            folder, "RIRS_NOISES", "pointsource_noises", "noise_list"
-        )
-        _prepare_csv(folder, noise_filelist, noise_csv, max_noise_len)
-
-
-def _prepare_csv(folder, filelist, csv_file, max_length=None):
-    """Iterate a set of wavs and write the corresponding csv file.
-
-    Arguments
-    ---------
-    folder : str
-        The folder relative to which the files in the list are listed.
-    filelist : str
-        The location of a file listing the files to be used.
-    csvfile : str
-        The location to use for writing the csv file.
-    max_length : float
-        The maximum length in seconds. Waveforms longer
-        than this will be cut into pieces.
-    """
-    try:
-        # make sure all processing reached here before main preocess create csv_file
-        sb.utils.distributed.ddp_barrier()
-        if sb.utils.distributed.if_main_process():
-            with open(csv_file, "w") as w:
-                w.write("ID,duration,wav,wav_format,wav_opts\n\n")
-                for line in open(filelist):
-
-                    # Read file for duration/channel info
-                    filename = os.path.join(folder, line.split()[-1])
-                    signal, rate = torchaudio.load(filename)
-
-                    # Ensure only one channel
-                    if signal.shape[0] > 1:
-                        signal = signal[0].unsqueeze(0)
-                        torchaudio.save(filename, signal, rate)
-
-                    ID, ext = os.path.basename(filename).split(".")
-                    duration = signal.shape[1] / rate
-
-                    # Handle long waveforms
-                    if max_length is not None and duration > max_length:
-                        # Delete old file
-                        os.remove(filename)
-                        for i in range(int(duration / max_length)):
-                            start = int(max_length * i * rate)
-                            stop = int(
-                                min(max_length * (i + 1), duration) * rate
-                            )
-                            new_filename = (
-                                filename[: -len(f".{ext}")] + f"_{i}.{ext}"
-                            )
-                            torchaudio.save(
-                                new_filename, signal[:, start:stop], rate
-                            )
-                            csv_row = (
-                                f"{ID}_{i}",
-                                str((stop - start) / rate),
-                                "$rir_root/" + new_filename[len(folder) :],
-                                ext,
-                                "\n",
-                            )
-                            w.write(",".join(csv_row))
-                    else:
-                        w.write(
-                            ",".join(
-                                (
-                                    ID,
-                                    str(duration),
-                                    "$rir_root/" + filename[len(folder) :],
-                                    ext,
-                                    "\n",
-                                )
-                            )
-                        )
-    finally:
-        sb.utils.distributed.ddp_barrier()
diff --git a/speechbrain/lobes/features.py b/speechbrain/lobes/features.py
index 416eb7af4358c8b5ced7c1a5eb1cc98ca9b15f95..de2bbfaab93f37628570d30d8d0305b2b16dd9e2 100644
--- a/speechbrain/lobes/features.py
+++ b/speechbrain/lobes/features.py
@@ -4,8 +4,11 @@ Authors
  * Mirco Ravanelli 2020
  * Peter Plantinga 2020
  * Sarthak Yadav 2020
+ * Sylvain de Langen 2024
 """
+from dataclasses import dataclass
 import torch
+from typing import Optional
 from speechbrain.processing.features import (
     STFT,
     spectral_magnitude,
@@ -14,6 +17,8 @@ from speechbrain.processing.features import (
     Deltas,
     ContextWindow,
 )
+from speechbrain.utils.autocast import fwd_default_precision
+from speechbrain.utils.filter_analysis import FilterProperties
 from speechbrain.nnet.CNN import GaborConv1d
 from speechbrain.nnet.normalization import PCEN
 from speechbrain.nnet.pooling import GaussianLowpassPooling
@@ -127,6 +132,7 @@ class Fbank(torch.nn.Module):
             left_frames=left_frames, right_frames=right_frames,
         )
 
+    @fwd_default_precision(cast_inputs=torch.float32)
     def forward(self, wav):
         """Returns a set of features generated from the input waveforms.
 
@@ -146,6 +152,10 @@ class Fbank(torch.nn.Module):
             fbanks = self.context_window(fbanks)
         return fbanks
 
+    def get_filter_properties(self) -> FilterProperties:
+        # only the STFT affects the FilterProperties of the Fbank
+        return self.compute_STFT.get_filter_properties()
+
 
 class MFCC(torch.nn.Module):
     """Generate features for input to the speech pipeline.
@@ -260,6 +270,7 @@ class MFCC(torch.nn.Module):
             left_frames=left_frames, right_frames=right_frames,
         )
 
+    @fwd_default_precision(cast_inputs=torch.float32)
     def forward(self, wav):
         """Returns a set of mfccs generated from the input waveforms.
 
@@ -387,6 +398,7 @@ class Leaf(torch.nn.Module):
             self.compression = None
         self.skip_transpose = skip_transpose
 
+    @fwd_default_precision(cast_inputs=torch.float32)
     def forward(self, x):
         """
         Returns the learned LEAF features
@@ -437,3 +449,185 @@ class Leaf(torch.nn.Module):
                 "Leaf expects 2d or 3d inputs. Got " + str(len(shape))
             )
         return in_channels
+
+
+def upalign_value(x, to: int) -> int:
+    """If `x` cannot evenly divide `to`, round it up to the next value that
+    can."""
+
+    assert x >= 0
+
+    if (x % to) == 0:
+        return x
+
+    return x + to - (x % to)
+
+
+@dataclass
+class StreamingFeatureWrapperContext:
+    """Streaming metadata for the feature extractor. Holds some past context
+    frames."""
+
+    left_context: Optional[torch.Tensor]
+    """Cached left frames to be inserted as left padding for the next chunk.
+    Initially `None` then gets updated from the last frames of the current
+    chunk.
+    See the relevant `forward` function for details."""
+
+
+class StreamingFeatureWrapper(torch.nn.Module):
+    """Wraps an arbitrary filter so that it can be used in a streaming fashion
+    (i.e. on a per-chunk basis), by remembering context and making "clever" use
+    of padding.
+
+    Arguments
+    ---------
+    module : torch.nn.Module
+        The filter to wrap; e.g. a module list that constitutes a sequential
+        feature extraction pipeline.
+        The module is assumed to pad its inputs, e.g. the output of a
+        convolution with a stride of 1 would end up with the same frame count
+        as the input.
+
+    properties : FilterProperties
+        The effective filter properties of the provided module. This is used to
+        determine padding and caching.
+    """
+
+    def __init__(self, module: torch.nn.Module, properties: FilterProperties):
+        super().__init__()
+
+        self.module = module
+        self.properties = properties
+
+        if self.properties.causal:
+            raise ValueError(
+                "Causal streaming feature wrapper is not yet supported"
+            )
+
+        if self.properties.dilation != 1:
+            raise ValueError(
+                "Dilation not yet supported in streaming feature wrapper"
+            )
+
+    def get_required_padding(self) -> int:
+        """Computes the number of padding/context frames that need to be
+        injected at the past and future of the input signal in the forward pass.
+        """
+
+        return upalign_value(
+            (self.properties.window_size - 1) // 2, self.properties.stride
+        )
+
+    def get_output_count_per_pad_frame(self) -> int:
+        """Computes the exact number of produced frames (along the time
+        dimension) per input pad frame."""
+
+        return self.get_required_padding() // self.properties.stride
+
+    def get_recommended_final_chunk_count(self, frames_per_chunk: int) -> int:
+        """Get the recommended number of zero chunks to inject at the end of an
+        input stream depending on the filter properties of the extractor.
+
+        The number of injected chunks is chosen to ensure that the filter has
+        output frames centered on the last input frames.
+        See also :meth:`~StreamingFeatureWrapper.forward`.
+
+        Arguments
+        ---------
+        frames_per_chunk : int
+            The number of frames per chunk, i.e. the size of the time dimension
+            passed to :meth:`~StreamingFeatureWrapper.forward`."""
+
+        return (
+            upalign_value(self.get_required_padding(), frames_per_chunk)
+            // frames_per_chunk
+        )
+
+    def forward(
+        self,
+        chunk: torch.Tensor,
+        context: StreamingFeatureWrapperContext,
+        *extra_args,
+        **extra_kwargs,
+    ) -> torch.Tensor:
+        """Forward pass for the streaming feature wrapper.
+
+        For the first chunk, 0-padding is inserted at the past of the input.
+        For any chunk (including the first), some future frames get truncated
+        and cached to be inserted as left context for the next chunk in time.
+
+        For further explanations, see the comments in the code.
+
+        Note that due to how the padding is implemented, you may want to call
+        this with a chunk worth full of zeros (potentially more for filters with
+        large windows) at the end of your input so that the final frames have a
+        chance to get processed by the filter.
+        See :meth:`~StreamingFeatureWrapper.get_recommended_final_chunk_count`.
+        This is not really an issue when processing endless streams, but when
+        processing files, it could otherwise result in truncated outputs.
+
+        Arguments
+        ---------
+        chunk : torch.Tensor
+            Chunk of input of shape [batch size, time]; typically a raw
+            waveform. Normally, in a chunkwise streaming scenario,
+            `time = (stride-1) * chunk_size` where `chunk_size` is the desired
+            **output** frame count.
+        context : StreamingFeatureWrapperContext
+            Mutable streaming context object; should be reused for subsequent
+            calls in the same streaming session.
+
+        Returns
+        -------
+        torch.Tensor
+            Processed chunk of shape [batch size, output frames]. This shape is
+            equivalent to the shape of `module(chunk)`.
+        """
+
+        feat_pad_size = self.get_required_padding()
+        num_outputs_per_pad = self.get_output_count_per_pad_frame()
+
+        # consider two audio chunks of 6 samples (for the example), where
+        # each sample is denoted by 1, 2, ..., 6
+        # so chunk 1 is 123456 and chunk 2 is 123456
+        if context.left_context is None:
+            # for the first chunk we left pad the input by two padding's worth of zeros,
+            # and truncate the right, so that we can pretend to have right padding and
+            # still consume the same amount of samples every time
+            #
+            # our first processed chunk will look like:
+            # 0000123456
+            #         ^^ right padding (truncated)
+            #   ^^^^^^ frames that some outputs are centered on
+            # ^^ left padding (truncated)
+            chunk = torch.nn.functional.pad(chunk, (feat_pad_size * 2, 0))
+        else:
+            # prepend left context
+            #
+            # for the second chunk ownwards, given the above example:
+            # 34 of the previous chunk becomes left padding
+            # 56 of the previous chunk becomes the first frames of this chunk
+            # thus on the second iteration (and onwards) it will look like:
+            # 3456123456
+            #         ^^ right padding (truncated)
+            #   ^^^^^^ frames that some outputs are centered on
+            # ^^ left padding (truncated)
+            chunk = torch.cat((context.left_context, chunk), 1)
+
+        # our chunk's right context will become the start of the "next processed chunk"
+        # plus we need left padding for that one, so make it double
+        context.left_context = chunk[:, -feat_pad_size * 2 :]
+
+        feats = self.module(chunk, *extra_args, **extra_kwargs)
+
+        # truncate left and right context
+        feats = feats[:, num_outputs_per_pad:-num_outputs_per_pad, ...]
+
+        return feats
+
+    def get_filter_properties(self) -> FilterProperties:
+        return self.properties
+
+    def make_streaming_context(self) -> StreamingFeatureWrapperContext:
+        return StreamingFeatureWrapperContext(None)
diff --git a/speechbrain/lobes/models/FastSpeech2.py b/speechbrain/lobes/models/FastSpeech2.py
index 026b4b7ca30e4de03fa1f1acb5f2f2f4903181e7..a5d9af547520b731638cfc6e3f566b06bfe5772a 100644
--- a/speechbrain/lobes/models/FastSpeech2.py
+++ b/speechbrain/lobes/models/FastSpeech2.py
@@ -15,50 +15,13 @@ from speechbrain.nnet import CNN, linear
 from speechbrain.nnet.embedding import Embedding
 from speechbrain.lobes.models.transformer.Transformer import (
     TransformerEncoder,
+    PositionalEncoding,
     get_key_padding_mask,
+    get_mask_from_lengths,
 )
 from speechbrain.nnet.normalization import LayerNorm
 from speechbrain.nnet.losses import bce_loss
-
-
-class PositionalEmbedding(nn.Module):
-    """Computation of the positional embeddings.
-    Arguments
-    ---------
-    embed_dim: int
-        dimensionality of the embeddings.
-    """
-
-    def __init__(self, embed_dim):
-        super(PositionalEmbedding, self).__init__()
-        self.demb = embed_dim
-        inv_freq = 1 / (
-            10000 ** (torch.arange(0.0, embed_dim, 2.0) / embed_dim)
-        )
-        self.register_buffer("inv_freq", inv_freq)
-
-    def forward(self, seq_len, mask, dtype):
-        """Computes the forward pass
-        Arguments
-        ---------
-        seq_len: int
-            length of the sequence
-        mask: torch.tensor
-            mask applied to the positional embeddings
-        dtype: str
-            dtype of the embeddings
-        Returns
-        -------
-        pos_emb: torch.Tensor
-            the tensor with positional embeddings
-        """
-        pos_seq = torch.arange(seq_len, device=mask.device).to(dtype)
-
-        sinusoid_inp = torch.matmul(
-            torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0)
-        )
-        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
-        return pos_emb[None, :, :] * mask[:, :, None]
+import numpy as np
 
 
 class EncoderPreNet(nn.Module):
@@ -322,7 +285,7 @@ class SPNPredictor(nn.Module):
             n_char, padding_idx, out_channels=enc_d_model
         )
 
-        self.sinusoidal_positional_embed_encoder = PositionalEmbedding(
+        self.sinusoidal_positional_embed_encoder = PositionalEncoding(
             enc_d_model
         )
 
@@ -365,9 +328,7 @@ class SPNPredictor(nn.Module):
 
         srcmask = get_key_padding_mask(tokens, pad_idx=self.padding_idx)
         srcmask_inverted = (~srcmask).unsqueeze(-1)
-        pos = self.sinusoidal_positional_embed_encoder(
-            token_feats.shape[1], srcmask, token_feats.dtype
-        )
+        pos = self.sinusoidal_positional_embed_encoder(token_feats)
         token_feats = torch.add(token_feats, pos) * srcmask_inverted
 
         spn_mask = (
@@ -414,7 +375,7 @@ class FastSpeech2(nn.Module):
     This class is the main entry point for the model, which is responsible
     for instantiating all submodules, which, in turn, manage the individual
     neural network layers
-    Simplified STRUCTURE: input->token embedding ->encoder ->duration predictor ->duration
+    Simplified STRUCTURE: input->token embedding ->encoder ->duration/pitch/energy predictor ->duration
     upsampler -> decoder -> output
     During training, teacher forcing is used (ground truth durations are used for upsampling)
     Arguments
@@ -567,10 +528,10 @@ class FastSpeech2(nn.Module):
         self.enc_num_head = enc_num_head
         self.dec_num_head = dec_num_head
         self.padding_idx = padding_idx
-        self.sinusoidal_positional_embed_encoder = PositionalEmbedding(
+        self.sinusoidal_positional_embed_encoder = PositionalEncoding(
             enc_d_model
         )
-        self.sinusoidal_positional_embed_decoder = PositionalEmbedding(
+        self.sinusoidal_positional_embed_decoder = PositionalEncoding(
             dec_d_model
         )
 
@@ -700,9 +661,7 @@ class FastSpeech2(nn.Module):
 
         # prenet & encoder
         token_feats = self.encPreNet(tokens)
-        pos = self.sinusoidal_positional_embed_encoder(
-            token_feats.shape[1], srcmask, token_feats.dtype
-        )
+        pos = self.sinusoidal_positional_embed_encoder(token_feats)
         token_feats = torch.add(token_feats, pos) * srcmask_inverted
         attn_mask = (
             srcmask.unsqueeze(-1)
@@ -761,7 +720,8 @@ class FastSpeech2(nn.Module):
             durations if durations is not None else dur_pred_reverse_log,
             pace=pace,
         )
-        srcmask = get_key_padding_mask(spec_feats, pad_idx=self.padding_idx)
+        srcmask = get_mask_from_lengths(torch.tensor(mel_lens))
+        srcmask = srcmask.to(spec_feats.device)
         srcmask_inverted = (~srcmask).unsqueeze(-1)
         attn_mask = (
             srcmask.unsqueeze(-1)
@@ -771,10 +731,7 @@ class FastSpeech2(nn.Module):
         )
 
         # decoder
-        pos = self.sinusoidal_positional_embed_decoder(
-            spec_feats.shape[1], srcmask, spec_feats.dtype
-        )
-
+        pos = self.sinusoidal_positional_embed_decoder(spec_feats)
         spec_feats = torch.add(spec_feats, pos) * srcmask_inverted
 
         output_mel_feats, memory, *_ = self.decoder(
@@ -1771,3 +1728,1060 @@ class _SSIMLoss(_Loss):
             k2=self.k2,
         )
         return torch.ones_like(score) - score
+
+
+class TextMelCollateWithAlignment:
+    """ Zero-pads model inputs and targets based on number of frames per step
+    result: tuple
+        a tuple of tensors to be used as inputs/targets
+        (
+            text_padded,
+            dur_padded,
+            input_lengths,
+            mel_padded,
+            output_lengths,
+            len_x,
+            labels,
+            wavs
+        )
+    """
+
+    # TODO: Make this more intuitive, use the pipeline
+    def __call__(self, batch):
+        """Collate's training batch from normalized text and mel-spectrogram
+        Arguments
+        ---------
+        batch: list
+            [text_normalized, mel_normalized]
+        """
+        # TODO: Remove for loops
+        raw_batch = list(batch)
+        for i in range(
+            len(batch)
+        ):  # the pipline return a dictionary wiht one elemnent
+            batch[i] = batch[i]["mel_text_pair"]
+
+        # Right zero-pad all one-hot text sequences to max input length
+        input_lengths, ids_sorted_decreasing = torch.sort(
+            torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
+        )
+
+        max_input_len = input_lengths[0]
+
+        phoneme_padded = torch.LongTensor(len(batch), max_input_len)
+        phoneme_padded.zero_()
+
+        for i in range(len(ids_sorted_decreasing)):
+            phoneme = batch[ids_sorted_decreasing[i]][0]
+            phoneme_padded[i, : phoneme.size(0)] = phoneme
+
+        # Right zero-pad mel-spec
+        num_mels = batch[0][1].size(0)
+        max_target_len = max([x[1].size(1) for x in batch])
+
+        # include mel padded and gate padded
+        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
+        mel_padded.zero_()
+        pitch_padded = torch.FloatTensor(len(batch), max_target_len)
+        pitch_padded.zero_()
+        energy_padded = torch.FloatTensor(len(batch), max_target_len)
+        energy_padded.zero_()
+        output_lengths = torch.LongTensor(len(batch))
+        labels, wavs = [], []
+        for i in range(len(ids_sorted_decreasing)):
+            idx = ids_sorted_decreasing[i]
+            mel = batch[idx][1]
+            pitch = batch[idx][2]
+            energy = batch[idx][3]
+            mel_padded[i, :, : mel.size(1)] = mel
+            pitch_padded[i, : pitch.size(0)] = pitch
+            energy_padded[i, : energy.size(0)] = energy
+            output_lengths[i] = mel.size(1)
+            labels.append(raw_batch[idx]["label"])
+            wavs.append(raw_batch[idx]["wav"])
+
+        mel_padded = mel_padded.permute(0, 2, 1)
+        return (
+            phoneme_padded,
+            input_lengths,
+            mel_padded,
+            pitch_padded,
+            energy_padded,
+            output_lengths,
+            labels,
+            wavs,
+        )
+
+
+def maximum_path_numpy(value, mask):
+    """
+    Monotonic alignment search algorithm, numpy works faster than the torch implementation.
+    Arguments
+    ---------
+    value: torch.Tensor
+        input alignment values [b, t_x, t_y]
+    mask: torch.Tensor
+        input alignment mask [b, t_x, t_y]
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.lobes.models.FastSpeech2 import maximum_path_numpy
+    >>> alignment = torch.rand(2, 5, 100)
+    >>> mask = torch.ones(2, 5, 100)
+    >>> hard_alignments = maximum_path_numpy(alignment, mask)
+    """
+    max_neg_val = -np.inf  # Patch for Sphinx complaint
+    value = value * mask
+
+    device = value.device
+    dtype = value.dtype
+    value = value.cpu().detach().numpy()
+    mask = mask.cpu().detach().numpy().astype(np.bool_)
+
+    b, t_x, t_y = value.shape
+    direction = np.zeros(value.shape, dtype=np.int64)
+    v = np.zeros((b, t_x), dtype=np.float32)
+    x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
+    for j in range(t_y):
+        v0 = np.pad(
+            v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val
+        )[:, :-1]
+        v1 = v
+        max_mask = v1 >= v0
+        v_max = np.where(max_mask, v1, v0)
+        direction[:, :, j] = max_mask
+
+        index_mask = x_range <= j
+        v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
+    direction = np.where(mask, direction, 1)
+
+    path = np.zeros(value.shape, dtype=np.float32)
+    index = mask[:, :, 0].sum(1).astype(np.int64) - 1
+    index_range = np.arange(b)
+    for j in reversed(range(t_y)):
+        path[index_range, index, j] = 1
+        index = index + direction[index_range, index, j] - 1
+    path = path * mask.astype(np.float32)
+    path = torch.from_numpy(path).to(device=device, dtype=dtype)
+    return path
+
+
+class AlignmentNetwork(torch.nn.Module):
+    """Learns the alignment between the input text
+    and the spectrogram with Gaussian Attention.
+
+    query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment
+    key   -> conv1d -> relu -> conv1d - - - - - - - - - - - -^
+
+    Arguments
+    ---------
+    in_query_channels: int
+        Number of channels in the query network. Defaults to 80.
+    in_key_channels: int
+        Number of channels in the key network. Defaults to 512.
+    attn_channels: int
+        Number of inner channels in the attention layers. Defaults to 80.
+    temperature: float
+        Temperature for the softmax. Defaults to 0.0005.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.lobes.models.FastSpeech2 import AlignmentNetwork
+    >>> aligner = AlignmentNetwork(
+    ...     in_query_channels=80,
+    ...     in_key_channels=512,
+    ...     attn_channels=80,
+    ...     temperature=0.0005,
+    ... )
+    >>> phoneme_feats = torch.rand(2, 512, 20)
+    >>> mels = torch.rand(2, 80, 100)
+    >>> alignment_soft, alignment_logprob = aligner(mels, phoneme_feats, None, None)
+    >>> alignment_soft.shape, alignment_logprob.shape
+    (torch.Size([2, 1, 100, 20]), torch.Size([2, 1, 100, 20]))
+    """
+
+    def __init__(
+        self,
+        in_query_channels=80,
+        in_key_channels=512,
+        attn_channels=80,
+        temperature=0.0005,
+    ):
+        super().__init__()
+        self.temperature = temperature
+        self.softmax = torch.nn.Softmax(dim=3)
+        self.log_softmax = torch.nn.LogSoftmax(dim=3)
+
+        self.key_layer = nn.Sequential(
+            CNN.Conv1d(
+                in_channels=in_key_channels,
+                out_channels=in_key_channels * 2,
+                kernel_size=3,
+                padding="same",
+                bias=True,
+                skip_transpose=True,
+            ),
+            torch.nn.ReLU(),
+            CNN.Conv1d(
+                in_channels=in_key_channels * 2,
+                out_channels=attn_channels,
+                kernel_size=1,
+                padding="same",
+                bias=True,
+                skip_transpose=True,
+            ),
+        )
+
+        self.query_layer = nn.Sequential(
+            CNN.Conv1d(
+                in_channels=in_query_channels,
+                out_channels=in_query_channels * 2,
+                kernel_size=3,
+                padding="same",
+                bias=True,
+                skip_transpose=True,
+            ),
+            torch.nn.ReLU(),
+            CNN.Conv1d(
+                in_channels=in_query_channels * 2,
+                out_channels=in_query_channels,
+                kernel_size=1,
+                padding="same",
+                bias=True,
+                skip_transpose=True,
+            ),
+            torch.nn.ReLU(),
+            CNN.Conv1d(
+                in_channels=in_query_channels,
+                out_channels=attn_channels,
+                kernel_size=1,
+                padding="same",
+                bias=True,
+                skip_transpose=True,
+            ),
+        )
+
+    def forward(self, queries, keys, mask, attn_prior):
+        """Forward pass of the aligner encoder.
+        Arguments
+        ---------
+        queries: torch.Tensor
+            the query tensor [B, C, T_de]
+        keys: torch.Tensor
+            the query tensor [B, C_emb, T_en]
+        mask: torch.Tensor
+            the query mask[B, T_de]
+        attn_prior: torch.Tensor
+            the prior attention tensor [B, 1, T_en, T_de]
+        Returns
+        ---------
+        attn: torch.tensor
+            soft attention [B, 1, T_en, T_de]
+        attn_logp: torch.tensor
+            log probabilities [B, 1, T_en , T_de]
+        """
+        key_out = self.key_layer(keys)
+        query_out = self.query_layer(queries)
+        attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
+        attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True)
+        if attn_prior is not None:
+            attn_logp = self.log_softmax(attn_logp) + torch.log(
+                attn_prior[:, None] + 1e-8
+            )
+        if mask is not None:
+            attn_logp.data.masked_fill_(
+                ~mask.bool().unsqueeze(2), -float("inf")
+            )
+        attn = self.softmax(attn_logp)
+        return attn, attn_logp
+
+
+class FastSpeech2WithAlignment(nn.Module):
+    """The FastSpeech2 text-to-speech model with internal alignment.
+    This class is the main entry point for the model, which is responsible
+    for instantiating all submodules, which, in turn, manage the individual
+    neural network layers. Certain parts are adopted from the following implementation:
+    https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/models/forward_tts.py
+
+    Simplified STRUCTURE:
+    input -> token embedding -> encoder -> aligner -> duration/pitch/energy -> upsampler -> decoder -> output
+
+    Arguments
+    ---------
+    #encoder parameters
+    enc_num_layers: int
+        number of transformer layers (TransformerEncoderLayer) in encoder
+    enc_num_head: int
+        number of multi-head-attention (MHA) heads in encoder transformer layers
+    enc_d_model: int
+        the number of expected features in the encoder
+    enc_ffn_dim: int
+        the dimension of the feedforward network model
+    enc_k_dim: int
+        the dimension of the key
+    enc_v_dim: int
+        the dimension of the value
+    enc_dropout: float
+        Dropout for the encoder
+    normalize_before: bool
+        whether normalization should be applied before or after MHA or FFN in Transformer layers.
+    ffn_type: str
+        whether to use convolutional layers instead of feed forward network inside tranformer layer
+    ffn_cnn_kernel_size_list: list of int
+        conv kernel size of 2 1d-convs if ffn_type is 1dcnn
+    #aligner parameters
+    in_query_channels: int
+        Number of channels in the query network.
+    in_key_channels: int
+        Number of channels in the key network.
+    attn_channels: int
+        Number of inner channels in the attention layers.
+    temperature: float
+        Temperature for the softmax.
+    #decoder parameters
+    dec_num_layers: int
+        number of transformer layers (TransformerEncoderLayer) in decoder
+    dec_num_head: int
+        number of multi-head-attention (MHA) heads in decoder transformer layers
+    dec_d_model: int
+        the number of expected features in the decoder
+    dec_ffn_dim: int
+        the dimension of the feedforward network model
+    dec_k_dim: int
+        the dimension of the key
+    dec_v_dim: int
+        the dimension of the value
+    dec_dropout: float
+        dropout for the decoder
+    normalize_before: bool
+        whether normalization should be applied before or after MHA or FFN in Transformer layers.
+    ffn_type: str
+        whether to use convolutional layers instead of feed forward network inside tranformer layer.
+    ffn_cnn_kernel_size_list: list of int
+        conv kernel size of 2 1d-convs if ffn_type is 1dcnn
+    n_char: int
+        the number of symbols for the token embedding
+    n_mels: int
+        number of bins in mel spectrogram
+    postnet_embedding_dim: int
+       output feature dimension for convolution layers
+    postnet_kernel_size: int
+       postnet convolution kernal size
+    postnet_n_convolutions: int
+       number of convolution layers
+    postnet_dropout: float
+        dropout probability fot postnet
+    padding_idx: int
+        the index for padding
+    dur_pred_kernel_size: int
+        the convolution kernel size in duration predictor
+    pitch_pred_kernel_size: int
+        kernel size for pitch prediction.
+    energy_pred_kernel_size: int
+        kernel size for energy prediction.
+    variance_predictor_dropout: float
+        dropout probability for variance predictor (duration/pitch/energy)
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.lobes.models.FastSpeech2 import FastSpeech2WithAlignment
+    >>> model = FastSpeech2WithAlignment(
+    ...    enc_num_layers=6,
+    ...    enc_num_head=2,
+    ...    enc_d_model=384,
+    ...    enc_ffn_dim=1536,
+    ...    enc_k_dim=384,
+    ...    enc_v_dim=384,
+    ...    enc_dropout=0.1,
+    ...    in_query_channels=80,
+    ...    in_key_channels=384,
+    ...    attn_channels=80,
+    ...    temperature=0.0005,
+    ...    dec_num_layers=6,
+    ...    dec_num_head=2,
+    ...    dec_d_model=384,
+    ...    dec_ffn_dim=1536,
+    ...    dec_k_dim=384,
+    ...    dec_v_dim=384,
+    ...    dec_dropout=0.1,
+    ...    normalize_before=False,
+    ...    ffn_type='1dcnn',
+    ...    ffn_cnn_kernel_size_list=[9, 1],
+    ...    n_char=40,
+    ...    n_mels=80,
+    ...    postnet_embedding_dim=512,
+    ...    postnet_kernel_size=5,
+    ...    postnet_n_convolutions=5,
+    ...    postnet_dropout=0.5,
+    ...    padding_idx=0,
+    ...    dur_pred_kernel_size=3,
+    ...    pitch_pred_kernel_size=3,
+    ...    energy_pred_kernel_size=3,
+    ...    variance_predictor_dropout=0.5)
+    >>> inputs = torch.tensor([
+    ...     [13, 12, 31, 14, 19],
+    ...     [31, 16, 30, 31, 0],
+    ... ])
+    >>> mels = torch.rand(2, 100, 80)
+    >>> mel_post, postnet_output, durations, predict_pitch, avg_pitch, predict_energy, avg_energy, mel_lens, alignment_durations, alignment_soft, alignment_logprob, alignment_mas = model(inputs, mels)
+    >>> mel_post.shape, durations.shape
+    (torch.Size([2, 100, 80]), torch.Size([2, 5]))
+    >>> predict_pitch.shape, predict_energy.shape
+    (torch.Size([2, 5, 1]), torch.Size([2, 5, 1]))
+    >>> alignment_soft.shape, alignment_mas.shape
+    (torch.Size([2, 100, 5]), torch.Size([2, 100, 5]))
+    """
+
+    def __init__(
+        self,
+        enc_num_layers,
+        enc_num_head,
+        enc_d_model,
+        enc_ffn_dim,
+        enc_k_dim,
+        enc_v_dim,
+        enc_dropout,
+        in_query_channels,
+        in_key_channels,
+        attn_channels,
+        temperature,
+        dec_num_layers,
+        dec_num_head,
+        dec_d_model,
+        dec_ffn_dim,
+        dec_k_dim,
+        dec_v_dim,
+        dec_dropout,
+        normalize_before,
+        ffn_type,
+        ffn_cnn_kernel_size_list,
+        n_char,
+        n_mels,
+        postnet_embedding_dim,
+        postnet_kernel_size,
+        postnet_n_convolutions,
+        postnet_dropout,
+        padding_idx,
+        dur_pred_kernel_size,
+        pitch_pred_kernel_size,
+        energy_pred_kernel_size,
+        variance_predictor_dropout,
+    ):
+        super().__init__()
+        self.enc_num_head = enc_num_head
+        self.dec_num_head = dec_num_head
+        self.padding_idx = padding_idx
+        self.sinusoidal_positional_embed_encoder = PositionalEncoding(
+            enc_d_model
+        )
+        self.sinusoidal_positional_embed_decoder = PositionalEncoding(
+            dec_d_model
+        )
+
+        self.encPreNet = EncoderPreNet(
+            n_char, padding_idx, out_channels=enc_d_model
+        )
+        self.durPred = DurationPredictor(
+            in_channels=enc_d_model,
+            out_channels=enc_d_model,
+            kernel_size=dur_pred_kernel_size,
+            dropout=variance_predictor_dropout,
+        )
+        self.pitchPred = DurationPredictor(
+            in_channels=enc_d_model,
+            out_channels=enc_d_model,
+            kernel_size=dur_pred_kernel_size,
+            dropout=variance_predictor_dropout,
+        )
+        self.energyPred = DurationPredictor(
+            in_channels=enc_d_model,
+            out_channels=enc_d_model,
+            kernel_size=dur_pred_kernel_size,
+            dropout=variance_predictor_dropout,
+        )
+        self.pitchEmbed = CNN.Conv1d(
+            in_channels=1,
+            out_channels=enc_d_model,
+            kernel_size=pitch_pred_kernel_size,
+            padding="same",
+            skip_transpose=True,
+        )
+
+        self.energyEmbed = CNN.Conv1d(
+            in_channels=1,
+            out_channels=enc_d_model,
+            kernel_size=energy_pred_kernel_size,
+            padding="same",
+            skip_transpose=True,
+        )
+        self.encoder = TransformerEncoder(
+            num_layers=enc_num_layers,
+            nhead=enc_num_head,
+            d_ffn=enc_ffn_dim,
+            d_model=enc_d_model,
+            kdim=enc_k_dim,
+            vdim=enc_v_dim,
+            dropout=enc_dropout,
+            activation=nn.ReLU,
+            normalize_before=normalize_before,
+            ffn_type=ffn_type,
+            ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
+        )
+
+        self.decoder = TransformerEncoder(
+            num_layers=dec_num_layers,
+            nhead=dec_num_head,
+            d_ffn=dec_ffn_dim,
+            d_model=dec_d_model,
+            kdim=dec_k_dim,
+            vdim=dec_v_dim,
+            dropout=dec_dropout,
+            activation=nn.ReLU,
+            normalize_before=normalize_before,
+            ffn_type=ffn_type,
+            ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
+        )
+
+        self.linear = linear.Linear(n_neurons=n_mels, input_size=dec_d_model)
+        self.postnet = PostNet(
+            n_mel_channels=n_mels,
+            postnet_embedding_dim=postnet_embedding_dim,
+            postnet_kernel_size=postnet_kernel_size,
+            postnet_n_convolutions=postnet_n_convolutions,
+            postnet_dropout=postnet_dropout,
+        )
+        self.aligner = AlignmentNetwork(
+            in_query_channels=in_query_channels,
+            in_key_channels=in_key_channels,
+            attn_channels=attn_channels,
+            temperature=temperature,
+        )
+
+    def _forward_aligner(self, x, y, x_mask, y_mask):
+        """Aligner forward pass.
+        1. Compute a mask to apply to the attention map.
+        2. Run the alignment network.
+        3. Apply MAS (Monotonic alignment search) to compute the hard alignment map.
+        4. Compute the durations from the hard alignment map.
+
+        Arguments
+        ---------
+        x: torch.Tensor
+            Input sequence [B, T_en, C_en].
+        y: torch.Tensor
+            Output sequence [B, T_de, C_de].
+        x_mask: torch.Tensor
+            Input sequence mask [B, 1, T_en].
+        y_mask: torch.Tensor
+            Output sequence mask [B, 1, T_de].
+        Returns
+        ---------
+        durations: torch.Tensor
+            Durations from the hard alignment map [B, T_en].
+        alignment_soft: torch.Tensor
+            soft alignment potentials [B, T_en, T_de].
+        alignment_logprob: torch.Tensor
+            log scale alignment potentials [B, 1, T_de, T_en].
+        alignment_mas: torch.Tensor
+            hard alignment map [B, T_en, T_de].
+        """
+        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
+        alignment_soft, alignment_logprob = self.aligner(
+            y.transpose(1, 2), x.transpose(1, 2), x_mask, None
+        )
+        alignment_mas = maximum_path_numpy(
+            alignment_soft.squeeze(1).transpose(1, 2).contiguous(),
+            attn_mask.squeeze(1).contiguous(),
+        )
+        durations = torch.sum(alignment_mas, -1).int()
+        alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
+        return durations, alignment_soft, alignment_logprob, alignment_mas
+
+    def forward(
+        self,
+        tokens,
+        mel_spectograms=None,
+        pitch=None,
+        energy=None,
+        pace=1.0,
+        pitch_rate=1.0,
+        energy_rate=1.0,
+    ):
+        """forward pass for training and inference
+        Arguments
+        ---------
+        tokens: torch.Tensor
+            batch of input tokens
+        mel_spectograms: torch.Tensor
+            batch of mel_spectograms (used only for training)
+        pitch: torch.Tensor
+            batch of pitch for each frame. If it is None, the model will infer on predicted pitches
+        energy: torch.Tensor
+            batch of energy for each frame. If it is None, the model will infer on predicted energies
+        pace: float
+            scaling factor for durations
+        pitch_rate: float
+            scaling factor for pitches
+        energy_rate: float
+            scaling factor for energies
+        Returns
+        ---------
+        mel_post: torch.Tensor
+            mel outputs from the decoder
+        postnet_output: torch.Tensor
+            mel outputs from the postnet
+        predict_durations: torch.Tensor
+            predicted durations of each token
+        predict_pitch: torch.Tensor
+            predicted pitches of each token
+        avg_pitch: torch.Tensor
+            target pitches for each token if input pitch is not None
+            None if input pitch is None
+        predict_energy: torch.Tensor
+            predicted energies of each token
+        avg_energy: torch.Tensor
+            target energies for each token if input energy is not None
+            None if input energy is None
+        mel_length:
+            predicted lengths of mel spectrograms
+        alignment_durations:
+            durations from the hard alignment map
+        alignment_soft: torch.Tensor
+            soft alignment potentials
+        alignment_logprob: torch.Tensor
+            log scale alignment potentials
+        alignment_mas: torch.Tensor
+            hard alignment map
+        """
+        srcmask = get_key_padding_mask(tokens, pad_idx=self.padding_idx)
+        srcmask_inverted = (~srcmask).unsqueeze(-1)
+
+        # encoder
+        token_feats = self.encPreNet(tokens)
+        pos = self.sinusoidal_positional_embed_encoder(token_feats)
+        token_feats = torch.add(token_feats, pos) * srcmask_inverted
+        attn_mask = (
+            srcmask.unsqueeze(-1)
+            .repeat(self.enc_num_head, 1, token_feats.shape[1])
+            .permute(0, 2, 1)
+            .bool()
+        )
+        token_feats, _ = self.encoder(
+            token_feats, src_mask=attn_mask, src_key_padding_mask=srcmask
+        )
+        token_feats = token_feats * srcmask_inverted
+
+        # aligner
+        alignment_durations = None
+        alignment_soft = None
+        alignment_logprob = None
+        alignment_mas = None
+        if mel_spectograms is not None:
+            y_mask = get_key_padding_mask(
+                mel_spectograms, pad_idx=self.padding_idx
+            )
+            y_mask_inverted = (~y_mask).unsqueeze(-1)
+
+            (
+                alignment_durations,
+                alignment_soft,
+                alignment_logprob,
+                alignment_mas,
+            ) = self._forward_aligner(
+                token_feats,
+                mel_spectograms,
+                srcmask_inverted.transpose(1, 2),
+                y_mask_inverted.transpose(1, 2),
+            )
+
+            alignment_soft = alignment_soft.transpose(1, 2)
+            alignment_mas = alignment_mas.transpose(1, 2)
+
+        # duration predictor
+        predict_durations = self.durPred(
+            token_feats, srcmask_inverted
+        ).squeeze()
+        if predict_durations.dim() == 1:
+            predict_durations = predict_durations.unsqueeze(0)
+        predict_durations_reverse_log = torch.clamp(
+            torch.exp(predict_durations) - 1, 0
+        )
+
+        # pitch predictor
+        avg_pitch = None
+        predict_pitch = self.pitchPred(token_feats, srcmask_inverted)
+        # use a pitch rate to adjust the pitch
+        predict_pitch = predict_pitch * pitch_rate
+        if pitch is not None:
+            avg_pitch = average_over_durations(
+                pitch.unsqueeze(1), alignment_durations
+            )
+            pitch = self.pitchEmbed(avg_pitch)
+            avg_pitch = avg_pitch.permute(0, 2, 1)
+        else:
+            pitch = self.pitchEmbed(predict_pitch.permute(0, 2, 1))
+        pitch = pitch.permute(0, 2, 1)
+        token_feats = token_feats.add(pitch)
+
+        # energy predictor
+        avg_energy = None
+        predict_energy = self.energyPred(token_feats, srcmask_inverted)
+        # use an energy rate to adjust the energy
+        predict_energy = predict_energy * energy_rate
+        if energy is not None:
+            avg_energy = average_over_durations(
+                energy.unsqueeze(1), alignment_durations
+            )
+            energy = self.energyEmbed(avg_energy)
+            avg_energy = avg_energy.permute(0, 2, 1)
+        else:
+            energy = self.energyEmbed(predict_energy.permute(0, 2, 1))
+        energy = energy.permute(0, 2, 1)
+        token_feats = token_feats.add(energy)
+
+        # upsampling
+        spec_feats, mel_lens = upsample(
+            token_feats,
+            alignment_durations
+            if alignment_durations is not None
+            else predict_durations_reverse_log,
+            pace=pace,
+        )
+        srcmask = get_mask_from_lengths(torch.tensor(mel_lens))
+        srcmask = srcmask.to(spec_feats.device)
+        srcmask_inverted = (~srcmask).unsqueeze(-1)
+        attn_mask = (
+            srcmask.unsqueeze(-1)
+            .repeat(self.dec_num_head, 1, spec_feats.shape[1])
+            .permute(0, 2, 1)
+            .bool()
+        )
+
+        # decoder
+        pos = self.sinusoidal_positional_embed_decoder(spec_feats)
+        spec_feats = torch.add(spec_feats, pos) * srcmask_inverted
+
+        output_mel_feats, memory, *_ = self.decoder(
+            spec_feats, src_mask=attn_mask, src_key_padding_mask=srcmask
+        )
+
+        # postnet
+        mel_post = self.linear(output_mel_feats) * srcmask_inverted
+        postnet_output = self.postnet(mel_post) + mel_post
+
+        return (
+            mel_post,
+            postnet_output,
+            predict_durations,
+            predict_pitch,
+            avg_pitch,
+            predict_energy,
+            avg_energy,
+            torch.tensor(mel_lens),
+            alignment_durations,
+            alignment_soft,
+            alignment_logprob,
+            alignment_mas,
+        )
+
+
+class LossWithAlignment(nn.Module):
+    """Loss computation including internal aligner
+    Arguments
+    ---------
+    log_scale_durations: bool
+       applies logarithm to target durations
+    ssim_loss_weight: float
+       weight for the ssim loss
+    duration_loss_weight: float
+       weight for the duration loss
+    pitch_loss_weight: float
+       weight for the pitch loss
+    energy_loss_weight: float
+       weight for the energy loss
+    mel_loss_weight: float
+       weight for the mel loss
+    postnet_mel_loss_weight: float
+       weight for the postnet mel loss
+    aligner_loss_weight: float
+       weight for the alignment loss
+    binary_alignment_loss_weight: float
+       weight for the postnet mel loss
+    binary_alignment_loss_warmup_epochs: int
+       Number of epochs to gradually increase the impact of binary loss.
+    binary_alignment_loss_max_epochs: int
+       From this epoch on the impact of binary loss is ignored.
+    """
+
+    def __init__(
+        self,
+        log_scale_durations,
+        ssim_loss_weight,
+        duration_loss_weight,
+        pitch_loss_weight,
+        energy_loss_weight,
+        mel_loss_weight,
+        postnet_mel_loss_weight,
+        aligner_loss_weight,
+        binary_alignment_loss_weight,
+        binary_alignment_loss_warmup_epochs,
+        binary_alignment_loss_max_epochs,
+    ):
+        super().__init__()
+
+        self.ssim_loss = SSIMLoss()
+        self.mel_loss = nn.MSELoss()
+        self.postnet_mel_loss = nn.MSELoss()
+        self.dur_loss = nn.MSELoss()
+        self.pitch_loss = nn.MSELoss()
+        self.energy_loss = nn.MSELoss()
+        self.aligner_loss = ForwardSumLoss()
+        self.binary_alignment_loss = BinaryAlignmentLoss()
+        self.log_scale_durations = log_scale_durations
+        self.ssim_loss_weight = ssim_loss_weight
+        self.mel_loss_weight = mel_loss_weight
+        self.postnet_mel_loss_weight = postnet_mel_loss_weight
+        self.duration_loss_weight = duration_loss_weight
+        self.pitch_loss_weight = pitch_loss_weight
+        self.energy_loss_weight = energy_loss_weight
+        self.aligner_loss_weight = aligner_loss_weight
+        self.binary_alignment_loss_weight = binary_alignment_loss_weight
+        self.binary_alignment_loss_warmup_epochs = (
+            binary_alignment_loss_warmup_epochs
+        )
+        self.binary_alignment_loss_max_epochs = binary_alignment_loss_max_epochs
+
+    def forward(self, predictions, targets, current_epoch):
+        """Computes the value of the loss function and updates stats
+        Arguments
+        ---------
+        predictions: tuple
+            model predictions
+        targets: tuple
+            ground truth data
+        current_epoch: int
+            used to determinate the start/end of the binary alignment loss
+        Returns
+        -------
+        loss: torch.Tensor
+            the loss value
+        """
+        (
+            mel_target,
+            target_pitch,
+            target_energy,
+            mel_length,
+            phon_len,
+        ) = targets
+        assert len(mel_target.shape) == 3
+        (
+            mel_out,
+            postnet_mel_out,
+            log_durations,
+            predicted_pitch,
+            average_pitch,
+            predicted_energy,
+            average_energy,
+            mel_lens,
+            alignment_durations,
+            alignment_soft,
+            alignment_logprob,
+            alignment_hard,
+        ) = predictions
+
+        predicted_pitch = predicted_pitch.squeeze(-1)
+        predicted_energy = predicted_energy.squeeze(-1)
+
+        target_pitch = average_pitch.squeeze(-1)
+        target_energy = average_energy.squeeze(-1)
+
+        log_durations = log_durations.squeeze(-1)
+        if self.log_scale_durations:
+            log_target_durations = torch.log(alignment_durations.float() + 1)
+        # change this to perform batch level using padding mask
+
+        for i in range(mel_target.shape[0]):
+            if i == 0:
+                mel_loss = self.mel_loss(
+                    mel_out[i, : mel_length[i], :],
+                    mel_target[i, : mel_length[i], :],
+                )
+                postnet_mel_loss = self.postnet_mel_loss(
+                    postnet_mel_out[i, : mel_length[i], :],
+                    mel_target[i, : mel_length[i], :],
+                )
+                dur_loss = self.dur_loss(
+                    log_durations[i, : phon_len[i]],
+                    log_target_durations[i, : phon_len[i]].to(torch.float32),
+                )
+                pitch_loss = self.pitch_loss(
+                    predicted_pitch[i, : mel_length[i]],
+                    target_pitch[i, : mel_length[i]].to(torch.float32),
+                )
+                energy_loss = self.energy_loss(
+                    predicted_energy[i, : mel_length[i]],
+                    target_energy[i, : mel_length[i]].to(torch.float32),
+                )
+            else:
+                mel_loss = mel_loss + self.mel_loss(
+                    mel_out[i, : mel_length[i], :],
+                    mel_target[i, : mel_length[i], :],
+                )
+                postnet_mel_loss = postnet_mel_loss + self.postnet_mel_loss(
+                    postnet_mel_out[i, : mel_length[i], :],
+                    mel_target[i, : mel_length[i], :],
+                )
+                dur_loss = dur_loss + self.dur_loss(
+                    log_durations[i, : phon_len[i]],
+                    log_target_durations[i, : phon_len[i]].to(torch.float32),
+                )
+                pitch_loss = pitch_loss + self.pitch_loss(
+                    predicted_pitch[i, : mel_length[i]],
+                    target_pitch[i, : mel_length[i]].to(torch.float32),
+                )
+                energy_loss = energy_loss + self.energy_loss(
+                    predicted_energy[i, : mel_length[i]],
+                    target_energy[i, : mel_length[i]].to(torch.float32),
+                )
+
+        total_loss = 0
+        loss = {}
+
+        ssim_loss = self.ssim_loss(mel_out, mel_target, mel_length)
+        loss["ssim_loss"] = ssim_loss * self.ssim_loss_weight
+
+        mel_loss = torch.div(mel_loss, len(mel_target))
+        loss["mel_loss"] = mel_loss * self.mel_loss_weight
+
+        postnet_mel_loss = torch.div(postnet_mel_loss, len(mel_target))
+        loss["postnet_mel_loss"] = (
+            postnet_mel_loss * self.postnet_mel_loss_weight
+        )
+
+        dur_loss = torch.div(dur_loss, len(mel_target))
+        loss["dur_loss"] = dur_loss * self.duration_loss_weight
+
+        pitch_loss = torch.div(pitch_loss, len(mel_target))
+        loss["pitch_loss"] = pitch_loss * self.pitch_loss_weight
+
+        energy_loss = torch.div(energy_loss, len(mel_target))
+        loss["energy_loss"] = energy_loss * self.energy_loss_weight
+
+        if alignment_logprob is not None:
+            aligner_loss = self.aligner_loss(
+                alignment_logprob, phon_len, mel_length
+            )
+            loss["aligner_loss"] = aligner_loss * self.aligner_loss_weight
+
+        if alignment_soft is not None and alignment_hard is not None:
+            if current_epoch > self.binary_alignment_loss_max_epochs:
+                binary_loss_warmup_weight = 0
+            else:
+                binary_loss_warmup_weight = (
+                    min(
+                        current_epoch
+                        / self.binary_alignment_loss_warmup_epochs,
+                        1.0,
+                    )
+                    * 1.0
+                )
+
+            binary_alignment_loss = self.binary_alignment_loss(
+                alignment_hard, alignment_soft
+            )
+            loss["binary_alignment_loss"] = (
+                binary_alignment_loss
+                * self.binary_alignment_loss_weight
+                * binary_loss_warmup_weight
+            )
+
+        total_loss = sum(loss.values())
+        loss["total_loss"] = total_loss
+        return loss
+
+
+class ForwardSumLoss(nn.Module):
+    """CTC alignment loss
+    Arguments
+    ---------
+    blank_logprob: pad value
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.lobes.models.FastSpeech2 import ForwardSumLoss
+    >>> loss_func = ForwardSumLoss()
+    >>> attn_logprob = torch.rand(2, 1, 100, 5)
+    >>> key_lens = torch.tensor([5, 5])
+    >>> query_lens = torch.tensor([100, 100])
+    >>> loss = loss_func(attn_logprob, key_lens, query_lens)
+    """
+
+    def __init__(self, blank_logprob=-1):
+        super().__init__()
+        self.log_softmax = torch.nn.LogSoftmax(dim=3)
+        self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
+        self.blank_logprob = blank_logprob
+
+    def forward(self, attn_logprob, key_lens, query_lens):
+        """
+        Arguments
+        ---------
+        attn_logprob: torch.Tensor
+            log scale alignment potentials [B, 1, query_lens, key_lens]
+        key_lens: torch.Tensor
+            mel lengths
+        query_lens: torch.Tensor
+            phoneme lengths
+        """
+        attn_logprob_padded = torch.nn.functional.pad(
+            input=attn_logprob, pad=(1, 0), value=self.blank_logprob
+        )
+
+        total_loss = 0.0
+        for bid in range(attn_logprob.shape[0]):
+            target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
+            curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
+                : query_lens[bid], :, : key_lens[bid] + 1
+            ]
+
+            curr_logprob = self.log_softmax(curr_logprob[None])[0]
+            loss = self.ctc_loss(
+                curr_logprob,
+                target_seq,
+                input_lengths=query_lens[bid : bid + 1],
+                target_lengths=key_lens[bid : bid + 1],
+            )
+            total_loss = total_loss + loss
+
+        total_loss = total_loss / attn_logprob.shape[0]
+        return total_loss
+
+
+class BinaryAlignmentLoss(nn.Module):
+    """Binary loss that forces soft alignments to match the hard alignments as
+    explained in `https://arxiv.org/pdf/2108.10447.pdf`.
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.lobes.models.FastSpeech2 import BinaryAlignmentLoss
+    >>> loss_func = BinaryAlignmentLoss()
+    >>> alignment_hard = torch.randint(0, 2, (2, 100, 5))
+    >>> alignment_soft = torch.rand(2, 100, 5)
+    >>> loss = loss_func(alignment_hard, alignment_soft)
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, alignment_hard, alignment_soft):
+        """
+        alignment_hard: torch.Tensor
+            hard alignment map [B, mel_lens, phoneme_lens]
+        alignment_soft: torch.Tensor
+            soft alignment potentials [B, mel_lens, phoneme_lens]
+        """
+        log_sum = torch.log(
+            torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)
+        ).sum()
+        return -log_sum / alignment_hard.sum()
diff --git a/speechbrain/lobes/models/HifiGAN.py b/speechbrain/lobes/models/HifiGAN.py
index 35e06b3e9344a900884f5d6c46dc3ecfa794bbf5..774469ffdda6dfe3ddca6f55cfcf8ddcfdd524b7 100644
--- a/speechbrain/lobes/models/HifiGAN.py
+++ b/speechbrain/lobes/models/HifiGAN.py
@@ -5,7 +5,7 @@ Efficient and High Fidelity Speech Synthesis
 For more details: https://arxiv.org/pdf/2010.05646.pdf
 
 Authors
- * Duret Jarod 2021
+ * Jarod Duret 2021
  * Yingzhi WANG 2022
 """
 
@@ -35,6 +35,7 @@ Authors
 import torch
 import torch.nn.functional as F
 import torch.nn as nn
+import speechbrain as sb
 from speechbrain.nnet.CNN import Conv1d, ConvTranspose1d, Conv2d
 from torchaudio import transforms
 
@@ -42,8 +43,7 @@ LRELU_SLOPE = 0.1
 
 
 def dynamic_range_compression(x, C=1, clip_val=1e-5):
-    """Dynamique range compression for audio signals
-    """
+    """Dynamique range compression for audio signals"""
     return torch.log(torch.clamp(x, min=clip_val) * C)
 
 
@@ -116,6 +116,67 @@ def mel_spectogram(
     return mel
 
 
+def process_duration(code, code_feat):
+    """
+    Process a given batch of code to extract consecutive unique elements and their associated features.
+
+    Parameters
+    ----------
+    code : torch.Tensor (batch, time)
+        Tensor of code indices.
+    code_feat : torch.Tensor (batch, time, channel)
+        Tensor of code features.
+
+    Returns
+    -------
+    uniq_code_feat_filtered : torch.Tensor (batch, time)
+        Features of consecutive unique codes.
+    mask : torch.Tensor (batch, time)
+        Padding mask for the unique codes.
+    uniq_code_count : torch.Tensor (n)
+        Count of unique codes.
+
+    Example
+    -------
+    >>> code = torch.IntTensor([[40, 18, 18, 10]])
+    >>> code_feat = torch.rand([1, 4, 128])
+    >>> out_tensor, mask, uniq_code = process_duration(code, code_feat)
+    >>> out_tensor.shape
+    torch.Size([1, 1, 128])
+    >>> mask.shape
+    torch.Size([1, 1])
+    >>> uniq_code.shape
+    torch.Size([1])
+    """
+    uniq_code_count = []
+    uniq_code_feat = []
+    for i in range(code.size(0)):
+        _, count = torch.unique_consecutive(code[i, :], return_counts=True)
+        if len(count) > 2:
+            # remove first and last code as segment sampling may cause incomplete segment length
+            uniq_code_count.append(count[1:-1])
+            uniq_code_idx = count.cumsum(dim=0)[:-2]
+        else:
+            uniq_code_count.append(count)
+            uniq_code_idx = count.cumsum(dim=0) - 1
+        uniq_code_feat.append(
+            code_feat[i, uniq_code_idx, :].view(-1, code_feat.size(2))
+        )
+    uniq_code_count = torch.cat(uniq_code_count)
+
+    # collate
+    max_len = max(feat.size(0) for feat in uniq_code_feat)
+    uniq_code_feat_filtered = uniq_code_feat[0].new_zeros(
+        (len(uniq_code_feat), max_len, uniq_code_feat[0].size(1))
+    )
+    mask = torch.arange(max_len).repeat(len(uniq_code_feat), 1)
+    for i, v in enumerate(uniq_code_feat):
+        uniq_code_feat_filtered[i, : v.size(0)] = v
+        mask[i, :] = mask[i, :] < v.size(0)
+
+    return uniq_code_feat_filtered, mask.bool(), uniq_code_count.float()
+
+
 ##################################
 # Generator
 ##################################
@@ -225,8 +286,7 @@ class ResBlock1(torch.nn.Module):
         return x
 
     def remove_weight_norm(self):
-        """This functions removes weight normalization during inference.
-        """
+        """This functions removes weight normalization during inference."""
         for l in self.convs1:
             l.remove_weight_norm()
         for l in self.convs2:
@@ -290,8 +350,7 @@ class ResBlock2(torch.nn.Module):
         return x
 
     def remove_weight_norm(self):
-        """This functions removes weight normalization during inference.
-        """
+        """This functions removes weight normalization during inference."""
         for l in self.convs:
             l.remove_weight_norm()
 
@@ -439,8 +498,7 @@ class HifiganGenerator(torch.nn.Module):
         return o
 
     def remove_weight_norm(self):
-        """This functions removes weight normalization during inference.
-        """
+        """This functions removes weight normalization during inference."""
 
         for l in self.ups:
             l.remove_weight_norm()
@@ -450,7 +508,7 @@ class HifiganGenerator(torch.nn.Module):
         self.conv_post.remove_weight_norm()
 
     @torch.no_grad()
-    def inference(self, c):
+    def inference(self, c, padding=True):
         """The inference function performs a padding and runs the forward method.
 
         Arguments
@@ -458,12 +516,241 @@ class HifiganGenerator(torch.nn.Module):
         x : torch.Tensor (batch, channel, time)
             feature input tensor.
         """
-        c = torch.nn.functional.pad(
-            c, (self.inference_padding, self.inference_padding), "replicate"
-        )
+        if padding:
+            c = torch.nn.functional.pad(
+                c, (self.inference_padding, self.inference_padding), "replicate"
+            )
         return self.forward(c)
 
 
+class VariancePredictor(nn.Module):
+    """Variance predictor inspired from FastSpeech2
+
+    Arguments
+    ---------
+    encoder_embed_dim : int
+        number of input tensor channels.
+    var_pred_hidden_dim : int
+        size of hidden channels for the convolutional layers.
+    var_pred_kernel_size : int
+        size of the convolution filter in each layer.
+    var_pred_dropout : float
+        dropout probability of each layer.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([4, 80, 128])
+    >>> duration_predictor = VariancePredictor(
+    ...    encoder_embed_dim = 128,
+    ...    var_pred_hidden_dim = 128,
+    ...    var_pred_kernel_size = 3,
+    ...    var_pred_dropout = 0.5,
+    ... )
+    >>> out_tensor = duration_predictor (inp_tensor)
+    >>> out_tensor.shape
+    torch.Size([4, 80])
+    """
+
+    def __init__(
+        self,
+        encoder_embed_dim,
+        var_pred_hidden_dim,
+        var_pred_kernel_size,
+        var_pred_dropout,
+    ):
+        super().__init__()
+        self.conv1 = nn.Sequential(
+            Conv1d(
+                in_channels=encoder_embed_dim,
+                out_channels=var_pred_hidden_dim,
+                kernel_size=var_pred_kernel_size,
+                padding="same",
+                skip_transpose=True,
+                weight_norm=True,
+            ),
+            nn.ReLU(),
+        )
+        self.dropout = var_pred_dropout
+        self.conv2 = nn.Sequential(
+            Conv1d(
+                in_channels=var_pred_hidden_dim,
+                out_channels=var_pred_hidden_dim,
+                kernel_size=var_pred_kernel_size,
+                padding="same",
+                skip_transpose=True,
+                weight_norm=True,
+            ),
+            nn.ReLU(),
+        )
+        self.proj = nn.Linear(var_pred_hidden_dim, 1)
+
+    def forward(self, x):
+        """
+        Arguments
+        ---------
+        x : torch.Tensor (batch, channel, time)
+            feature input tensor.
+        """
+        x = self.conv1(x.transpose(1, 2)).transpose(1, 2)
+        x = F.dropout(x, p=self.dropout, training=self.training)
+        x = self.conv2(x.transpose(1, 2)).transpose(1, 2)
+        x = F.dropout(x, p=self.dropout, training=self.training)
+        return self.proj(x).squeeze(dim=2)
+
+
+class UnitHifiganGenerator(HifiganGenerator):
+    """Unit HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
+
+    Arguments
+    ---------
+    in_channels : int
+        number of input tensor channels.
+    out_channels : int
+        number of output tensor channels.
+    resblock_type : str
+        type of the `ResBlock`. '1' or '2'.
+    resblock_dilation_sizes : List[List[int]]
+        list of dilation values in each layer of a `ResBlock`.
+    resblock_kernel_sizes : List[int]
+        list of kernel sizes for each `ResBlock`.
+    upsample_kernel_sizes : List[int]
+        list of kernel sizes for each transposed convolution.
+    upsample_initial_channel : int
+        number of channels for the first upsampling layer. This is divided by 2
+        for each consecutive upsampling layer.
+    upsample_factors : List[int]
+        upsampling factors (stride) for each upsampling layer.
+    inference_padding : int
+        constant padding applied to the input at inference time. Defaults to 5.
+    num_embeddings : int
+        size of the dictionary of embeddings.
+    embedding_dim : int
+        size of each embedding vector.
+    duration_predictor : bool
+        enable duration predictor module.
+    var_pred_hidden_dim : int
+        size of hidden channels for the convolutional layers of the duration predictor.
+    var_pred_kernel_size : int
+        size of the convolution filter in each layer of the duration predictor.
+    var_pred_dropout : float
+        dropout probability of each layer in the duration predictor.
+
+    Example
+    -------
+    >>> inp_tensor = torch.randint(0, 100, (4, 10))
+    >>> unit_hifigan_generator= UnitHifiganGenerator(
+    ...    in_channels = 128,
+    ...    out_channels = 1,
+    ...    resblock_type = "1",
+    ...    resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+    ...    resblock_kernel_sizes = [3, 7, 11],
+    ...    upsample_kernel_sizes = [11, 8, 8, 4, 4],
+    ...    upsample_initial_channel = 512,
+    ...    upsample_factors = [5, 4, 4, 2, 2],
+    ...    num_embeddings = 100,
+    ...    embedding_dim = 128,
+    ...    duration_predictor = True,
+    ...    var_pred_hidden_dim = 128,
+    ...    var_pred_kernel_size = 3,
+    ...    var_pred_dropout = 0.5,
+    ... )
+    >>> out_tensor, _ = unit_hifigan_generator(inp_tensor)
+    >>> out_tensor.shape
+    torch.Size([4, 1, 3200])
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        resblock_type,
+        resblock_dilation_sizes,
+        resblock_kernel_sizes,
+        upsample_kernel_sizes,
+        upsample_initial_channel,
+        upsample_factors,
+        inference_padding=5,
+        cond_channels=0,
+        conv_post_bias=True,
+        num_embeddings=100,
+        embedding_dim=128,
+        duration_predictor=False,
+        var_pred_hidden_dim=128,
+        var_pred_kernel_size=3,
+        var_pred_dropout=0.5,
+    ):
+        super().__init__(
+            in_channels,
+            out_channels,
+            resblock_type,
+            resblock_dilation_sizes,
+            resblock_kernel_sizes,
+            upsample_kernel_sizes,
+            upsample_initial_channel,
+            upsample_factors,
+            inference_padding,
+            cond_channels,
+            conv_post_bias,
+        )
+        self.unit_embedding = torch.nn.Embedding(num_embeddings, embedding_dim)
+        self.duration_predictor = duration_predictor
+        if duration_predictor:
+            self.var_predictor = VariancePredictor(
+                embedding_dim,
+                var_pred_hidden_dim,
+                var_pred_kernel_size,
+                var_pred_dropout,
+            )
+
+    def forward(self, x, g=None):
+        """
+        Arguments
+        ---------
+        x : torch.Tensor (batch, time)
+            feature input tensor.
+        g : torch.Tensor (batch, 1, time)
+            global conditioning input tensor.
+        """
+        u = self.unit_embedding(x).transpose(1, 2)
+
+        log_dur = None
+        log_dur_pred = None
+
+        if self.duration_predictor:
+            uniq_code_feat, uniq_code_mask, dur = process_duration(
+                x, u.transpose(1, 2)
+            )
+            log_dur_pred = self.var_predictor(uniq_code_feat)
+            log_dur_pred = log_dur_pred[uniq_code_mask]
+            log_dur = torch.log(dur + 1)
+
+        return super().forward(u), (log_dur_pred, log_dur)
+
+    @torch.no_grad()
+    def inference(self, x):
+        """The inference function performs duration prediction and runs the forward method.
+
+        Arguments
+        ---------
+        x : torch.Tensor (batch, time)
+            feature input tensor.
+        """
+        x = self.unit_embedding(x).transpose(1, 2)
+
+        if self.duration_predictor:
+            assert (
+                x.size(0) == 1
+            ), "only support single sample batch in inference"
+            log_dur_pred = self.var_predictor(x.transpose(1, 2))
+            dur_out = torch.clamp(
+                torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1
+            )
+            # B x C x T
+            x = torch.repeat_interleave(x, dur_out.view(-1), dim=2)
+
+        return super().forward(x)
+
+
 ##################################
 # DISCRIMINATOR
 ##################################
@@ -738,8 +1025,7 @@ class HifiganDiscriminator(nn.Module):
 
 
 def stft(x, n_fft, hop_length, win_length, window_fn="hann_window"):
-    """computes the Fourier transform of short overlapping windows of the input
-    """
+    """computes the Fourier transform of short overlapping windows of the input"""
     o = torch.stft(x.squeeze(1), n_fft, hop_length, win_length,)
     M = o[:, :, :, 0]
     P = o[:, :, :, 1]
@@ -1135,6 +1421,8 @@ class GeneratorLoss(nn.Module):
         feat_match_loss_weight=0,
         l1_spec_loss=None,
         l1_spec_loss_weight=0,
+        mseg_dur_loss=None,
+        mseg_dur_loss_weight=0,
     ):
         super().__init__()
         self.stft_loss = stft_loss
@@ -1145,14 +1433,19 @@ class GeneratorLoss(nn.Module):
         self.feat_match_loss_weight = feat_match_loss_weight
         self.l1_spec_loss = l1_spec_loss
         self.l1_spec_loss_weight = l1_spec_loss_weight
+        self.mseg_dur_loss = mseg_dur_loss
+        self.mseg_dur_loss_weight = mseg_dur_loss_weight
 
     def forward(
         self,
+        stage,
         y_hat=None,
         y=None,
         scores_fake=None,
         feats_fake=None,
         feats_real=None,
+        log_dur_pred=None,
+        log_dur=None,
     ):
         """Returns a dictionary of generator losses and applies weights
 
@@ -1172,6 +1465,7 @@ class GeneratorLoss(nn.Module):
 
         gen_loss = 0
         adv_loss = 0
+        dur_loss = 0
         loss = {}
 
         # STFT Loss
@@ -1202,7 +1496,14 @@ class GeneratorLoss(nn.Module):
             feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
             loss["G_feat_match_loss"] = feat_match_loss
             adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
-        loss["G_loss"] = gen_loss + adv_loss
+
+        # Duration loss
+        if self.mseg_dur_loss and stage == sb.Stage.TRAIN:
+            dur_loss = F.mse_loss(log_dur_pred, log_dur, reduction="mean")
+            loss["G_dur_loss"] = dur_loss
+            dur_loss *= self.mseg_dur_loss_weight
+
+        loss["G_loss"] = gen_loss + adv_loss + dur_loss
         loss["G_gen_loss"] = gen_loss
         loss["G_adv_loss"] = adv_loss
 
diff --git a/speechbrain/lobes/models/MSTacotron2.py b/speechbrain/lobes/models/MSTacotron2.py
new file mode 100644
index 0000000000000000000000000000000000000000..55b8deac2c52732cc02fd40196011c2a532aa04e
--- /dev/null
+++ b/speechbrain/lobes/models/MSTacotron2.py
@@ -0,0 +1,750 @@
+"""
+Neural network modules for the Zero-Shot Multi-Speaker Tacotron2 end-to-end neural
+Text-to-Speech (TTS) model
+
+Authors
+* Georges Abous-Rjeili 2021
+* Artem Ploujnikov 2021
+* Pradnya Kandarkar 2023
+"""
+
+# This code uses a significant portion of the NVidia implementation, even though it
+# has been modified and enhanced
+
+# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py
+# *****************************************************************************
+#  Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions are met:
+#      * Redistributions of source code must retain the above copyright
+#        notice, this list of conditions and the following disclaimer.
+#      * Redistributions in binary form must reproduce the above copyright
+#        notice, this list of conditions and the following disclaimer in the
+#        documentation and/or other materials provided with the distribution.
+#      * Neither the name of the NVIDIA CORPORATION nor the
+#        names of its contributors may be used to endorse or promote products
+#        derived from this software without specific prior written permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+#  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+#  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+#  DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
+#  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+#  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+#  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+#  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# *****************************************************************************
+
+from math import sqrt
+from speechbrain.nnet.loss.guidedattn_loss import GuidedAttentionLoss
+import torch
+from torch import nn
+from torch.nn import functional as F
+from collections import namedtuple
+import pickle
+from speechbrain.lobes.models.Tacotron2 import (
+    LinearNorm,
+    Postnet,
+    Encoder,
+    Decoder,
+    get_mask_from_lengths,
+)
+
+
+class Tacotron2(nn.Module):
+    """The Tactron2 text-to-speech model, based on the NVIDIA implementation.
+
+    This class is the main entry point for the model, which is responsible
+    for instantiating all submodules, which, in turn, manage the individual
+    neural network layers
+
+    Simplified STRUCTURE: phoneme input->token embedding ->encoder -> (encoder output + speaker embedding) ->attention \
+    ->decoder(+prenet) -> postnet ->output
+
+    prenet(input is decoder previous time step) output is input to decoder
+    concatenanted with the attention output
+
+    Arguments
+    ---------
+    spk_emb_size: int
+        Speaker embedding size
+
+    mask_padding: bool
+        whether or not to mask pad-outputs of tacotron
+
+    #mel generation parameter in data io
+    n_mel_channels: int
+        number of mel channels for constructing spectrogram
+
+    #symbols
+    n_symbols:  int=128
+        number of accepted char symbols defined in textToSequence
+    symbols_embedding_dim: int
+        number of embeding dimension for symbols fed to nn.Embedding
+
+    # Encoder parameters
+    encoder_kernel_size: int
+        size of kernel processing the embeddings
+    encoder_n_convolutions: int
+        number of convolution layers in encoder
+    encoder_embedding_dim: int
+        number of kernels in encoder, this is also the dimension
+        of the bidirectional LSTM in the encoder
+
+    # Attention parameters
+    attention_rnn_dim: int
+        input dimension
+    attention_dim: int
+        number of hidden represetation in attention
+    # Location Layer parameters
+    attention_location_n_filters: int
+        number of 1-D convulation filters in attention
+    attention_location_kernel_size: int
+        length of the 1-D convolution filters
+
+    # Decoder parameters
+    n_frames_per_step: int=1
+        only 1 generated mel-frame per step is supported for the decoder as of now.
+    decoder_rnn_dim: int
+        number of 2 unidirectionnal stacked LSTM units
+    prenet_dim: int
+        dimension of linear prenet layers
+    max_decoder_steps: int
+        maximum number of steps/frames the decoder generates before stopping
+    p_attention_dropout: float
+        attention drop out probability
+    p_decoder_dropout: float
+        decoder drop  out probability
+
+    gate_threshold: int
+        cut off level any output probabilty above that is considered
+        complete and stops genration so we have variable length outputs
+    decoder_no_early_stopping: bool
+        determines early stopping of decoder
+        along with gate_threshold . The logical inverse of this is fed to the decoder
+
+
+    #Mel-post processing network parameters
+    postnet_embedding_dim: int
+        number os postnet dfilters
+    postnet_kernel_size: int
+        1d size of posnet kernel
+    postnet_n_convolutions: int
+        number of convolution layers in postnet
+
+    Example
+    -------
+    >>> import torch
+    >>> _ = torch.manual_seed(213312)
+    >>> from speechbrain.lobes.models.Tacotron2 import Tacotron2
+    >>> model = Tacotron2(
+    ...    mask_padding=True,
+    ...    n_mel_channels=80,
+    ...    n_symbols=148,
+    ...    symbols_embedding_dim=512,
+    ...    encoder_kernel_size=5,
+    ...    encoder_n_convolutions=3,
+    ...    encoder_embedding_dim=512,
+    ...    attention_rnn_dim=1024,
+    ...    attention_dim=128,
+    ...    attention_location_n_filters=32,
+    ...    attention_location_kernel_size=31,
+    ...    n_frames_per_step=1,
+    ...    decoder_rnn_dim=1024,
+    ...    prenet_dim=256,
+    ...    max_decoder_steps=32,
+    ...    gate_threshold=0.5,
+    ...    p_attention_dropout=0.1,
+    ...    p_decoder_dropout=0.1,
+    ...    postnet_embedding_dim=512,
+    ...    postnet_kernel_size=5,
+    ...    postnet_n_convolutions=5,
+    ...    decoder_no_early_stopping=False
+    ... )
+    >>> _ = model.eval()
+    >>> inputs = torch.tensor([
+    ...     [13, 12, 31, 14, 19],
+    ...     [31, 16, 30, 31, 0],
+    ... ])
+    >>> input_lengths = torch.tensor([5, 4])
+    >>> outputs, output_lengths, alignments = model.infer(inputs, input_lengths)
+    >>> outputs.shape, output_lengths.shape, alignments.shape
+    (torch.Size([2, 80, 1]), torch.Size([2]), torch.Size([2, 1, 5]))
+    """
+
+    def __init__(
+        self,
+        spk_emb_size,
+        mask_padding=True,
+        n_mel_channels=80,
+        n_symbols=148,
+        symbols_embedding_dim=512,
+        encoder_kernel_size=5,
+        encoder_n_convolutions=3,
+        encoder_embedding_dim=512,
+        attention_rnn_dim=1024,
+        attention_dim=128,
+        attention_location_n_filters=32,
+        attention_location_kernel_size=31,
+        n_frames_per_step=1,
+        decoder_rnn_dim=1024,
+        prenet_dim=256,
+        max_decoder_steps=1000,
+        gate_threshold=0.5,
+        p_attention_dropout=0.1,
+        p_decoder_dropout=0.1,
+        postnet_embedding_dim=512,
+        postnet_kernel_size=5,
+        postnet_n_convolutions=5,
+        decoder_no_early_stopping=False,
+    ):
+        super().__init__()
+        self.mask_padding = mask_padding
+        self.n_mel_channels = n_mel_channels
+        self.n_frames_per_step = n_frames_per_step
+        self.embedding = nn.Embedding(n_symbols, symbols_embedding_dim)
+        std = sqrt(2.0 / (n_symbols + symbols_embedding_dim))
+        val = sqrt(3.0) * std  # uniform bounds for std
+        self.embedding.weight.data.uniform_(-val, val)
+        self.encoder = Encoder(
+            encoder_n_convolutions, encoder_embedding_dim, encoder_kernel_size
+        )
+        self.decoder = Decoder(
+            n_mel_channels,
+            n_frames_per_step,
+            encoder_embedding_dim,
+            attention_dim,
+            attention_location_n_filters,
+            attention_location_kernel_size,
+            attention_rnn_dim,
+            decoder_rnn_dim,
+            prenet_dim,
+            max_decoder_steps,
+            gate_threshold,
+            p_attention_dropout,
+            p_decoder_dropout,
+            not decoder_no_early_stopping,
+        )
+        self.postnet = Postnet(
+            n_mel_channels,
+            postnet_embedding_dim,
+            postnet_kernel_size,
+            postnet_n_convolutions,
+        )
+
+        # Additions for Zero-Shot Multi-Speaker TTS
+        # FiLM (Feature-wise Linear Modulation) layers for injecting the speaker embeddings into the TTS pipeline
+        self.ms_film_hidden_size = int(
+            (spk_emb_size + encoder_embedding_dim) / 2
+        )
+        self.ms_film_hidden = LinearNorm(spk_emb_size, self.ms_film_hidden_size)
+        self.ms_film_h = LinearNorm(
+            self.ms_film_hidden_size, encoder_embedding_dim
+        )
+        self.ms_film_g = LinearNorm(
+            self.ms_film_hidden_size, encoder_embedding_dim
+        )
+
+    def parse_output(self, outputs, output_lengths, alignments_dim=None):
+        """
+        Masks the padded part of output
+
+        Arguments
+        ---------
+        outputs: list
+            a list of tensors - raw outputs
+        output_lengths: torch.Tensor
+            a tensor representing the lengths of all outputs
+        alignments_dim: int
+            the desired dimension of the alignments along the last axis
+            Optional but needed for data-parallel training
+
+
+        Returns
+        -------
+        result: tuple
+            a (mel_outputs, mel_outputs_postnet, gate_outputs, alignments) tuple with
+            the original outputs - with the mask applied
+        """
+        mel_outputs, mel_outputs_postnet, gate_outputs, alignments = outputs
+        if self.mask_padding and output_lengths is not None:
+            mask = get_mask_from_lengths(
+                output_lengths, max_len=mel_outputs.size(-1)
+            )
+            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
+            mask = mask.permute(1, 0, 2)
+
+            mel_outputs.clone().masked_fill_(mask, 0.0)
+            mel_outputs_postnet.masked_fill_(mask, 0.0)
+            gate_outputs.masked_fill_(mask[:, 0, :], 1e3)  # gate energies
+        if alignments_dim is not None:
+            alignments = F.pad(
+                alignments, (0, alignments_dim - alignments.size(-1))
+            )
+
+        return (
+            mel_outputs,
+            mel_outputs_postnet,
+            gate_outputs,
+            alignments,
+            output_lengths,
+        )
+
+    def forward(self, inputs, spk_embs, alignments_dim=None):
+        """Decoder forward pass for training
+
+        Arguments
+        ---------
+        inputs: tuple
+            batch object
+        spk_embs: torch.Tensor
+            Speaker embeddings corresponding to the inputs
+        alignments_dim: int
+            the desired dimension of the alignments along the last axis
+            Optional but needed for data-parallel training
+
+        Returns
+        ---------
+        mel_outputs: torch.Tensor
+            mel outputs from the decoder
+        mel_outputs_postnet: torch.Tensor
+            mel outputs from postnet
+        gate_outputs: torch.Tensor
+            gate outputs from the decoder
+        alignments: torch.Tensor
+            sequence of attention weights from the decoder
+        output_legnths: torch.Tensor
+            length of the output without padding
+        """
+        inputs, input_lengths, targets, max_len, output_lengths = inputs
+        input_lengths, output_lengths = input_lengths.data, output_lengths.data
+
+        embedded_inputs = self.embedding(inputs).transpose(1, 2)
+        encoder_outputs = self.encoder(embedded_inputs, input_lengths)
+
+        # Inject speaker embeddings into the encoder output
+        spk_embs_shared = F.relu(self.ms_film_hidden(spk_embs))
+
+        spk_embs_h = self.ms_film_h(spk_embs_shared)
+        spk_embs_h = torch.unsqueeze(spk_embs_h, 1).repeat(
+            1, encoder_outputs.shape[1], 1
+        )
+        encoder_outputs = encoder_outputs * spk_embs_h
+
+        spk_embs_g = self.ms_film_g(spk_embs_shared)
+        spk_embs_g = torch.unsqueeze(spk_embs_g, 1).repeat(
+            1, encoder_outputs.shape[1], 1
+        )
+        encoder_outputs = encoder_outputs + spk_embs_g
+
+        # Pass the encoder output combined with speaker embeddings to the next layers
+        mel_outputs, gate_outputs, alignments = self.decoder(
+            encoder_outputs, targets, memory_lengths=input_lengths
+        )
+
+        mel_outputs_postnet = self.postnet(mel_outputs)
+        mel_outputs_postnet = mel_outputs + mel_outputs_postnet
+
+        return self.parse_output(
+            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
+            output_lengths,
+            alignments_dim,
+        )
+
+    def infer(self, inputs, spk_embs, input_lengths):
+        """Produces outputs
+
+
+        Arguments
+        ---------
+        inputs: torch.tensor
+            text or phonemes converted
+        spk_embs: torch.Tensor
+            Speaker embeddings corresponding to the inputs
+        input_lengths: torch.tensor
+            the lengths of input parameters
+
+        Returns
+        -------
+        mel_outputs_postnet: torch.Tensor
+            final mel output of tacotron 2
+        mel_lengths: torch.Tensor
+            length of mels
+        alignments: torch.Tensor
+            sequence of attention weights
+        """
+
+        embedded_inputs = self.embedding(inputs).transpose(1, 2)
+        encoder_outputs = self.encoder.infer(embedded_inputs, input_lengths)
+
+        # Inject speaker embeddings into the encoder output
+        spk_embs_shared = F.relu(self.ms_film_hidden(spk_embs))
+
+        spk_embs_h = self.ms_film_h(spk_embs_shared)
+        spk_embs_h = torch.unsqueeze(spk_embs_h, 1).repeat(
+            1, encoder_outputs.shape[1], 1
+        )
+        encoder_outputs = encoder_outputs * spk_embs_h
+
+        spk_embs_g = self.ms_film_g(spk_embs_shared)
+        spk_embs_g = torch.unsqueeze(spk_embs_g, 1).repeat(
+            1, encoder_outputs.shape[1], 1
+        )
+        encoder_outputs = encoder_outputs + spk_embs_g
+
+        # Pass the encoder output combined with speaker embeddings to the next layers
+        mel_outputs, gate_outputs, alignments, mel_lengths = self.decoder.infer(
+            encoder_outputs, input_lengths
+        )
+
+        mel_outputs_postnet = self.postnet(mel_outputs)
+        mel_outputs_postnet = mel_outputs + mel_outputs_postnet
+
+        BS = mel_outputs_postnet.size(0)
+        alignments = alignments.unfold(1, BS, BS).transpose(0, 2)
+
+        return mel_outputs_postnet, mel_lengths, alignments
+
+
+LossStats = namedtuple(
+    "TacotronLoss", "loss mel_loss spk_emb_loss gate_loss attn_loss attn_weight"
+)
+
+
+class Loss(nn.Module):
+    """The Tacotron loss implementation
+    The loss consists of an MSE loss on the spectrogram, a BCE gate loss
+    and a guided attention loss (if enabled) that attempts to make the
+    attention matrix diagonal
+    The output of the moduel is a LossStats tuple, which includes both the
+    total loss
+    Arguments
+    ---------
+    guided_attention_sigma: float
+        The guided attention sigma factor, controling the "width" of
+        the mask
+    gate_loss_weight: float
+        The constant by which the gate loss will be multiplied
+    mel_loss_weight: float
+        The constant by which the mel loss will be multiplied
+    spk_emb_loss_weight: float
+        The constant by which the speaker embedding loss will be multiplied - placeholder for future work
+    spk_emb_loss_type: str
+        Type of the speaker embedding loss - placeholder for future work
+    guided_attention_weight: float
+        The weight for the guided attention
+    guided_attention_scheduler: callable
+        The scheduler class for the guided attention loss
+    guided_attention_hard_stop: int
+        The number of epochs after which guided attention will be compeltely
+        turned off
+    Example:
+    >>> import torch
+    >>> _ = torch.manual_seed(42)
+    >>> from speechbrain.lobes.models.MSTacotron2 import Loss
+    >>> loss = Loss(guided_attention_sigma=0.2)
+    >>> mel_target = torch.randn(2, 80, 861)
+    >>> gate_target = torch.randn(1722, 1)
+    >>> mel_out = torch.randn(2, 80, 861)
+    >>> mel_out_postnet = torch.randn(2, 80, 861)
+    >>> gate_out = torch.randn(2, 861)
+    >>> alignments = torch.randn(2, 861, 173)
+    >>> pred_mel_lens = torch.randn(2)
+    >>> targets = mel_target, gate_target
+    >>> model_outputs = mel_out, mel_out_postnet, gate_out, alignments, pred_mel_lens
+    >>> input_lengths = torch.tensor([173,  91])
+    >>> target_lengths = torch.tensor([861, 438])
+    >>> spk_embs = None
+    >>> loss(model_outputs, targets, input_lengths, target_lengths, spk_embs, 1)
+    TacotronLoss(loss=tensor([4.8566]), mel_loss=tensor(4.0097), spk_emb_loss=tensor([0.]), gate_loss=tensor(0.8460), attn_loss=tensor(0.0010), attn_weight=tensor(1.))
+    """
+
+    def __init__(
+        self,
+        guided_attention_sigma=None,
+        gate_loss_weight=1.0,
+        mel_loss_weight=1.0,
+        spk_emb_loss_weight=1.0,
+        spk_emb_loss_type=None,
+        guided_attention_weight=1.0,
+        guided_attention_scheduler=None,
+        guided_attention_hard_stop=None,
+    ):
+        super().__init__()
+        if guided_attention_weight == 0:
+            guided_attention_weight = None
+        self.guided_attention_weight = guided_attention_weight
+        self.gate_loss_weight = gate_loss_weight
+        self.mel_loss_weight = mel_loss_weight
+        self.spk_emb_loss_weight = spk_emb_loss_weight
+        self.spk_emb_loss_type = spk_emb_loss_type
+
+        self.mse_loss = nn.MSELoss()
+        self.bce_loss = nn.BCEWithLogitsLoss()
+        self.guided_attention_loss = GuidedAttentionLoss(
+            sigma=guided_attention_sigma
+        )
+        self.cos_sim = nn.CosineSimilarity()
+        self.triplet_loss = torch.nn.TripletMarginWithDistanceLoss(
+            distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
+        )
+        self.cos_emb_loss = nn.CosineEmbeddingLoss()
+
+        self.guided_attention_scheduler = guided_attention_scheduler
+        self.guided_attention_hard_stop = guided_attention_hard_stop
+
+    def forward(
+        self,
+        model_output,
+        targets,
+        input_lengths,
+        target_lengths,
+        spk_embs,
+        epoch,
+    ):
+        """Computes the loss
+        Arguments
+        ---------
+        model_output: tuple
+            the output of the model's forward():
+            (mel_outputs, mel_outputs_postnet, gate_outputs, alignments)
+        targets: tuple
+            the targets
+        input_lengths: torch.Tensor
+            a (batch, length) tensor of input lengths
+        target_lengths: torch.Tensor
+            a (batch, length) tensor of target (spectrogram) lengths
+        spk_embs: torch.Tensor
+            Speaker embedding input for the loss computation - placeholder for future work
+        epoch: int
+            the current epoch number (used for the scheduling of the guided attention
+            loss) A StepScheduler is typically used
+        Returns
+        -------
+        result: LossStats
+            the total loss - and individual losses (mel and gate)
+        """
+        mel_target, gate_target = targets[0], targets[1]
+        mel_target.requires_grad = False
+        gate_target.requires_grad = False
+        gate_target = gate_target.view(-1, 1)
+
+        (
+            mel_out,
+            mel_out_postnet,
+            gate_out,
+            alignments,
+            pred_mel_lens,
+        ) = model_output
+
+        gate_out = gate_out.view(-1, 1)
+        mel_loss = self.mse_loss(mel_out, mel_target) + self.mse_loss(
+            mel_out_postnet, mel_target
+        )
+
+        mel_loss = self.mel_loss_weight * mel_loss
+
+        gate_loss = self.gate_loss_weight * self.bce_loss(gate_out, gate_target)
+        attn_loss, attn_weight = self.get_attention_loss(
+            alignments, input_lengths, target_lengths, epoch
+        )
+
+        # Speaker embedding loss placeholder - for future work
+        spk_emb_loss = torch.Tensor([0]).to(mel_loss.device)
+
+        if self.spk_emb_loss_type == "scl_loss":
+            target_spk_embs, preds_spk_embs = spk_embs
+
+            cos_sim_scores = self.cos_sim(preds_spk_embs, target_spk_embs)
+            spk_emb_loss = -torch.div(
+                torch.sum(cos_sim_scores), len(cos_sim_scores)
+            )
+
+        if self.spk_emb_loss_type == "cos_emb_loss":
+            target_spk_embs, preds_spk_embs = spk_embs
+            spk_emb_loss = self.cos_emb_loss(
+                target_spk_embs,
+                preds_spk_embs,
+                torch.ones(len(target_spk_embs)).to(target_spk_embs.device),
+            )
+
+        if self.spk_emb_loss_type == "triplet_loss":
+            anchor_spk_embs, pos_spk_embs, neg_spk_embs = spk_embs
+            if anchor_spk_embs is not None:
+                spk_emb_loss = self.triplet_loss(
+                    anchor_spk_embs, pos_spk_embs, neg_spk_embs
+                )
+
+        spk_emb_loss = self.spk_emb_loss_weight * spk_emb_loss
+
+        total_loss = mel_loss + spk_emb_loss + gate_loss + attn_loss
+        return LossStats(
+            total_loss,
+            mel_loss,
+            spk_emb_loss,
+            gate_loss,
+            attn_loss,
+            attn_weight,
+        )
+
+    def get_attention_loss(
+        self, alignments, input_lengths, target_lengths, epoch
+    ):
+        """Computes the attention loss
+        Arguments
+        ---------
+        alignments: torch.Tensor
+            the aligment matrix from the model
+        input_lengths: torch.Tensor
+            a (batch, length) tensor of input lengths
+        target_lengths: torch.Tensor
+            a (batch, length) tensor of target (spectrogram) lengths
+        epoch: int
+            the current epoch number (used for the scheduling of the guided attention
+            loss) A StepScheduler is typically used
+        Returns
+        -------
+        attn_loss: torch.Tensor
+            the attention loss value
+        """
+        zero_tensor = torch.tensor(0.0, device=alignments.device)
+        if (
+            self.guided_attention_weight is None
+            or self.guided_attention_weight == 0
+        ):
+            attn_weight, attn_loss = zero_tensor, zero_tensor
+        else:
+            hard_stop_reached = (
+                self.guided_attention_hard_stop is not None
+                and epoch > self.guided_attention_hard_stop
+            )
+            if hard_stop_reached:
+                attn_weight, attn_loss = zero_tensor, zero_tensor
+            else:
+                attn_weight = self.guided_attention_weight
+                if self.guided_attention_scheduler is not None:
+                    _, attn_weight = self.guided_attention_scheduler(epoch)
+            attn_weight = torch.tensor(attn_weight, device=alignments.device)
+            attn_loss = attn_weight * self.guided_attention_loss(
+                alignments, input_lengths, target_lengths
+            )
+        return attn_loss, attn_weight
+
+
+class TextMelCollate:
+    """ Zero-pads model inputs and targets based on number of frames per step
+    Arguments
+    ---------
+    speaker_embeddings_pickle : str
+        Path to the file containing speaker embeddings
+    n_frames_per_step: int
+        The number of output frames per step
+    Returns
+    -------
+    result: tuple
+        a tuple inputs/targets
+        (
+            text_padded,
+            input_lengths,
+            mel_padded,
+            gate_padded,
+            output_lengths,
+            len_x,
+            labels,
+            wavs,
+            spk_embs,
+            spk_ids
+        )
+    """
+
+    def __init__(
+        self, speaker_embeddings_pickle, n_frames_per_step=1,
+    ):
+        self.n_frames_per_step = n_frames_per_step
+        self.speaker_embeddings_pickle = speaker_embeddings_pickle
+
+    # TODO: Make this more intuitive, use the pipeline
+    def __call__(self, batch):
+        """Collate's training batch from normalized text and mel-spectrogram
+        Arguments
+        ---------
+        batch: list
+            [text_normalized, mel_normalized]
+        """
+
+        # TODO: Remove for loops and this dirty hack
+        raw_batch = list(batch)
+        for i in range(
+            len(batch)
+        ):  # the pipline return a dictionary wiht one elemnent
+            batch[i] = batch[i]["mel_text_pair"]
+
+        # Right zero-pad all one-hot text sequences to max input length
+
+        input_lengths, ids_sorted_decreasing = torch.sort(
+            torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
+        )
+        max_input_len = input_lengths[0]
+
+        text_padded = torch.LongTensor(len(batch), max_input_len)
+        text_padded.zero_()
+        for i in range(len(ids_sorted_decreasing)):
+            text = batch[ids_sorted_decreasing[i]][0]
+            text_padded[i, : text.size(0)] = text
+
+        # Right zero-pad mel-spec
+        num_mels = batch[0][1].size(0)
+        max_target_len = max([x[1].size(1) for x in batch])
+        if max_target_len % self.n_frames_per_step != 0:
+            max_target_len += (
+                self.n_frames_per_step - max_target_len % self.n_frames_per_step
+            )
+            assert max_target_len % self.n_frames_per_step == 0
+
+        # include mel padded and gate padded
+        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
+        mel_padded.zero_()
+        gate_padded = torch.FloatTensor(len(batch), max_target_len)
+        gate_padded.zero_()
+        output_lengths = torch.LongTensor(len(batch))
+        labels, wavs, spk_embs_list, spk_ids = [], [], [], []
+        with open(
+            self.speaker_embeddings_pickle, "rb"
+        ) as speaker_embeddings_file:
+            speaker_embeddings = pickle.load(speaker_embeddings_file)
+
+        for i in range(len(ids_sorted_decreasing)):
+            idx = ids_sorted_decreasing[i]
+            mel = batch[idx][1]
+            mel_padded[i, :, : mel.size(1)] = mel
+            gate_padded[i, mel.size(1) - 1 :] = 1
+            output_lengths[i] = mel.size(1)
+            labels.append(raw_batch[idx]["label"])
+            wavs.append(raw_batch[idx]["wav"])
+
+            spk_emb = speaker_embeddings[raw_batch[idx]["uttid"]]
+            spk_embs_list.append(spk_emb)
+
+            spk_ids.append(raw_batch[idx]["uttid"].split("_")[0])
+
+        spk_embs = torch.stack(spk_embs_list)
+
+        # count number of items - characters in text
+        len_x = [x[2] for x in batch]
+        len_x = torch.Tensor(len_x)
+        return (
+            text_padded,
+            input_lengths,
+            mel_padded,
+            gate_padded,
+            output_lengths,
+            len_x,
+            labels,
+            wavs,
+            spk_embs,
+            spk_ids,
+        )
diff --git a/speechbrain/lobes/models/Tacotron2.py b/speechbrain/lobes/models/Tacotron2.py
index 4db80005535ae14dd45c6a5e8c4a9f7bf7c63d92..5d6267c6968521d1c5cc998fbb663e1883a547f6 100644
--- a/speechbrain/lobes/models/Tacotron2.py
+++ b/speechbrain/lobes/models/Tacotron2.py
@@ -40,6 +40,9 @@ Authors
 
 from math import sqrt
 from speechbrain.nnet.loss.guidedattn_loss import GuidedAttentionLoss
+from speechbrain.lobes.models.transformer.Transformer import (
+    get_mask_from_lengths,
+)
 import torch
 from torch import nn
 from torch.nn import functional as F
@@ -268,7 +271,9 @@ class Attention(nn.Module):
     -------
     >>> import torch
     >>> from speechbrain.lobes.models.Tacotron2 import (
-    ...     Attention, get_mask_from_lengths)
+    ... Attention)
+    >>> from speechbrain.lobes.models.transformer.Transformer import (
+    ... get_mask_from_lengths)
     >>> layer = Attention()
     >>> attention_hidden_state = torch.randn(2, 1024)
     >>> memory = torch.randn(2, 173, 512)
@@ -1523,31 +1528,6 @@ class Tacotron2(nn.Module):
         return mel_outputs_postnet, mel_lengths, alignments
 
 
-def get_mask_from_lengths(lengths, max_len=None):
-    """Creates a mask from a tensor of lengths
-
-    Arguments
-    ---------
-    lengths: torch.Tensor
-        a tensor of sequence lengths
-
-    Returns
-    -------
-    mask: torch.Tensor
-        the mask
-    max_len: int
-        The maximum length, i.e. the last dimension of
-        the mask tensor. If not provided, it will be
-        calculated automatically
-    """
-    if max_len is None:
-        max_len = torch.max(lengths).item()
-    ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype)
-    mask = (ids < lengths.unsqueeze(1)).byte()
-    mask = torch.le(mask, 0)
-    return mask
-
-
 def infer(model, text_sequences, input_lengths):
     """
     An inference hook for pretrained synthesizers
diff --git a/speechbrain/lobes/models/convolution.py b/speechbrain/lobes/models/convolution.py
index d0cdc3e522dbaa8556a7f31905117fa281a79a2c..3f73ca8440e23a22319246a6260375a2c509b3ac 100644
--- a/speechbrain/lobes/models/convolution.py
+++ b/speechbrain/lobes/models/convolution.py
@@ -8,6 +8,10 @@ import torch
 from speechbrain.nnet.CNN import Conv2d, Conv1d
 from speechbrain.nnet.containers import Sequential
 from speechbrain.nnet.normalization import LayerNorm
+from speechbrain.utils.filter_analysis import (
+    FilterProperties,
+    stack_filter_properties,
+)
 
 
 class ConvolutionalSpatialGatingUnit(torch.nn.Module):
@@ -174,6 +178,11 @@ class ConvolutionFrontEnd(Sequential):
                 conv_init=conv_init,
             )
 
+    def get_filter_properties(self) -> FilterProperties:
+        return stack_filter_properties(
+            block.get_filter_properties() for block in self.children()
+        )
+
 
 class ConvBlock(torch.nn.Module):
     """An implementation of convolution block with 1d or 2d convolutions (depthwise).
@@ -223,19 +232,28 @@ class ConvBlock(torch.nn.Module):
     ):
         super().__init__()
         self.convs = Sequential(input_shape=input_shape)
+        self.filter_properties = []
 
         for i in range(num_layers):
+            layer_stride = stride if i == num_layers - 1 else 1
             self.convs.append(
                 conv_module,
                 out_channels=out_channels,
                 kernel_size=kernel_size,
-                stride=stride if i == num_layers - 1 else 1,
+                stride=layer_stride,
                 dilation=dilation,
                 layer_name=f"conv_{i}",
                 bias=conv_bias,
                 padding=padding,
                 conv_init=conv_init,
             )
+            self.filter_properties.append(
+                FilterProperties(
+                    window_size=kernel_size,
+                    stride=layer_stride,
+                    dilation=dilation,
+                )
+            )
             if norm is not None:
                 self.convs.append(norm, layer_name=f"norm_{i}")
             self.convs.append(activation(), layer_name=f"act_{i}")
@@ -264,3 +282,6 @@ class ConvBlock(torch.nn.Module):
             out = out + self.reduce_conv(x)
             out = self.drop(out)
         return out
+
+    def get_filter_properties(self) -> FilterProperties:
+        return stack_filter_properties(self.filter_properties)
diff --git a/speechbrain/lobes/models/discrete/__init__.py b/speechbrain/lobes/models/discrete/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..da3fcfa0cc466c2ff8cdf9d826c4ee26ec8af5cc
--- /dev/null
+++ b/speechbrain/lobes/models/discrete/__init__.py
@@ -0,0 +1,2 @@
+""" Package defining discrete models.
+"""
diff --git a/speechbrain/lobes/models/discrete/dac.py b/speechbrain/lobes/models/discrete/dac.py
new file mode 100644
index 0000000000000000000000000000000000000000..bebc17cda1b84b80e65116cea11c314a52031245
--- /dev/null
+++ b/speechbrain/lobes/models/discrete/dac.py
@@ -0,0 +1,1139 @@
+"""
+This lobe enables the integration of pretrained discrete DAC model.
+Reference: http://arxiv.org/abs/2306.06546
+Reference: https://descript.notion.site/Descript-Audio-Codec-11389fce0ce2419891d6591a68f814d5
+Reference: https://github.com/descriptinc/descript-audio-codec
+
+Author
+ * Shubham Gupta 2023
+
+"""
+
+import math
+from pathlib import Path
+from typing import List, Union
+import logging
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# Note: The path torch.nn.utils.parametrizations may not be available
+# in older PyTorch versions, such as 1.13.1. To ensure compatibility,
+# it is recommended to check and use the appropriate import statement.
+
+# Attempt to import the preferred module for parametrizations in newer PyTorch versions
+try:
+    from torch.nn.utils.parametrizations import weight_norm
+
+# If the preferred import fails, fallback to the alternative import for compatibility
+except ImportError:
+    from torch.nn.utils import weight_norm
+
+logger = logging.getLogger(__name__)
+
+SUPPORTED_VERSIONS = ["1.0.0"]
+
+
+__MODEL_LATEST_TAGS__ = {
+    ("44khz", "8kbps"): "0.0.1",
+    ("24khz", "8kbps"): "0.0.4",
+    ("16khz", "8kbps"): "0.0.5",
+    ("44khz", "16kbps"): "1.0.0",
+}
+
+
+__MODEL_URLS__ = {
+    (
+        "44khz",
+        "0.0.1",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
+    (
+        "24khz",
+        "0.0.4",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
+    (
+        "16khz",
+        "0.0.5",
+        "8kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
+    (
+        "44khz",
+        "1.0.0",
+        "16kbps",
+    ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
+}
+
+
+def WNConv1d(*args, **kwargs):
+    """
+    Apply weight normalization to a 1D convolutional layer.
+
+    Parameters
+    ----------
+    *args
+        Variable length argument list for nn.Conv1d.
+    **kwargs
+        Arbitrary keyword arguments for nn.Conv1d.
+
+    Returns
+    -------
+    torch.nn.Module
+        The weight-normalized nn.Conv1d layer.
+    """
+    return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+    """
+    Apply weight normalization to a 1D transposed convolutional layer.
+
+    Parameters
+    ----------
+    *args
+        Variable length argument list for nn.ConvTranspose1d.
+    **kwargs
+        Arbitrary keyword arguments for nn.ConvTranspose1d.
+
+    Returns
+    -------
+    torch.nn.Module
+        The weight-normalized nn.ConvTranspose1d layer.
+    """
+    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+def init_weights(m):
+    """
+    Initialize the weights of a 1D convolutional layer.
+    """
+    if isinstance(m, nn.Conv1d):
+        nn.init.trunc_normal_(m.weight, std=0.02)
+        nn.init.constant_(m.bias, 0)
+
+
+def download(
+    model_type: str = "44khz",
+    model_bitrate: str = "8kbps",
+    tag: str = "latest",
+    local_path: Path = None,
+):
+    """
+    Downloads a specified model file based on model type, bitrate, and tag, saving it to a local path.
+
+    Parameters
+    ----------
+    model_type : str, optional
+        The type of model to download. Can be '44khz', '24khz', or '16khz'. Default is '44khz'.
+    model_bitrate : str, optional
+        The bitrate of the model. Can be '8kbps' or '16kbps'. Default is '8kbps'.
+    tag : str, optional
+        A specific version tag for the model. Default is 'latest'.
+    local_path : Path, optional
+        The local file path where the model will be saved. If not provided, a default path will be used.
+
+    Returns
+    -------
+    Path
+        The local path where the model is saved.
+
+    Raises
+    ------
+    ValueError
+        If the model type or bitrate is not supported, or if the model cannot be found or downloaded.
+    """
+
+    model_type = model_type.lower()
+    tag = tag.lower()
+
+    assert model_type in [
+        "44khz",
+        "24khz",
+        "16khz",
+    ], "model_type must be one of '44khz', '24khz', or '16khz'"
+
+    assert model_bitrate in [
+        "8kbps",
+        "16kbps",
+    ], "model_bitrate must be one of '8kbps', or '16kbps'"
+
+    if tag == "latest":
+        tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
+
+    download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
+    logger.info(f"Download link: {download_link}")
+
+    if download_link is None:
+        raise ValueError(
+            f"Could not find model with tag {tag} and model type {model_type}"
+        )
+
+    if local_path is None:
+        local_path = (
+            Path.home()
+            / f".cache/descript/dac/weights_{model_type}_{model_bitrate}_{tag}.pth"
+        )
+
+    if not local_path.exists():
+        local_path.parent.mkdir(parents=True, exist_ok=True)
+
+        # Download the model
+        import requests
+
+        response = requests.get(download_link)
+
+        if response.status_code != 200:
+            raise ValueError(
+                f"Could not download model. Received response code {response.status_code}"
+            )
+        local_path.write_bytes(response.content)
+
+    return local_path
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+    """
+    Applies the 'snake' activation function on the input tensor.
+
+    This function reshapes the input tensor, applies a modified sine function to it, and then reshapes it back
+    to its original shape.
+
+    Parameters
+    ----------
+    x : torch.Tensor
+        The input tensor to which the snake activation function will be applied.
+    alpha : float
+        A scalar value that modifies the sine function within the snake activation.
+
+    Returns
+    -------
+    torch.Tensor
+        The transformed tensor after applying the snake activation function.
+    """
+    shape = x.shape
+    x = x.reshape(shape[0], shape[1], -1)
+    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+    x = x.reshape(shape)
+    return x
+
+
+class VectorQuantize(nn.Module):
+    """
+    An implementation for Vector Quantization
+    """
+
+    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+        """
+        Implementation of VQ similar to Karpathy's repo:
+        https://github.com/karpathy/deep-vector-quantization
+        Additionally uses following tricks from Improved VQGAN
+        (https://arxiv.org/pdf/2110.04627.pdf):
+            1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+                for improved codebook usage
+            2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+                improves training stability
+        """
+        super().__init__()
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+
+        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+        self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+    def forward(self, z: torch.Tensor):
+        """Quantized the input tensor using a fixed codebook and returns
+        the corresponding codebook vectors
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        Tensor[1]
+            Commitment loss to train encoder to predict vectors closer to codebook
+            entries
+        Tensor[1]
+            Codebook loss to update the codebook
+        Tensor[B x T]
+            Codebook indices (quantized discrete representation of input)
+        Tensor[B x D x T]
+            Projected latents (continuous representation of input before quantization)
+        """
+
+        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+        z_e = self.in_proj(z)  # z_e : (B x D x T)
+        z_q, indices = self.decode_latents(z_e)
+
+        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean(
+            [1, 2]
+        )
+        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean(
+            [1, 2]
+        )
+
+        z_q = (
+            z_e + (z_q - z_e).detach()
+        )  # noop in forward pass, straight-through gradient estimator in backward pass
+
+        z_q = self.out_proj(z_q)
+
+        return z_q, commitment_loss, codebook_loss, indices, z_e
+
+    def embed_code(self, embed_id: torch.Tensor):
+        """
+        Embeds an ID using the codebook weights.
+
+        This method utilizes the codebook weights to embed the given ID.
+
+        Parameters
+        ----------
+        embed_id : torch.Tensor
+            The tensor containing IDs that need to be embedded.
+
+        Returns
+        -------
+        torch.Tensor
+            The embedded output tensor after applying the codebook weights.
+        """
+        return F.embedding(embed_id, self.codebook.weight)
+
+    def decode_code(self, embed_id: torch.Tensor):
+        """
+        Decodes the embedded ID by transposing the dimensions.
+
+        This method decodes the embedded ID by applying a transpose operation to the dimensions of the
+        output tensor from the `embed_code` method.
+
+        Parameters
+        ----------
+        embed_id : torch.Tensor
+            The tensor containing embedded IDs.
+
+        Returns
+        -------
+        torch.Tensor
+            The decoded tensor
+        """
+        return self.embed_code(embed_id).transpose(1, 2)
+
+    def decode_latents(self, latents: torch.Tensor):
+        """
+        Decodes latent representations into discrete codes by comparing with the codebook.
+
+        Parameters
+        ----------
+        latents : torch.Tensor
+            The latent tensor representations to be decoded.
+
+        Returns
+        -------
+        Tuple[torch.Tensor, torch.Tensor]
+            A tuple containing the decoded latent tensor (`z_q`) and the indices of the codes.
+        """
+        encodings = latents.permute(0, 2, 1).reshape(-1, latents.size(1))
+        codebook = self.codebook.weight  # codebook: (N x D)
+
+        # L2 normalize encodings and codebook (ViT-VQGAN)
+        encodings = F.normalize(encodings)
+        codebook = F.normalize(codebook)
+
+        # Compute euclidean distance with codebook
+        dist = (
+            encodings.pow(2).sum(1, keepdim=True)
+            - 2 * encodings @ codebook.t()
+            + codebook.pow(2).sum(1, keepdim=True).t()
+        )
+
+        # indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+
+        max_indices = (-dist).max(dim=1)[1]
+        b = latents.size(0)
+        t = max_indices.numel() // b
+        indices = max_indices.view(b, t)
+        z_q = self.decode_code(indices)
+        return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+    """
+    Introduced in SoundStream: An end2end neural audio codec
+    https://arxiv.org/abs/2107.03312
+
+
+    Example
+    -------
+    Using a pretrained RVQ unit.
+
+    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
+    >>> quantizer = dac.quantizer
+    >>> continuous_embeddings = torch.randn(1, 1024, 100) # Example shape: [Batch, Channels, Time]
+    >>> discrete_embeddings, codes, _, _, _ = quantizer(continuous_embeddings)
+
+    """
+
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: float = 0.0,
+    ):
+        """
+        Initializes the ResidualVectorQuantize
+
+        Parameters
+        ----------
+        input_dim : int, optional, by default 512
+        n_codebooks : int, optional, by default 9
+        codebook_size : int, optional, by default 1024
+        codebook_dim : Union[int, list], optional,  by default 8
+        quantizer_dropout : float, optional, by default 0.0
+        """
+        super().__init__()
+        if isinstance(codebook_dim, int):
+            codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+        self.n_codebooks = n_codebooks
+        self.codebook_dim = codebook_dim
+        self.codebook_size = codebook_size
+
+        self.quantizers = nn.ModuleList(
+            [
+                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+                for i in range(n_codebooks)
+            ]
+        )
+        self.quantizer_dropout = quantizer_dropout
+
+    def forward(self, z, n_quantizers: int = None):
+        """Quantized the input tensor using a fixed set of `n` codebooks and returns
+        the corresponding codebook vectors
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+        n_quantizers : int, optional
+            No. of quantizers to use
+            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+            Note: if `self.quantizer_dropout` is True, this argument is ignored
+                when in training mode, and a random number of quantizers is used.
+        Returns
+        -------
+        z : Tensor[B x D x T]
+            Quantized continuous representation of input
+        codes : Tensor[B x N x T]
+            Codebook indices for each codebook
+            (quantized discrete representation of input)
+        latents : Tensor[B x N*D x T]
+            Projected latents (continuous representation of input before quantization)
+        vq/commitment_loss : Tensor[1]
+            Commitment loss to train encoder to predict vectors closer to codebook
+            entries
+        vq/codebook_loss : Tensor[1]
+            Codebook loss to update the codebook
+        """
+        z_q = 0
+        residual = z
+        commitment_loss = 0
+        codebook_loss = 0
+
+        codebook_indices = []
+        latents = []
+
+        if n_quantizers is None:
+            n_quantizers = self.n_codebooks
+        if self.training:
+            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+            dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+            n_dropout = int(z.shape[0] * self.quantizer_dropout)
+            n_quantizers[:n_dropout] = dropout[:n_dropout]
+            n_quantizers = n_quantizers.to(z.device)
+
+        for i, quantizer in enumerate(self.quantizers):
+            if self.training is False and i >= n_quantizers:
+                break
+
+            (
+                z_q_i,
+                commitment_loss_i,
+                codebook_loss_i,
+                indices_i,
+                z_e_i,
+            ) = quantizer(residual)
+
+            # Create mask to apply quantizer dropout
+            mask = (
+                torch.full((z.shape[0],), fill_value=i, device=z.device)
+                < n_quantizers
+            )
+            z_q = z_q + z_q_i * mask[:, None, None]
+            residual = residual - z_q_i
+
+            # Sum losses
+            commitment_loss += (commitment_loss_i * mask).mean()
+            codebook_loss += (codebook_loss_i * mask).mean()
+
+            codebook_indices.append(indices_i)
+            latents.append(z_e_i)
+
+        codes = torch.stack(codebook_indices, dim=1)
+        latents = torch.cat(latents, dim=1)
+
+        return z_q, codes, latents, commitment_loss, codebook_loss
+
+    def from_codes(self, codes: torch.Tensor):
+        """Given the quantized codes, reconstruct the continuous representation
+        Parameters
+        ----------
+        codes : Tensor[B x N x T]
+            Quantized discrete representation of input
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        """
+        z_q = 0.0
+        z_p = []
+        n_codebooks = codes.shape[1]
+        for i in range(n_codebooks):
+            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+            z_p.append(z_p_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+        return z_q, torch.cat(z_p, dim=1), codes
+
+    def from_latents(self, latents: torch.Tensor):
+        """Given the unquantized latents, reconstruct the
+        continuous representation after quantization.
+
+        Parameters
+        ----------
+        latents : Tensor[B x N x T]
+            Continuous representation of input after projection
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized representation of full-projected space
+        Tensor[B x D x T]
+            Quantized representation of latent space
+        """
+        z_q = 0
+        z_p = []
+        codes = []
+        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+        n_codebooks = np.where(dims <= latents.shape[1])[0].max(
+            axis=0, keepdims=True
+        )[0]
+        for i in range(n_codebooks):
+            j, k = dims[i], dims[i + 1]
+            z_p_i, codes_i = self.quantizers[i].decode_latents(
+                latents[:, j:k, :]
+            )
+            z_p.append(z_p_i)
+            codes.append(codes_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+
+        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+class Snake1d(nn.Module):
+    """
+    A PyTorch module implementing the Snake activation function in 1D.
+
+    Parameters
+    ----------
+    channels : int
+        The number of channels in the input tensor.
+    """
+
+    def __init__(self, channels):
+        """
+        Initializes Snake1d
+        Parameters
+        ----------
+        channels : int
+        """
+        super().__init__()
+        self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+    def forward(self, x):
+        """
+
+        Parameters
+        ----------
+        x : torch.Tensor
+
+        Returns
+        -------
+        torch.Tensor
+        """
+        return snake(x, self.alpha)
+
+
+class ResidualUnit(nn.Module):
+    """
+    A residual unit module for convolutional neural networks.
+
+    Parameters
+    ----------
+    dim : int, optional
+        The number of channels in the input tensor. Default is 16.
+    dilation : int, optional
+        The dilation rate for the convolutional layers. Default is 1.
+
+    """
+
+    def __init__(self, dim: int = 16, dilation: int = 1):
+        """
+        Initializes the ResidualUnit
+        Parameters
+        ----------
+        dim : int, optional, by default 16
+        dilation : int, optional, by default 1
+        """
+        super().__init__()
+        pad = ((7 - 1) * dilation) // 2
+        self.block = nn.Sequential(
+            Snake1d(dim),
+            WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+            Snake1d(dim),
+            WNConv1d(dim, dim, kernel_size=1),
+        )
+
+    def forward(self, x: torch.tensor) -> torch.tensor:
+        """
+        Parameters
+        ----------
+        x : torch.tensor
+
+        Returns
+        -------
+        torch.tensor
+        """
+        y = self.block(x)
+        pad = (x.shape[-1] - y.shape[-1]) // 2
+        if pad > 0:
+            x = x[..., pad:-pad]
+        return x + y
+
+
+class EncoderBlock(nn.Module):
+    """
+    An encoder block module for convolutional neural networks.
+
+    This module constructs an encoder block consisting of a series of ResidualUnits and a final Snake1d
+    activation followed by a weighted normalized 1D convolution. This block can be used as part of an
+    encoder in architectures like autoencoders.
+
+    Parameters
+    ----------
+    dim : int, optional
+        The number of output channels. Default is 16.
+    stride : int, optional
+        The stride for the final convolutional layer. Default is 1.
+    """
+
+    def __init__(self, dim: int = 16, stride: int = 1):
+        """
+        Initializes the EncoderBlock
+        Parameters
+        ----------
+        dim : int, optional, by default 16
+        stride : int, optional, by default 1
+        """
+        super().__init__()
+        self.block = nn.Sequential(
+            ResidualUnit(dim // 2, dilation=1),
+            ResidualUnit(dim // 2, dilation=3),
+            ResidualUnit(dim // 2, dilation=9),
+            Snake1d(dim // 2),
+            WNConv1d(
+                dim // 2,
+                dim,
+                kernel_size=2 * stride,
+                stride=stride,
+                padding=math.ceil(stride / 2),
+            ),
+        )
+
+    def forward(self, x: torch.tensor):
+        """
+        Parameters
+        ----------
+        x : torch.tensor
+
+        Returns
+        -------
+        torch.tensor
+        """
+        return self.block(x)
+
+
+class Encoder(nn.Module):
+    """
+    A PyTorch module for the Encoder part of DAC.
+
+    Parameters
+    ----------
+    d_model : int, optional
+        The initial dimensionality of the model. Default is 64.
+    strides : list, optional
+        A list of stride values for downsampling in each EncoderBlock. Default is [2, 4, 8, 8].
+    d_latent : int, optional
+        The dimensionality of the output latent space. Default is 64.
+
+    Example
+    -------
+    Creating an Encoder instance
+    >>> encoder = Encoder()
+    >>> audio_input = torch.randn(1, 1, 44100) # Example shape: [Batch, Channels, Time]
+    >>> continuous_embedding = encoder(audio_input)
+
+    Using a pretrained encoder.
+
+    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
+    >>> encoder = dac.encoder
+    >>> audio_input = torch.randn(1, 1, 44100) # Example shape: [Batch, Channels, Time]
+    >>> continuous_embeddings = encoder(audio_input)
+    """
+
+    def __init__(
+        self,
+        d_model: int = 64,
+        strides: list = [2, 4, 8, 8],
+        d_latent: int = 64,
+    ):
+        """
+        Initializes the Encoder
+
+        Parameters
+        ----------
+        d_model : int, optional, by default 64
+        strides : list, optional, by default [2, 4, 8, 8]
+        d_latent : int, optional, by default 64
+        """
+        super().__init__()
+        # Create first convolution
+        self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+        # Create EncoderBlocks that double channels as they downsample by `stride`
+        for stride in strides:
+            d_model *= 2
+            self.block += [EncoderBlock(d_model, stride=stride)]
+
+        # Create last convolution
+        self.block += [
+            Snake1d(d_model),
+            WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
+        ]
+
+        # Wrap black into nn.Sequential
+        self.block = nn.Sequential(*self.block)
+        self.enc_dim = d_model
+
+    def forward(self, x):
+        """
+        Parameters
+        ----------
+        x : torch.tensor
+
+        Returns
+        -------
+        torch.tensor
+        """
+        return self.block(x)
+
+
+class DecoderBlock(nn.Module):
+    """
+    A PyTorch module representing a block within the Decoder architecture.
+
+    Parameters
+    ----------
+    input_dim : int, optional
+        The number of input channels. Default is 16.
+    output_dim : int, optional
+        The number of output channels. Default is 8.
+    stride : int, optional
+        The stride for the transposed convolution, controlling the upsampling. Default is 1.
+    """
+
+    def __init__(
+        self, input_dim: int = 16, output_dim: int = 8, stride: int = 1
+    ):
+        """
+        Initializes the DecoderBlock
+
+        Parameters
+        ----------
+        input_dim : int, optional, by default 16
+        output_dim : int, optional, by default 8
+        stride : int, optional, by default 1
+        """
+        super().__init__()
+        self.block = nn.Sequential(
+            Snake1d(input_dim),
+            WNConvTranspose1d(
+                input_dim,
+                output_dim,
+                kernel_size=2 * stride,
+                stride=stride,
+                padding=math.ceil(stride / 2),
+            ),
+            ResidualUnit(output_dim, dilation=1),
+            ResidualUnit(output_dim, dilation=3),
+            ResidualUnit(output_dim, dilation=9),
+        )
+
+    def forward(self, x):
+        """
+
+        Parameters
+        ----------
+        x : torch.tensor
+
+        Returns
+        -------
+        torch.tensor
+        """
+        return self.block(x)
+
+
+class Decoder(nn.Module):
+    """
+    A PyTorch module for the Decoder part of DAC.
+
+    Parameters
+    ----------
+    input_channel : int
+        The number of channels in the input tensor.
+    channels : int
+        The base number of channels for the convolutional layers.
+    rates : list
+        A list of stride rates for each decoder block
+    d_out: int
+        The out dimension of the final conv layer, Default is 1.
+
+    Example
+    -------
+    Creating a Decoder instance
+
+    >>> decoder = Decoder(256, 1536,  [8, 8, 4, 2])
+    >>> discrete_embeddings = torch.randn(2, 256, 200) # Example shape: [Batch, Channels, Time]
+    >>> recovered_audio = decoder(discrete_embeddings)
+
+    Using a pretrained decoder. Note that the actual input should be proper discrete representation.
+    Using randomly generated input here for illustration of use.
+
+    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
+    >>> decoder = dac.decoder
+    >>> discrete_embeddings = torch.randn(1, 1024, 500) # Example shape: [Batch, Channels, Time]
+    >>> recovered_audio = decoder(discrete_embeddings)
+    """
+
+    def __init__(
+        self,
+        input_channel: int,
+        channels: int,
+        rates: List[int],
+        d_out: int = 1,
+    ):
+        """Initializes Decoder
+
+        Parameters
+        ----------
+        input_channel : int
+        channels : int
+        rates : List[int]
+        d_out : int, optional, by default 1
+        """
+        super().__init__()
+
+        # Add first conv layer
+        layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
+
+        # Add upsampling + MRF blocks
+        for i, stride in enumerate(rates):
+            input_dim = channels // 2 ** i
+            output_dim = channels // 2 ** (i + 1)
+            layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+        # Add final conv layer
+        layers += [
+            Snake1d(output_dim),
+            WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
+            nn.Tanh(),
+        ]
+
+        self.model = nn.Sequential(*layers)
+
+    def forward(self, x):
+        """
+
+        Parameters
+        ----------
+        x : torch.tensor
+
+        Returns
+        -------
+        torch.tensor
+        """
+        return self.model(x)
+
+
+class DAC(nn.Module):
+    """
+    Discrete Autoencoder Codec (DAC) for audio data encoding and decoding.
+
+    This class implements an autoencoder architecture with quantization for efficient audio processing.
+    It includes an encoder, quantizer, and decoder for transforming audio data into a compressed latent representation and reconstructing it back into audio.
+    This implementation supports both initializing a new model and loading a pretrained model.
+
+    Parameters
+    ----------
+    encoder_dim : int
+        Dimensionality of the encoder.
+    encoder_rates : List[int]
+        Downsampling rates for each encoder layer.
+    latent_dim : int, optional
+        Dimensionality of the latent space, automatically calculated if None.
+    decoder_dim : int
+        Dimensionality of the decoder.
+    decoder_rates : List[int]
+        Upsampling rates for each decoder layer.
+    n_codebooks : int
+        Number of codebooks for vector quantization.
+    codebook_size : int
+        Size of each codebook.
+    codebook_dim : Union[int, list]
+        Dimensionality of each codebook entry.
+    quantizer_dropout : bool
+        Whether to use dropout in the quantizer.
+    sample_rate : int
+        Sample rate of the audio data.
+    model_type : str
+        Type of the model to load (if pretrained).
+    model_bitrate : str
+        Bitrate of the model to load (if pretrained).
+    tag : str
+        Specific tag of the model to load (if pretrained).
+    load_path : str, optional
+        Path to load the pretrained model from, automatically downloaded if None.
+    strict : bool
+        Whether to strictly enforce the state dictionary match.
+    load_pretrained : bool
+        Whether to load a pretrained model.
+
+    Example
+    -------
+    Creating a new DAC instance:
+
+    >>> dac = DAC()
+    >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time]
+    >>> tokens, embeddings = dac(audio_data)
+
+
+    Loading a pretrained DAC instance:
+
+    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
+    >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time]
+    >>> tokens, embeddings = dac(audio_data)
+
+    The tokens and the discrete embeddings obtained above or from other sources can be decoded:
+
+    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
+    >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time]
+    >>> tokens, embeddings = dac(audio_data)
+    >>> decoded_audio = dac.decode(embeddings)
+    """
+
+    def __init__(
+        self,
+        encoder_dim: int = 64,
+        encoder_rates: List[int] = [2, 4, 8, 8],
+        latent_dim: int = None,
+        decoder_dim: int = 1536,
+        decoder_rates: List[int] = [8, 8, 4, 2],
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: bool = False,
+        sample_rate: int = 44100,
+        model_type: str = "44khz",
+        model_bitrate: str = "8kbps",
+        tag: str = "latest",
+        load_path: str = None,
+        strict: bool = False,
+        load_pretrained: bool = False,
+    ):
+        """ Initializes DAC
+
+        Parameters
+        ----------
+        encoder_dim : int, optional, by default 64
+        encoder_rates : List[int], optional, by default [2, 4, 8, 8]
+        latent_dim : int, optional, by default None
+        decoder_dim : int, optional, by default 1536
+        decoder_rates : List[int], optional, by default [8, 8, 4, 2]
+        n_codebooks : int, optional, by default 9
+        codebook_size : int, optional, by default 1024
+        codebook_dim : Union[int, list], optional, by default 8
+        quantizer_dropout : bool, optional, by default False
+        sample_rate : int, optional, by default 44100
+        model_type : str, optional, by default "44khz"
+        model_bitrate : str, optional, by default "8kbps"
+        tag : str, optional, by default "latest"
+        load_path : str, optional, by default None
+        strict : bool, optional, by default False
+        load_pretrained : bool, optional
+             If True, then a pretrained model is loaded, by default False
+        """
+        super().__init__()
+
+        self.encoder_dim = encoder_dim
+        self.encoder_rates = encoder_rates
+        self.decoder_dim = decoder_dim
+        self.decoder_rates = decoder_rates
+        self.sample_rate = sample_rate
+        self.n_codebooks = n_codebooks
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+        self.latent_dim = latent_dim
+        self.quantizer_dropout = quantizer_dropout
+
+        if load_pretrained:
+            if not load_path:
+                load_path = download(
+                    model_type=model_type, model_bitrate=model_bitrate, tag=tag
+                )
+                logger.info(f"Obtained load path as: {load_path}")
+            model_dict = torch.load(load_path, "cpu")
+            metadata = model_dict["metadata"]
+            for key, value in metadata["kwargs"].items():
+                setattr(self, key, value)
+
+        self.hop_length = np.prod(self.encoder_rates)
+        if self.latent_dim is None:
+            self.latent_dim = self.encoder_dim * (2 ** len(self.encoder_rates))
+        self.encoder = Encoder(
+            self.encoder_dim, self.encoder_rates, self.latent_dim
+        )
+        self.quantizer = ResidualVectorQuantize(
+            input_dim=self.latent_dim,
+            n_codebooks=self.n_codebooks,
+            codebook_size=self.codebook_size,
+            codebook_dim=self.codebook_dim,
+            quantizer_dropout=self.quantizer_dropout,
+        )
+        self.decoder = Decoder(
+            self.latent_dim, self.decoder_dim, self.decoder_rates,
+        )
+        self.apply(init_weights)
+
+        if load_pretrained:
+            self.load_state_dict(model_dict["state_dict"], strict=strict)
+            self.metadata = metadata
+
+    def encode(
+        self, audio_data: torch.Tensor, n_quantizers: int = None,
+    ):
+        """Encode given audio data and return quantized latent codes
+
+        Parameters
+        ----------
+        audio_data : Tensor[B x 1 x T]
+            Audio data to encode
+        n_quantizers : int, optional
+            Number of quantizers to use, by default None
+            If None, all quantizers are used.
+
+        Returns
+        -------
+        "z" : Tensor[B x D x T]
+            Quantized continuous representation of input
+        "codes" : Tensor[B x N x T]
+            Codebook indices for each codebook
+            (quantized discrete representation of input)
+        "latents" : Tensor[B x N*D x T]
+            Projected latents (continuous representation of input before quantization)
+        "vq/commitment_loss" : Tensor[1]
+            Commitment loss to train encoder to predict vectors closer to codebook
+            entries
+        "vq/codebook_loss" : Tensor[1]
+            Codebook loss to update the codebook
+        "length" : int
+            Number of samples in input audio
+        """
+        z = self.encoder(audio_data)
+        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
+            z, n_quantizers
+        )
+        return z, codes, latents, commitment_loss, codebook_loss
+
+    def decode(self, z: torch.Tensor):
+        """Decode given latent codes and return audio data
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+            Quantized continuous representation of input
+        length : int, optional
+            Number of samples in output audio, by default None
+
+        Returns
+        -------
+        torch.Tensor: shape B x 1 x length
+            Decoded audio data.
+        """
+        return self.decoder(z)
+
+    def forward(
+        self,
+        audio_data: torch.Tensor,
+        sample_rate: int = None,
+        n_quantizers: int = None,
+    ):
+        """Model forward pass
+
+        Parameters
+        ----------
+        audio_data : Tensor[B x 1 x T]
+            Audio data to encode
+        sample_rate : int, optional
+            Sample rate of audio data in Hz, by default None
+            If None, defaults to `self.sample_rate`
+        n_quantizers : int, optional
+            Number of quantizers to use, by default None.
+            If None, all quantizers are used.
+
+        Returns
+        -------
+        "tokens" : Tensor[B x N x T]
+            Codebook indices for each codebook
+            (quantized discrete representation of input)
+        "embeddings" : Tensor[B x D x T]
+            Quantized continuous representation of input
+        """
+        # Preprocess the audio data to have the right padded lengths
+        length = audio_data.shape[-1]
+        right_pad = (
+            math.ceil(length / self.hop_length) * self.hop_length - length
+        )
+        audio_data = nn.functional.pad(audio_data, (0, right_pad))
+
+        z, codes, _, _, _ = self.encode(audio_data, n_quantizers)
+        return codes, z
diff --git a/speechbrain/lobes/models/g2p/dataio.py b/speechbrain/lobes/models/g2p/dataio.py
index 6b4fdc9a20c227be40311006b5df6e8062935d3e..d92faec4e3b38e73029b4f41f8e7d3d1db184f2b 100644
--- a/speechbrain/lobes/models/g2p/dataio.py
+++ b/speechbrain/lobes/models/g2p/dataio.py
@@ -9,6 +9,7 @@ Authors
 
 from functools import reduce
 from speechbrain.wordemb.util import expand_to_chars
+from torch import nn
 import speechbrain as sb
 import torch
 import re
@@ -522,6 +523,43 @@ def _map_tokens_item(tokens, char_map):
     return [char_map[char] for char in tokens]
 
 
+class LazyInit(nn.Module):
+    """A lazy initialization wrapper
+
+    Arguments
+    ---------
+    init : callable
+        The function to initialize the underlying object"""
+
+    def __init__(self, init):
+        super().__init__()
+        self.instance = None
+        self.init = init
+        self.device = None
+
+    def __call__(self):
+        """Initializes the object instance, if necessary
+        and returns it."""
+        if self.instance is None:
+            self.instance = self.init()
+        return self.instance
+
+    def to(self, device):
+        """Moves the underlying object to the specified device
+
+        Arguments
+        ---------
+        device : str | torch.device
+            the device
+        """
+        super().to(device)
+        if self.instance is None:
+            self.instance = self.init()
+        if hasattr(self.instance, "to"):
+            self.instance = self.instance.to(device)
+        return self
+
+
 def lazy_init(init):
     """A wrapper to ensure that the specified object is initialzied
     only once (used mainly for tokenizers that train when the
@@ -537,16 +575,7 @@ def lazy_init(init):
     instance: object
         the object instance
     """
-    instance = None
-
-    def f():
-        """The initializer function"""
-        nonlocal instance
-        if instance is None:
-            instance = init()
-        return instance
-
-    return f
+    return LazyInit(init)
 
 
 def get_sequence_key(key, mode):
diff --git a/speechbrain/lobes/models/huggingface_transformers/__init__.py b/speechbrain/lobes/models/huggingface_transformers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c42af90f0bd3ef643de951079439284ec5b966c
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/__init__.py
@@ -0,0 +1,22 @@
+"""High level processing blocks.
+
+This subpackage gathers higher level blocks, or "lobes" for HuggingFace Transformers.
+"""
+
+# Transformers is required for this package.
+try:
+    import transformers  # noqa
+except ImportError:
+    MSG = "Please install transformers from HuggingFace.\n"
+    MSG += "E.G. run: pip install transformers \n"
+    MSG += "For more information, visit: https://huggingface.co/docs/transformers/installation"
+    raise ImportError(MSG)
+
+from .gpt import *  # noqa
+from .hubert import *  # noqa
+from .huggingface import *  # noqa
+from .wav2vec2 import *  # noqa
+from .wavlm import *  # noqa
+from .whisper import *  # noqa
+from .encodec import *  # noqa
+from .weighted_ssl import *  # noqa
diff --git a/speechbrain/lobes/models/huggingface_transformers/discrete_hubert.py b/speechbrain/lobes/models/huggingface_transformers/discrete_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f947fa00988706d590e03cf78ea294bc474e06a
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/discrete_hubert.py
@@ -0,0 +1,175 @@
+"""This lobe enables the integration of pretrained discrete Hubert.
+
+Reference: https://arxiv.org/abs/2006.11477
+Reference: https://arxiv.org/abs/1904.05862
+Reference: https://arxiv.org/abs/2110.13900
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Author
+ * Pooneh Mousavi 2023
+
+"""
+import logging
+import torch
+from huggingface_hub import hf_hub_download
+import joblib
+
+from speechbrain.lobes.models.huggingface_transformers.hubert import HuBERT
+
+logger = logging.getLogger(__name__)
+
+
+class DiscreteHuBERT(HuBERT):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained Discrete HuBERT models.
+
+    Source paper HuBERT: https://arxiv.org/abs/2106.07447
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model can be used as a fixed Discrete feature extractor or can be finetuned. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    For now, HuggingFace's HuBERT and WavLM model can be loaded using the exact code for Wav2Vec2 model.
+    For this reason, HuBERT and WavLM can be fine inheriting the Wav2Vec2 class.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "facebook/hubert-base-ls960"
+    save_path : str
+        Path (dir) of the downloaded model.
+    kmeans_repo_id : str
+        Huggingface repository if that contains the pretrained kmean model
+    kmeans_filename : str
+        Name of the file in HF repo that need to be downloaded.
+    kmeans_cache_dir: str
+        Path (dir) of the downloaded kmeans model.
+    output_norm : bool (default: True)
+        If True, a layer_norm (affine) will be applied to the output obtained
+        from the HuBERT model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    freeze_feature_extractor :  bool (default: False)
+        When freeze = False and freeze_feature_extractor True, the featue_extractor module of the model is Frozen. If False
+        all the HuBERT model will be trained including featue_extractor module.
+    apply_spec_augment : bool (default: False)
+        If True, the model will apply spec augment on the output of feature extractor
+        (inside huggingface Hubert Model() class).
+        If False, the model will not apply spec augment. We set this to false to prevent from doing it twice.
+    output_all_hiddens : bool (default: True)
+        If True, the forward function outputs the hidden states from all transformer layers.
+        For example facebook/hubert-base-ls960 has 12 transformer layers and the output is of shape (13, B, T, C),
+        where a projection of the CNN output is added to the beginning.
+        If False, the forward function outputs the hidden states only from the last transformer layer.
+    ssl_layer_num : (int) (default: -1)
+        determine the output of which layer of the SSL model should be used for clustering.
+
+
+    Example
+    -------
+    >>> import torch
+    >>> inputs = torch.rand([10, 600])
+    >>> model_hub = "facebook/hubert-base-ls960"
+    >>> save_path = "savedir"
+    >>> ssl_layer_num = -1
+    >>> kmeans_repo_id = "speechbrain/SSL_Quantization"
+    >>> kmeans_filename = "LibriSpeech_hubert_k128_L7.pt"
+    >>> kmeans_cache_dir="savedir"
+    >>> model = DiscreteHuBERT(model_hub, save_path,freeze = True,ssl_layer_num=ssl_layer_num,kmeans_repo_id=kmeans_repo_id, kmeans_filename=kmeans_filename, kmeans_cache_dir=kmeans_cache_dir)
+    >>> embs, tokens = model(inputs)
+    >>> embs.shape
+    torch.Size([10, 1, 768])
+    >>> tokens.shape
+    torch.Size([10, 1])
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        kmeans_filename,
+        kmeans_cache_dir,
+        kmeans_repo_id="speechbrain/SSL_Quantization",
+        output_norm=False,
+        freeze=False,
+        freeze_feature_extractor=False,
+        apply_spec_augment=False,
+        output_all_hiddens=True,
+        ssl_layer_num=-1,
+    ):
+        super().__init__(
+            source=source,
+            save_path=save_path,
+            output_norm=output_norm,
+            freeze=freeze,
+            freeze_feature_extractor=freeze_feature_extractor,
+            apply_spec_augment=apply_spec_augment,
+            output_all_hiddens=output_all_hiddens,
+        )
+
+        self.kmeans = self.load_kmeans(
+            kmeans_repo_id, kmeans_filename, kmeans_cache_dir
+        )
+        self.vocabulary = self.kmeans.cluster_centers_
+        self.ssl_layer_num = ssl_layer_num
+
+    def load_kmeans(self, repo_id, filename, cache_dir):
+        """Load a Pretrained kmeans model from HF.
+
+        Arguments
+        ---------
+        repo_id : str
+           The hugingface repo id that contains the model.
+        filename : str
+            The name of the checkpoints in the repo that need to be downloaded.
+        cache_dir: str
+            Path (dir) of the downloaded model.
+        Returns:
+        ---------
+        kmeans_model : MiniBatchKMeans:
+            pretrained Kmeans  model loaded from the HF.
+        """
+        kmeans_model = joblib.load(
+            hf_hub_download(
+                repo_id=repo_id, filename=filename, cache_dir=cache_dir
+            )
+        )
+        return kmeans_model
+
+    def forward(self, wav, wav_lens=None):
+        """Takes an input waveform and return its corresponding wav2vec encoding.
+
+        Arguments
+        ---------
+        wav : torch.Tensor (signal)
+            A batch of audio signals to transform to features.
+        wav_len : tensor
+            The relative length of the wav given in SpeechBrain format.
+        Returns:
+        ---------
+        tokens : torch.Tensor
+            A (Batch x Seq) tensor of audio tokens
+        emb : torch.Tensor
+            A (Batch x Seq x embedding_dim ) cluster_centers embeddings for each tokens
+        """
+
+        # If we freeze, we simply remove all grads from the graph.
+        with torch.set_grad_enabled(not self.freeze):
+            feats = self.extract_features(wav, wav_lens)[self.ssl_layer_num]
+        tokens = self.kmeans.predict(feats.flatten(end_dim=-2).cpu())
+        embs = self.vocabulary[tokens]
+        return (
+            torch.tensor(
+                embs.reshape(wav.shape[0], -1, embs.shape[-1]),
+                dtype=torch.float,
+                device=wav.device,
+            ),
+            torch.tensor(
+                tokens.reshape(wav.shape[0], -1),
+                dtype=torch.long,
+                device=wav.device,
+            ),
+        )
diff --git a/speechbrain/lobes/models/huggingface_transformers/discrete_wav2vec2.py b/speechbrain/lobes/models/huggingface_transformers/discrete_wav2vec2.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cee8ec4d5fc5980a307868d400ca8576f2b7d2f
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/discrete_wav2vec2.py
@@ -0,0 +1,172 @@
+"""This lobe enables the integration of pretrained discrete wav2vec2 model.
+
+Reference: https://arxiv.org/abs/2006.11477
+Reference: https://arxiv.org/abs/1904.05862
+Reference: https://arxiv.org/abs/2110.13900
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Author
+ * Pooneh Mousavi 2023
+
+"""
+import logging
+import torch
+from huggingface_hub import hf_hub_download
+import joblib
+
+from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import Wav2Vec2
+
+logger = logging.getLogger(__name__)
+
+
+class DiscreteWav2Vec2(Wav2Vec2):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained Discrete Wav2Vec2 models.
+
+     Source paper wav2vec2.0: https://arxiv.org/abs/2006.11477
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model can be used as a fixed Discrete feature extractor or can be finetuned. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+    save_path : str
+        Path (dir) of the downloaded model.
+    kmeans_repo_id : str
+        Huggingface repository if that contains the pretrained kmean model
+    kmeans_filename : str
+        Name of the file in HF repo that need to be downloaded.
+    kmeans_cache_dir: str
+        Path (dir) of the downloaded kmeans model.
+    output_norm : bool (default: True)
+        If True, a layer_norm (affine) will be applied to the output obtained
+        from the Wav2Vec2 model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    freeze_feature_extractor :  bool (default: False)
+        When freeze = False and freeze_feature_extractor True, the featue_extractor module of the model is Frozen. If False
+        all the Wav2Vec2 model will be trained including featue_extractor module.
+    apply_spec_augment : bool (default: False)
+        If True, the model will apply spec augment on the output of feature extractor
+        (inside huggingface Wav2Vec2 Model() class).
+        If False, the model will not apply spec augment. We set this to false to prevent from doing it twice.
+    output_all_hiddens : bool (default: True)
+        If True, the forward function outputs the hidden states from all transformer layers.
+        For example facebook/wav2vec2-large-lv60 has 12 transformer layers and the output is of shape (13, B, T, C),
+        where a projection of the CNN output is added to the beginning.
+        If False, the forward function outputs the hidden states only from the last transformer layer.
+    ssl_layer_num : (int) (default: -1)
+        determine the output of which layer of the SSL model should be used for clustering.
+
+
+    Example
+    -------
+    >>> import torch
+    >>> inputs = torch.rand([10, 600])
+    >>> model_hub = "facebook/wav2vec2-large-lv60"
+    >>> save_path = "savedir"
+    >>> ssl_layer_num = -1
+    >>> kmeans_repo_id = "speechbrain/SSL_Quantization"
+    >>> kmeans_filename = "LibriSpeech_wav2vec_k128_L7.pt"
+    >>> kmeans_cache_dir="savedir"
+    >>> model = DiscreteWav2Vec2(model_hub, save_path,freeze = True,ssl_layer_num=ssl_layer_num,kmeans_repo_id=kmeans_repo_id, kmeans_filename=kmeans_filename, kmeans_cache_dir=kmeans_cache_dir)
+    >>> embs, tokens = model(inputs)
+    >>> embs.shape
+    torch.Size([10, 1, 1024])
+    >>> tokens.shape
+    torch.Size([10, 1])
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        kmeans_filename,
+        kmeans_cache_dir,
+        kmeans_repo_id="speechbrain/SSL_Quantization",
+        output_norm=False,
+        freeze=False,
+        freeze_feature_extractor=False,
+        apply_spec_augment=False,
+        output_all_hiddens=True,
+        ssl_layer_num=-1,
+    ):
+        super().__init__(
+            source=source,
+            save_path=save_path,
+            output_norm=output_norm,
+            freeze=freeze,
+            freeze_feature_extractor=freeze_feature_extractor,
+            apply_spec_augment=apply_spec_augment,
+            output_all_hiddens=output_all_hiddens,
+        )
+
+        self.kmeans = self.load_kmeans(
+            kmeans_repo_id, kmeans_filename, kmeans_cache_dir
+        )
+        self.vocabulary = self.kmeans.cluster_centers_
+        self.ssl_layer_num = ssl_layer_num
+
+    def load_kmeans(self, repo_id, filename, cache_dir):
+        """Load a Pretrained kmeans model from HF.
+
+        Arguments
+        ---------
+        repo_id : str
+           The hugingface repo id that contains the model.
+        filename : str
+            The name of the checkpoints in the repo that need to be downloaded.
+        cache_dir: str
+            Path (dir) of the downloaded model.
+        Returns:
+        ---------
+        kmeans_model : MiniBatchKMeans:
+            pretrained Kmeans  model loaded from the HF.
+        """
+        kmeans_model = joblib.load(
+            hf_hub_download(
+                repo_id=repo_id, filename=filename, cache_dir=cache_dir
+            )
+        )
+        return kmeans_model
+
+    def forward(self, wav, wav_lens=None):
+        """Takes an input waveform and return its corresponding wav2vec encoding.
+
+        Arguments
+        ---------
+        wav : torch.Tensor (signal)
+            A batch of audio signals to transform to features.
+        wav_len : tensor
+            The relative length of the wav given in SpeechBrain format.
+        Returns:
+        ---------
+        tokens : torch.Tensor
+            A (Batch x Seq) tensor of audio tokens
+        emb : torch.Tensor
+            A (Batch x Seq x embedding_dim ) cluster_centers embeddings for each tokens
+        """
+
+        # If we freeze, we simply remove all grads from the graph.
+        with torch.set_grad_enabled(not self.freeze):
+            feats = self.extract_features(wav, wav_lens)[self.ssl_layer_num]
+        tokens = self.kmeans.predict(feats.flatten(end_dim=-2).cpu())
+        embs = self.vocabulary[tokens]
+        return (
+            torch.tensor(
+                embs.reshape(wav.shape[0], -1, embs.shape[-1]),
+                dtype=torch.float,
+                device=wav.device,
+            ),
+            torch.tensor(
+                tokens.reshape(wav.shape[0], -1),
+                dtype=torch.long,
+                device=wav.device,
+            ),
+        )
diff --git a/speechbrain/lobes/models/huggingface_transformers/discrete_wavlm.py b/speechbrain/lobes/models/huggingface_transformers/discrete_wavlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a7ab473b857ddf59bda288af7484d20f1ced366
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/discrete_wavlm.py
@@ -0,0 +1,172 @@
+"""This lobe enables the integration of pretrained discrete wavLM.
+
+Reference: https://arxiv.org/abs/2006.11477
+Reference: https://arxiv.org/abs/1904.05862
+Reference: https://arxiv.org/abs/2110.13900
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Author
+ * Pooneh Mousavi 2023
+
+"""
+import logging
+import torch
+from huggingface_hub import hf_hub_download
+import joblib
+
+from speechbrain.lobes.models.huggingface_transformers.wavlm import WavLM
+
+logger = logging.getLogger(__name__)
+
+
+class DiscreteWavLM(WavLM):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained Discrete WavLM models.
+
+    Source paper WavLM: https://arxiv.org/abs/2110.13900
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model can be used as a fixed Discrete feature extractor or can be finetuned. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "microsoft/wavlm-large"
+    save_path : str
+        Path (dir) of the downloaded model.
+    kmeans_repo_id : str
+        Huggingface repository if that contains the pretrained kmean model
+    kmeans_filename : str
+        Name of the file in HF repo that need to be downloaded.
+    kmeans_cache_dir: str
+        Path (dir) of the downloaded kmeans model.
+    output_norm : bool (default: True)
+        If True, a layer_norm (affine) will be applied to the output obtained
+        from the WavLM model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    freeze_feature_extractor :  bool (default: False)
+        When freeze = False and freeze_feature_extractor True, the featue_extractor module of the model is Frozen. If False
+        all the WavLM model will be trained including featue_extractor module.
+    apply_spec_augment : bool (default: False)
+        If True, the model will apply spec augment on the output of feature extractor
+        (inside huggingface WavLM Model() class).
+        If False, the model will not apply spec augment. We set this to false to prevent from doing it twice.
+    output_all_hiddens : bool (default: True)
+        If True, the forward function outputs the hidden states from all transformer layers.
+        For example microsoft/wavlm-large  has 12 transformer layers and the output is of shape (13, B, T, C),
+        where a projection of the CNN output is added to the beginning.
+        If False, the forward function outputs the hidden states only from the last transformer layer.
+    ssl_layer_num : (int) (default: -1)
+        determine the output of which layer of the SSL model should be used for clustering.
+
+
+    Example
+    -------
+    >>> import torch
+    >>> inputs = torch.rand([10, 600])
+    >>> model_hub = "microsoft/wavlm-large"
+    >>> save_path = "savedir"
+    >>> ssl_layer_num = -1
+    >>> kmeans_repo_id = "speechbrain/SSL_Quantization"
+    >>> kmeans_filename = "LJSpeech_wavlm_k128_L7.pt"
+    >>> kmeans_cache_dir="savedir"
+    >>> model = DiscreteWavLM(model_hub, save_path,freeze = True,ssl_layer_num=ssl_layer_num,kmeans_repo_id=kmeans_repo_id, kmeans_filename=kmeans_filename, kmeans_cache_dir=kmeans_cache_dir)
+    >>> embs, tokens = model(inputs)
+    >>> embs.shape
+    torch.Size([10, 1, 1024])
+    >>> tokens.shape
+    torch.Size([10, 1])
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        kmeans_filename,
+        kmeans_cache_dir,
+        kmeans_repo_id="speechbrain/SSL_Quantization",
+        output_norm=False,
+        freeze=False,
+        freeze_feature_extractor=False,
+        apply_spec_augment=False,
+        output_all_hiddens=True,
+        ssl_layer_num=-1,
+    ):
+        super().__init__(
+            source=source,
+            save_path=save_path,
+            output_norm=output_norm,
+            freeze=freeze,
+            freeze_feature_extractor=freeze_feature_extractor,
+            apply_spec_augment=apply_spec_augment,
+            output_all_hiddens=output_all_hiddens,
+        )
+
+        self.kmeans = self.load_kmeans(
+            kmeans_repo_id, kmeans_filename, kmeans_cache_dir
+        )
+        self.vocabulary = self.kmeans.cluster_centers_
+        self.ssl_layer_num = ssl_layer_num
+
+    def load_kmeans(self, repo_id, filename, cache_dir):
+        """Load a Pretrained kmeans model from HF.
+
+        Arguments
+        ---------
+        repo_id : str
+           The hugingface repo id that contains the model.
+        filename : str
+            The name of the checkpoints in the repo that need to be downloaded.
+        cache_dir: str
+            Path (dir) of the downloaded model.
+        Returns:
+        ---------
+        kmeans_model : MiniBatchKMeans:
+            pretrained Kmeans  model loaded from the HF.
+        """
+        kmeans_model = joblib.load(
+            hf_hub_download(
+                repo_id=repo_id, filename=filename, cache_dir=cache_dir
+            )
+        )
+        return kmeans_model
+
+    def forward(self, wav, wav_lens=None):
+        """Takes an input waveform and return its corresponding wav2vec encoding.
+
+        Arguments
+        ---------
+        wav : torch.Tensor (signal)
+            A batch of audio signals to transform to features.
+        wav_len : tensor
+            The relative length of the wav given in SpeechBrain format.
+        Returns:
+        ---------
+        tokens : torch.Tensor
+            A (Batch x Seq) tensor of audio tokens
+        emb : torch.Tensor
+            A (Batch x Seq x embedding_dim ) cluster_centers embeddings for each tokens
+        """
+
+        # If we freeze, we simply remove all grads from the graph.
+        with torch.set_grad_enabled(not self.freeze):
+            feats = self.extract_features(wav, wav_lens)[self.ssl_layer_num]
+        tokens = self.kmeans.predict(feats.flatten(end_dim=-2).cpu())
+        embs = self.vocabulary[tokens]
+        return (
+            torch.tensor(
+                embs.reshape(wav.shape[0], -1, embs.shape[-1]),
+                dtype=torch.float,
+                device=wav.device,
+            ),
+            torch.tensor(
+                tokens.reshape(wav.shape[0], -1),
+                dtype=torch.long,
+                device=wav.device,
+            ),
+        )
diff --git a/speechbrain/lobes/models/huggingface_transformers/encodec.py b/speechbrain/lobes/models/huggingface_transformers/encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ec3b3415ff5d74066020362f464a24a9f674723
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/encodec.py
@@ -0,0 +1,377 @@
+"""This lobe enables the integration of huggingface pretrained EnCodec.
+
+EnCodec makes it possible to compress audio into a sequence of discrete tokens
+at different bandwidths - and to reconstruct audio from such sequences, with
+some loss of quality depending on the bandwidth.
+
+Note that while encodec can be used to reconstruct speech data, for a
+high-quality reconstruction, it is recommended to use a specially trained
+vocoder, such as Vocos (speechbrain.lobes.models.huggingface_transformers.vocos)
+
+Repository: https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec
+Paper: https://arxiv.org/abs/2210.13438
+
+Authors
+ * Artem Ploujnikov 2023
+"""
+
+import torch
+import logging
+from torch.nn import functional as F
+from speechbrain.dataio.dataio import length_to_mask, clean_padding_
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    HFTransformersInterface,
+)
+
+DEFAULT_SAMPLE_RATE = 24000
+
+logger = logging.getLogger(__name__)
+
+
+class Encodec(HFTransformersInterface):
+    """An wrapper for the HuggingFace encodec model
+
+    Arguments
+    ---------
+    source : str
+        A HuggingFace repository identifier or a path
+    save_path : str
+        The location where the pretrained model will be saved
+    sample_rate : int
+        The audio sampling rate
+    bandwidth : float
+        The encoding bandwidth, in kbps (optional)
+        Supported bandwidths:
+        1.5, 3.0, 6.0, 12.0, 24.0
+    flat_embeddings : bool
+        If set to True, embeddings will be flattened into
+        (Batch x Length x (Heads * Embedding))
+    freeze : bool
+        whether the model will be frozen (e.g. not trainable if used
+        as part of training another model)
+    renorm_embeddings : bool
+        whether embeddings should be renormalized. In the original
+        model.
+
+    Example
+    -------
+    >>> model_hub = "facebook/encodec_24khz"
+    >>> save_path = "savedir"
+    >>> model = Encodec(model_hub, save_path)
+    >>> audio = torch.randn(4, 1000)
+    >>> length = torch.tensor([1.0, .5, .75, 1.0])
+    >>> tokens, emb = model.encode(audio, length)
+    >>> tokens.shape
+    torch.Size([4, 4, 2])
+    >>> emb.shape
+    torch.Size([4, 4, 2, 128])
+    >>> rec = model.decode(tokens, length)
+    >>> rec.shape
+    torch.Size([4, 1, 1280])
+    >>> rec_emb = model.decode_emb(emb, length)
+    >>> rec_emb.shape
+    torch.Size([4, 1, 1280])
+    >>> rec_tokens = model.tokens(emb, length)
+    >>> rec_tokens.shape
+    torch.Size([4, 4, 2])
+    >>> model = Encodec(model_hub, save_path, flat_embeddings=True)
+    >>> _, emb = model.encode(audio, length)
+    >>> emb.shape
+    torch.Size([4, 4, 256])
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path=None,
+        sample_rate=None,
+        bandwidth=1.5,
+        flat_embeddings=False,
+        freeze=True,
+        renorm_embeddings=True,
+    ):
+        super().__init__(source=source, save_path=save_path, freeze=freeze)
+        if not sample_rate:
+            sample_rate = DEFAULT_SAMPLE_RATE
+        self.sample_rate = sample_rate
+        self.bandwidth = bandwidth
+        self.flat_embeddings = flat_embeddings
+        self.num_heads = self.model.quantizer.get_num_quantizers_for_bandwidth(
+            bandwidth
+        )
+        self.num_tokens = self.model.config.codebook_size
+        quantizer_layers = self.model.quantizer.layers[: self.num_heads]
+        vocabulary = torch.stack(
+            [layer.codebook.embed for layer in quantizer_layers]
+        )
+        self.register_buffer("vocabulary", vocabulary)
+        _, self.num_tokens, self.emb_dim = self.vocabulary.shape
+        vocabulary_flat = self.vocabulary.reshape(
+            self.num_heads * self.num_tokens, self.emb_dim
+        )
+        self.register_buffer("vocabulary_flat", vocabulary_flat)
+        token_index_offsets = (
+            torch.arange(self.num_heads)[None, None, :] * self.num_tokens
+        )
+        self.register_buffer("token_index_offsets", token_index_offsets)
+        self.renorm_embeddings = renorm_embeddings
+        if self.renorm_embeddings:
+            emb_mean, emb_std = self._precalibrate()
+            self.register_buffer("emb_mean", emb_mean)
+            self.register_buffer("emb_std", emb_std)
+        if self.freeze:
+            logger.warning("huggingface_Encodec - Encodec is frozen.")
+            for param in self.model.parameters():
+                param.requires_grad = False
+
+    def _precalibrate(self):
+        """Compute parameters required to renormalize embeddings"""
+        sample = torch.arange(self.num_tokens)[None, :, None].expand(
+            1, self.num_tokens, self.num_heads
+        )
+        return self._compute_embedding_norm(sample)
+
+    def _compute_embedding_norm(self, sample, length=None):
+        """Computes the normalization for embeddings based on
+        a sample.
+
+        Arguments
+        ---------
+        sample : torch.Tensor
+            A (Batch x Samples) or (Batch x Channel x Samples)
+            audio sample
+
+        length : torch.Tensor
+            A tensor of relative lengths
+        """
+        if length is None:
+            length = torch.ones(len(sample), device=sample.device)
+        max_len = sample.size(1)
+        emb = self._raw_embeddings(sample)
+        mask = length_to_mask(length * max_len, max_len)[
+            :, :, None, None
+        ].expand_as(emb)
+        emb_mean = (emb.mean(-1).sum(1) / mask.mean(-1).sum(1)).mean(0)[
+            None, None, :, None
+        ]
+        emb_diff_sq = ((emb - emb_mean) * mask) ** 2
+        emb_std = (
+            emb_diff_sq.sum(dim=[0, 1, 3])
+            / (mask.expand_as(emb_diff_sq).sum(dim=[0, 1, 3]) - 1)
+        ).sqrt()[None, None, :, None]
+        return emb_mean, emb_std
+
+    def calibrate(self, sample, length):
+        """Calibrates the normalization on a sound sample
+
+        Arguments
+        ---------
+        sample : torch.Tensor
+            A (Batch x Samples) or (Batch x Channel x Samples)
+            audio sample
+
+        length : torch.Tensor
+            A tensor of relative lengths
+
+        Returns
+        -------
+        emb_mean : torch.Tensor
+            The embedding mean
+
+        emb_std : torch.Tensor
+            The embedding standard deviation
+        """
+        if not self.renorm_embeddings:
+            raise ValueError("Not supported when renorm_embeddings is disabled")
+        sample_tokens = self._encode_tokens(sample, length)
+        self.emb_mean, self.emb_std = self._compute_embedding_norm(
+            sample_tokens, length
+        )
+        return self.emb_mean.squeeze(), self.emb_std.squeeze()
+
+    def forward(self, inputs, length):
+        """Encodes the input audio as tokens
+
+        Arguments
+        ---------
+        inputs : torch.Tensor
+            A (Batch x Samples) or (Batch x Channel x Samples)
+            tensor of audio
+        length : torch.Tensor
+            A tensor of relative lengths
+
+        Returns
+        -------
+        tokens : torch.Tensor
+            A (Batch X Tokens) tensor of audio tokens
+        """
+        return self.encode(inputs, length)
+
+    def encode(self, inputs, length):
+        """Encodes the input audio as tokens and embeddings
+
+        Arguments
+        ---------
+        inputs : torch.Tensor
+            A (Batch x Samples) or (Batch x Channel x Samples)
+            tensor of audio
+        length : torch.Tensor
+            A tensor of relative lengths
+
+        Returns
+        -------
+        tokens : torch.Tensor
+            A (Batch x Tokens x Heads) tensor of audio tokens
+        emb : torch.Tensor
+            Raw vector embeddings from the model's
+            quantizers
+        """
+        with torch.set_grad_enabled(not self.freeze):
+            tokens = self._encode_tokens(inputs, length)
+            emb = self.embeddings(tokens)
+            return tokens, emb
+
+    def _encode_tokens(self, inputs, length):
+        """Encodes audio as tokens only
+
+        Arguments
+        ---------
+        inputs : torch.Tensor
+            A (Batch x Samples) or (Batch x Channel x Samples)
+            tensor of audio
+        length : torch.Tensor
+            A tensor of relative lengths
+
+        Returns
+        -------
+        tokens : torch.Tensor
+            A (Batch x Tokens x Heads) tensor of audio tokens
+        """
+        if inputs.dim() == 2:
+            inputs = inputs.unsqueeze(1)
+        max_len = inputs.size(-1)
+        mask = length_to_mask(
+            length * max_len, max_len, device=inputs.device
+        ).unsqueeze(1)
+        result = self.model.encode(inputs, mask, bandwidth=self.bandwidth)
+        tokens = result.audio_codes.squeeze(0).transpose(-1, -2)
+        return tokens
+
+    def _raw_embeddings(self, tokens):
+        """Converts token indexes to vector embeddings, for
+        each quantizer
+
+        Arguments
+        ---------
+        tokens : torch.Tensor
+            a (Batch x Length x Heads) tensor of token indexes
+
+        Returns
+        -------
+        emb : torch.Tensor
+            a (Batch x Length x Heads x Embedding) tensor
+            of raw vector embeddings from the model's
+            quantizer codebooks
+        """
+        idx = tokens + self.token_index_offsets
+        emb = F.embedding(idx, self.vocabulary_flat)
+        return emb
+
+    def embeddings(self, tokens):
+        """Converts token indexes to vector embeddings
+
+        Arguments
+        ---------
+        tokens : torch.Tensor
+            a (Batch x Length x Heads) tensor of token indexes
+
+        Returns
+        -------
+        emb : torch.Tensor
+            a (Batch x Length x Heads x Embedding) tensor
+            of raw vector embeddings from the model's
+            quantizer codebooks
+        """
+        emb = self._raw_embeddings(tokens)
+        if self.renorm_embeddings:
+            emb = (emb - self.emb_mean) / self.emb_std
+        if self.flat_embeddings:
+            batch_size, max_len, num_heads, emb_dim = emb.shape
+            emb = emb.reshape(batch_size, max_len, num_heads * emb_dim)
+        return emb
+
+    def decode(self, tokens, length=None):
+        """Decodes audio from tokens
+
+        Arguments
+        ---------
+        tokens : torch.Tensor
+            A (Batch x Length x Heads) tensor of audio tokens
+        length : torch.Tensor
+            A 1-D tensor of relative lengths
+
+        Returns
+        -------
+        audio : torch.Tensor
+            the reconstructed audio
+        """
+        with torch.set_grad_enabled(not self.freeze):
+            result = self.model.decode(
+                tokens.unsqueeze(0).transpose(-1, -2), [None]
+            )
+            audio = result.audio_values
+            if length is not None:
+                clean_padding_(audio, length)
+            return audio
+
+    def tokens(self, emb, length=None):
+        """Comberts embeddings to raw tokens
+
+        Arguments
+        ---------
+        emb : torch.Tensor
+            Raw embeddings
+        length : torch.Tensor
+            A 1-D tensor of relative lengths. If supplied,
+            padded positions will be zeroed out
+
+        Returns
+        -------
+        tokens : torch.Tensor
+            A (Batch x Length) tensor of token indices"""
+        with torch.set_grad_enabled(not self.freeze):
+            if self.flat_embeddings:
+                batch_size, max_len, _ = emb.shape
+                emb = emb.reshape(
+                    batch_size, max_len, self.num_heads, self.emb_dim
+                )
+            if self.renorm_embeddings:
+                emb = emb * self.emb_std + self.emb_mean
+            scaled_states = emb.pow(2).sum(-1, keepdim=True)
+            vocab = self.vocabulary.transpose(-1, -2).unsqueeze(0)
+            emb_perm = emb.permute(0, 2, 1, 3)
+            emb_vocab_prod = (emb_perm @ vocab).moveaxis(1, 2)
+            vocab_sum = vocab.pow(2).sum(-2, keepdim=True).moveaxis(1, 2)
+            dist = -(scaled_states - 2 * emb_vocab_prod + vocab_sum)
+            tokens = dist.max(dim=-1).indices
+            if length is not None:
+                clean_padding_(tokens, length)
+            return tokens
+
+    def decode_emb(self, emb, length):
+        """Decodes raw vector embeddings into audio
+
+        Arguments
+        ---------
+        emb : torch.Tensor
+            A (Batch x Length x Heads x Embedding) tensor of
+            raw vector embeddings
+
+        Returns
+        -------
+        audio : torch.Tensor
+            the reconstructed audio
+        """
+        with torch.set_grad_enabled(not self.freeze):
+            tokens = self.tokens(emb)
+            return self.decode(tokens, length)
diff --git a/speechbrain/lobes/models/huggingface_transformers/gpt.py b/speechbrain/lobes/models/huggingface_transformers/gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec191a7d61f3506950dfb70ff61fbd9de2d3f93f
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/gpt.py
@@ -0,0 +1,152 @@
+"""This lobe enables the integration of huggingface pretrained GPT2LMHeadModel model.
+
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Pooneh Mousavi 2023
+ * Simone Alghisi 2023
+"""
+
+import logging
+from torch import Tensor
+import torch
+
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    HFTransformersInterface,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class GPT(HFTransformersInterface):
+    """This lobe enables the integration of HuggingFace pretrained GPT model.
+     Source paper whisper:
+        https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf
+    Transformer from HuggingFace needs to be installed:
+        https://huggingface.co/transformers/installation.html
+
+    The model can be finetuned. It will download automatically the model from
+    HuggingFace or use a local path.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "gpt2"
+    save_path : str
+        Path (dir) of the downloaded model.
+    freeze : bool (default: False)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    Example
+    -------
+    >>> model_hub = "gpt2"
+    >>> save_path = "savedir"
+    >>> model = GPT(model_hub, save_path)
+    >>> tokens = torch.tensor([[1, 1]])
+    >>> tokens_type = torch.tensor([[1, 1]])
+    >>> attention_mask = torch.tensor([[1, 1]])
+    >>> outputs = model(tokens, tokens_type, attention_mask)
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        freeze=False,
+        max_new_tokens=200,
+        min_length=1,
+        top_k=45,
+        top_p=0.9,
+        num_beams=8,
+        eos_token_id=50258,
+        early_stopping=True,
+    ) -> None:
+        super().__init__(
+            source=source, save_path=save_path, freeze=freeze, with_lm_head=True
+        )
+        self.max_new_tokens = max_new_tokens
+        self.min_length = min_length
+        self.top_k = top_k
+        self.top_p = top_p
+        self.num_beams = num_beams
+        self.early_stopping = early_stopping
+        self.eos_token_id = eos_token_id
+
+        self.load_tokenizer(source=source, pad_token=None, use_fast=False)
+
+        if self.freeze:
+            logger.warning("huggingface_GPT - GPT  is frozen.")
+            self.model.train()  # we keep it to train to have dropout and LN computed adequaly
+            for param in self.model.parameters():
+                param.requires_grad = False
+
+    def forward(
+        self, input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor,
+    ):
+        """ Takes an input a history of conversation and returns its corresponding reply.
+
+        Arguments
+        ---------
+        input_ids : torch.Tensor ()
+            A batch of input-id to transform to features.
+        token_type_ids : torch.Tensor
+            Token Type(Speaker) for each token in input_ids.
+        attention_mask : torch.Tensor ()
+            A batch of attention_mask.
+        """
+        with torch.set_grad_enabled(not self.freeze):
+            output = self.model.forward(
+                input_ids,
+                token_type_ids=token_type_ids,
+                attention_mask=attention_mask,
+            )
+        return output
+
+    def generate(
+        self,
+        input_ids: Tensor,
+        token_type_ids,
+        attention_mask: Tensor,
+        decoder_type="greedy",
+    ):
+        """ Takes an input a history of conversation and returns its corresponding reply.
+
+        Arguments
+        --------
+        input_ids : torch.Tensor ()
+            A batch of input-id   which are dialogue context tokens
+        decoder_type : Str
+            It shows strategy for autoregressive decoding either beam seach or greedy.
+        attention_mask : torch.Tensor ()
+            A batch of attention_mask.
+        """
+
+        with torch.no_grad():
+            if decoder_type == "beam":
+                # beam decoding based on the input_ids which are dialogue context tokens (here only history)
+                hyp = self.model.generate(
+                    input_ids=input_ids,
+                    token_type_ids=token_type_ids,
+                    attention_mask=attention_mask,
+                    do_sample=True,
+                    max_new_tokens=self.max_new_tokens,
+                    min_length=self.min_length,
+                    top_k=self.top_k,
+                    top_p=self.top_p,
+                    num_beams=self.num_beams,
+                    num_return_sequences=1,
+                    eos_token_id=self.eos_token_id,
+                    early_stopping=self.early_stopping,
+                )
+            else:
+                # greedy decoding based on the input_ids which are dialogue context tokens (here only history)
+                hyp = self.model.generate(
+                    input_ids,
+                    token_type_ids=token_type_ids,
+                    max_new_tokens=self.max_new_tokens,
+                    eos_token_id=self.eos_token_id,
+                    attention_mask=attention_mask,
+                )
+        return hyp
diff --git a/speechbrain/lobes/models/huggingface_transformers/hubert.py b/speechbrain/lobes/models/huggingface_transformers/hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ab06b3a32c9de19260564d558341103a0ca36ba
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/hubert.py
@@ -0,0 +1,89 @@
+"""This lobe enables the integration of huggingface pretrained hubert models.
+
+Reference: https://arxiv.org/abs/2006.11477
+Reference: https://arxiv.org/abs/1904.05862
+Reference: https://arxiv.org/abs/2110.13900
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Titouan Parcollet 2021
+ * Boumadane Abdelmoumene 2021
+ * Ha Nguyen 2023
+"""
+
+import logging
+
+from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import Wav2Vec2
+
+logger = logging.getLogger(__name__)
+
+
+class HuBERT(Wav2Vec2):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained HuBERT models.
+
+    Source paper HuBERT: https://arxiv.org/abs/2106.07447
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model can be used as a fixed feature extractor or can be finetuned. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    For now, HuggingFace's HuBERT and WavLM model can be loaded using the exact code for Wav2Vec2 model.
+    For this reason, HuBERT and WavLM can be fine inheriting the Wav2Vec2 class.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "facebook/hubert-base-ls960"
+    save_path : str
+        Path (dir) of the downloaded model.
+    output_norm : bool (default: True)
+        If True, a layer_norm (affine) will be applied to the output obtained
+        from the HuBERT model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    freeze_feature_extractor :  bool (default: False)
+        When freeze = False and freeze_feature_extractor True, the featue_extractor module of the model is Frozen. If False
+        all the HuBERT model will be trained including featue_extractor module.
+    apply_spec_augment : bool (default: False)
+        If True, the model will apply spec augment on the output of feature extractor
+        (inside huggingface HubertModel() class).
+        If False, the model will not apply spec augment. We set this to false to prevent from doing it twice.
+    output_all_hiddens : bool (default: False)
+        If True, the forward function outputs the hidden states from all transformer layers.
+        For example facebook/hubert-base-ls960 has 12 transformer layers and the output is of shape (13, B, T, C),
+        where a projection of the CNN output is added to the beginning.
+        If False, the forward function outputs the hidden states only from the last transformer layer.
+
+    Example
+    -------
+    >>> import torch
+    >>> inputs = torch.rand([10, 600])
+    >>> model_hub = "facebook/hubert-base-ls960"
+    >>> save_path = "savedir"
+    >>> model = HuBERT(model_hub, save_path)
+    >>> outputs = model(inputs)
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        output_norm=False,
+        freeze=False,
+        freeze_feature_extractor=False,
+        apply_spec_augment=False,
+        output_all_hiddens=False,
+    ):
+        super().__init__(
+            source=source,
+            save_path=save_path,
+            output_norm=output_norm,
+            freeze=freeze,
+            freeze_feature_extractor=freeze_feature_extractor,
+            apply_spec_augment=apply_spec_augment,
+            output_all_hiddens=output_all_hiddens,
+        )
diff --git a/speechbrain/lobes/models/huggingface_transformers/huggingface.py b/speechbrain/lobes/models/huggingface_transformers/huggingface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d57c07e868b2966d7d2d83b642611e6c7531a83a
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/huggingface.py
@@ -0,0 +1,419 @@
+"""This lobe is the interface for huggingface transformers models
+It enables loading config and model via AutoConfig & AutoModel.
+
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Titouan Parcollet 2021, 2022, 2023
+ * Mirco Ravanelli 2021
+ * Boumadane Abdelmoumene 2021
+ * Ju-Chieh Chou 2021
+ * Artem Ploujnikov 2021, 2022
+ * Abdel Heba 2021
+ * Aku Rouhe 2022
+ * Arseniy Gorin 2022
+ * Ali Safaya 2022
+ * Benoit Wang 2022
+ * Adel Moumen 2022, 2023
+ * Andreas Nautsch 2022, 2023
+ * Luca Della Libera 2022
+ * Heitor Guimarães 2022
+ * Ha Nguyen 2023
+"""
+import os
+import torch
+import logging
+import pathlib
+from torch import nn
+from huggingface_hub import model_info
+from speechbrain.utils.fetching import fetch
+from speechbrain.dataio.dataio import length_to_mask
+
+from transformers import (
+    AutoConfig,
+    AutoTokenizer,
+    AutoFeatureExtractor,
+    AutoModelForPreTraining,
+    AutoModel,
+    AutoModelWithLMHead,
+    AutoModelForSeq2SeqLM,
+    AutoModelForCausalLM,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class HFTransformersInterface(nn.Module):
+    """This lobe provides an interface for integrating any HuggingFace transformer model within SpeechBrain.
+
+    We use AutoClasses for loading any model from the hub and its necessary components.
+    For example, we build Wav2Vec2 class which inherits HFTransformersInterface for working with HuggingFace's wav2vec models.
+    While Wav2Vec2 can enjoy some already built features like modeling loading, pretrained weights loading, all weights freezing,
+    feature_extractor loading, etc.
+    Users are expected to override the essential forward() function to fit their specific needs.
+    Depending on the HuggingFace transformer model in question, one can also modify the state_dict by overwriting the _modify_state_dict() method,
+    or adapting their config by modifying override_config() method, etc.
+    See:
+    https://huggingface.co/docs/transformers/model_doc/auto
+    https://huggingface.co/docs/transformers/autoclass_tutorial
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+    save_path : str
+        save directory of the downloaded model.
+    for_pretraining: bool (default: False)
+        If True, build the model for pretraining
+    with_lm_head : bool (default: False)
+        If True, build the model with lm_head
+    with_casual_lm : bool (default: False)
+        If True, build casual lm  model
+    seq2seqlm : bool (default: False)
+        If True, build a sequence-to-sequence model with lm_head
+    quantization_config : dict (default: None)
+        Quantization config, extremely useful for deadling with LLM
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    cache_dir : str or Path (default: None)
+        Location of HuggingFace cache for storing pre-trained models, to which symlinks are created.
+
+    Example
+    -------
+    >>> model_hub = "facebook/wav2vec2-base-960h"
+    >>> save_path = "tmp"
+    >>> model = HFTransformersInterface(model_hub, save_path=save_path)
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path="",
+        for_pretraining=False,
+        with_lm_head=False,
+        with_casual_lm=False,
+        seq2seqlm=False,
+        quantization_config=None,
+        freeze=False,
+        cache_dir="pretrained_models",
+        **kwarg,
+    ):
+        super().__init__()
+
+        # Fetch config
+        self.config, _unused_kwargs = AutoConfig.from_pretrained(
+            source, cache_dir=save_path, return_unused_kwargs=True,
+        )
+
+        self.config = self.override_config(self.config)
+        self.quantization_config = quantization_config
+
+        self.for_pretraining = for_pretraining
+
+        if self.for_pretraining:
+            self.auto_class = AutoModelForPreTraining
+        elif with_lm_head:
+            self.auto_class = AutoModelWithLMHead
+        elif with_casual_lm:
+            self.auto_class = AutoModelForCausalLM
+        elif seq2seqlm:
+            self.auto_class = AutoModelForSeq2SeqLM
+        else:
+            self.auto_class = AutoModel
+
+        # Download model
+        self._from_pretrained(
+            source, save_path=save_path, cache_dir=cache_dir,
+        )
+
+        # Prepare for training, fine-tuning, or inference
+        self.freeze = freeze
+        if self.freeze:
+            logger.warning(
+                f"speechbrain.lobes.models.huggingface_transformers.huggingface - {type(self.model).__name__} is frozen."
+            )
+            self.freeze_model(self.model)
+        else:
+            self.model.gradient_checkpointing_disable()  # Required by DDP
+            self.model.train()
+
+    def _from_pretrained(
+        self, source, save_path, cache_dir,
+    ):
+        """This function manages the source checking and loading of the params.
+
+        # 1. Is the model from HF or a local path
+        # 2. Is the model pretrained with HF or SpeechBrain
+        # 3. Download (if appropriate) and load with respect to 1. and 2.
+
+        Arguments
+        ---------
+        source : str
+            HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+        save_path : str
+            Path (dir) of the downloaded model.
+        cache_dir : str
+            Path (dir) in which a downloaded pretrained model configuration should be cached.
+        """
+        is_sb, ckpt_file, is_local = self._check_model_source(source, save_path)
+
+        if is_sb or self.for_pretraining:
+            self.model = self.auto_class.from_config(self.config)
+
+        if is_sb:
+            self.model.gradient_checkpointing_disable()  # Required by DDP
+            # fetch the checkpoint file
+            ckpt_full_path = fetch(
+                filename=ckpt_file, source=source, savedir=save_path,
+            )
+            # We transfer the parameters from the checkpoint.
+            self._load_sb_pretrained_parameters(ckpt_full_path)
+        elif not self.for_pretraining:
+            self.model = self.auto_class.from_pretrained(
+                source,
+                config=self.config,
+                cache_dir=save_path,
+                quantization_config=self.quantization_config,
+            )
+
+    def _check_model_source(self, path, save_path):
+        """Checks if the pretrained model has been trained with SpeechBrain and
+        is hosted locally or on a HuggingFace hub.
+        Called as static function in HFTransformersInterface._from_pretrained.
+
+        Arguments
+        ---------
+        path : str
+            Used as "source"; local path or HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+        save_path : str
+            norm_output (dir) of the downloaded model.
+
+        Returns
+        -------
+        is_sb : bool
+            Whether/not the model is deserializable w/ SpeechBrain or not (then, model conversion is needed).
+        checkpoint_filename : str
+            as of HuggingFace documentation: file name relative to the repo root (guaranteed to be here).
+        is_local : bool
+            Whether/not the model is hosted locally or on a HuggingFace hub.
+
+        Raises
+        ------
+        ValueError
+            If file is not found
+        """
+        checkpoint_filename = ""
+        source = pathlib.Path(path)
+        is_local = True
+
+        # If path is a huggingface hub.
+        if not source.exists():
+            is_local = False
+
+        # Check if source is downloaded already
+        sink = pathlib.Path(
+            save_path + "/models--" + path.replace("/", "--") + "/snapshots"
+        )
+        if sink.exists():
+            sink = (
+                sink / os.listdir(str(sink))[0]
+            )  # there's a hash-id subfolder
+            if any(
+                File.endswith(".bin") or File.endswith(".ckpt")
+                for File in os.listdir(str(sink))
+            ):
+                is_local = True
+                local_path = str(sink)
+            else:
+                local_path = path
+        else:
+            local_path = path
+
+        if is_local:
+            # Test for HuggingFace model
+            if any(File.endswith(".bin") for File in os.listdir(local_path)):
+                is_sb = False
+                return is_sb, checkpoint_filename, is_local
+
+            # Test for SpeechBrain model and get the filename.
+            for File in os.listdir(local_path):
+                if File.endswith(".ckpt"):
+                    checkpoint_filename = os.path.join(path, File)
+                    is_sb = True
+                    return is_sb, checkpoint_filename, is_local
+        else:
+            files = model_info(
+                path
+            ).siblings  # get the list of files of the Hub
+
+            # Test if it's an HuggingFace model or a SB one
+            for File in files:
+                if File.rfilename.endswith(".ckpt"):
+                    checkpoint_filename = File.rfilename
+                    is_sb = True
+                    return is_sb, checkpoint_filename, is_local
+
+            for File in files:
+                if File.rfilename.endswith(".bin"):
+                    checkpoint_filename = File.rfilename
+                    is_sb = False
+                    return is_sb, checkpoint_filename, is_local
+
+        err_msg = f"{path} does not contain a .bin or .ckpt checkpoint !"
+        raise FileNotFoundError(err_msg)
+
+    def _modify_state_dict(self, path, **kwargs):
+        """A custom loading ensures SpeechBrain compatibility for pretrain and model.
+
+        For example, wav2vec2 model pretrained with SB (Wav2Vec2Pretrain) has slightly different keys from Wav2Vec2.
+        This method handle the compatibility between the two.
+
+        Users should modify this function according to their own tasks.
+
+        Arguments
+        ---------
+        path : str
+            Checkpoint path, file name relative to the repo root.
+        """
+        return None
+
+    def _load_sb_pretrained_parameters(self, path):
+        """Loads the parameter of a HuggingFace model pretrained with SpeechBrain
+        and the HuggingFace Pretrain Object. It is necessary to perform a custom
+        loading because HuggingFace adds a level to the checkpoint when storing
+        the model breaking the compatibility Pretrain and model de/serialization.
+
+        For example, a typical Wav2Vec2 checkpoint for a given parameter
+        would be: model.conv.weight.data while for Wav2Vec2Pretrain it
+        is: model.wav2vec2.weight.data (wav2vec2 must be removed before loading).
+
+        Arguments
+        ---------
+        path : pathlib.Path
+            The full path to the checkpoint.
+        """
+        modified_state_dict = self._modify_state_dict(path)
+
+        if modified_state_dict is None:
+            modified_state_dict = torch.load(path, map_location="cpu")
+
+        incompatible_keys = self.model.load_state_dict(
+            modified_state_dict, strict=False
+        )
+        for missing_key in incompatible_keys.missing_keys:
+            logger.warning(
+                f"During parameter transfer to {self.model} loading from "
+                + f"{path}, the transferred parameters did not have "
+                + f"parameters for the key: {missing_key}"
+            )
+        for unexpected_key in incompatible_keys.unexpected_keys:
+            logger.warning(
+                f"The param with the key: {unexpected_key} is discarded as it "
+                + f"is useless for finetuning this {type(self.model).__name__} model."
+            )
+
+    def forward(self, **kwargs):
+        """Users should modify this function according to their own tasks."""
+        raise NotImplementedError
+
+    def forward_encoder(self, **kwargs):
+        """Users should modify this function according to their own tasks."""
+        raise NotImplementedError
+
+    def forward_decoder(self, **kwargs):
+        """Users should modify this function according to their own tasks."""
+        raise NotImplementedError
+
+    def decode(self, **kwargs):
+        """Might be useful for models like mbart, which can exploit SB's beamsearch for inference
+        Users should modify this function according to their own tasks."""
+        raise NotImplementedError
+
+    def encode(self, **kwargs):
+        """Customed encoding for inference
+        Users should modify this function according to their own tasks."""
+        raise NotImplementedError
+
+    def freeze_model(self, model):
+        """
+        Freezes parameters of a model.
+        This should be overrided too, depending on users' needs, for example, adapters use.
+
+        Arguments
+        ---------
+        model : from AutoModel.from_config
+            Valid HuggingFace transformers model object.
+        """
+        model.eval()
+        for param in model.parameters():
+            param.requires_grad = False
+
+    def override_config(self, config):
+        """Users should modify this function according to their own tasks.
+
+        Arguments
+        ---------
+        config : HuggingFace config object
+            The orginal config.
+
+        Returns
+        ---------
+        config : HuggingFace config object
+            Overridden config.
+        """
+        return config
+
+    def load_feature_extractor(self, source, cache_dir, **kwarg):
+        """Load model's feature_extractor from the hub.
+
+        Arguments
+        ---------
+        source : str
+            HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+        cache_dir : str
+            Path (dir) in which a downloaded pretrained model configuration should be cached.
+        **kwarg
+            Keyword arguments to pass to the AutoFeatureExtractor.from_pretrained() method.
+        """
+        self.feature_extractor = AutoFeatureExtractor.from_pretrained(
+            source, cache_dir=cache_dir, **kwarg
+        )
+
+    def load_tokenizer(self, source, **kwarg):
+        """Load model's tokenizer from the hub.
+
+        Arguments
+        ---------
+        source : str
+            HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+        **kwarg
+            Keyword arguments to pass to the AutoFeatureExtractor.from_pretrained() method.
+        """
+        self.tokenizer = AutoTokenizer.from_pretrained(source, **kwarg)
+
+
+def make_padding_masks(src, wav_len=None, pad_idx=0):
+    """This method generates the padding masks.
+
+    Arguments
+    ---------
+    src : tensor
+        The sequence to the encoder (required).
+    wav_len : tensor
+        The relative length of the wav given in SpeechBrain format.
+    pad_idx : int
+        The index for <pad> token (default=0).
+
+    Returns
+    ---------
+    src_key_padding_mask : tensor
+        The padding mask.
+    """
+    src_key_padding_mask = None
+    if wav_len is not None:
+        abs_len = torch.round(wav_len * src.shape[1])
+        src_key_padding_mask = length_to_mask(abs_len).bool()
+
+    return src_key_padding_mask
diff --git a/speechbrain/lobes/models/huggingface_transformers/labse.py b/speechbrain/lobes/models/huggingface_transformers/labse.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d68ac2cb3959a3aa545c81d7fbe2bd01e6f8d88
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/labse.py
@@ -0,0 +1,111 @@
+"""This lobe enables the integration of huggingface pretrained LaBSE models.
+Reference: https://arxiv.org/abs/2007.01852
+
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Ha Nguyen 2023
+"""
+
+import torch
+import logging
+import torch.nn.functional as F
+import os
+
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    HFTransformersInterface,
+)
+
+logger = logging.getLogger(__name__)
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+class LaBSE(HFTransformersInterface):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained LaBSE models.
+
+    Source paper LaBSE: https://arxiv.org/abs/2007.01852
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model can be used as a fixed text-based sentence-level embeddings generator or can be finetuned.
+    It will download automatically the model from HuggingFace or use a local path.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "setu4993/LaBSE"
+    save_path : str
+        Path (dir) of the downloaded model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    output_norm : bool (default: True)
+        If True, normalize the output.
+    Example
+    -------
+    >>> inputs = ["La vie est belle"]
+    >>> model_hub = "setu4993/smaller-LaBSE"
+    >>> save_path = "savedir"
+    >>> model = LaBSE(model_hub, save_path)
+    >>> outputs = model(inputs)
+    """
+
+    def __init__(
+        self, source, save_path, freeze=True, output_norm=True,
+    ):
+        super().__init__(source=source, save_path=save_path, freeze=freeze)
+
+        self.load_tokenizer(source=source)
+
+        self.output_norm = output_norm
+
+    def forward(self, input_texts):
+        """This method implements a forward of the labse model,
+        which generates sentence-level embeddings from input text.
+
+        Arguments
+        ----------
+        input_texts (translation): list
+            The list of texts (required).
+        """
+
+        # Transform input to the right format of the LaBSE model.
+        if self.freeze:
+            with torch.no_grad():
+                # Tokenize the input text before feeding to LaBSE model.
+                input_texts = self.tokenizer(
+                    input_texts, return_tensors="pt", padding=True
+                )
+                # Set the right device for the input.
+                for key in input_texts.keys():
+                    input_texts[key] = input_texts[key].to(
+                        device=self.model.device
+                    )
+                    input_texts[key].requires_grad = False
+
+                embeddings = self.model(**input_texts).pooler_output
+
+                if self.output_norm:
+                    # Output normalizing if needed.
+                    embeddings = F.normalize(embeddings, p=2)
+
+                return embeddings
+
+        # Tokenize the input text before feeding to LaBSE model.
+        input_texts = self.tokenizer(
+            input_texts, return_tensors="pt", padding=True
+        )
+        # Set the right device for the input.
+        for key in input_texts.keys():
+            input_texts[key] = input_texts[key].to(device=self.model.device)
+
+        embeddings = self.model(**input_texts).pooler_output
+
+        if self.output_norm:
+            # Output normalizing if needed.
+            embeddings = F.normalize(embeddings, p=2)
+
+        return embeddings
diff --git a/speechbrain/lobes/models/huggingface_transformers/llama2.py b/speechbrain/lobes/models/huggingface_transformers/llama2.py
new file mode 100644
index 0000000000000000000000000000000000000000..977577bf7b97a9123fe82284d0084dee80ec6377
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/llama2.py
@@ -0,0 +1,406 @@
+"""This lobe enables the integration of huggingface pretrained LLAMA2-chat model.
+
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Pooneh Mousavi 2023
+ * Ha Nguyen 2023
+"""
+
+import logging
+from torch import Tensor
+import torch
+
+import torch.nn as nn
+from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    HFTransformersInterface,
+)
+from transformers import BitsAndBytesConfig
+
+from bitsandbytes.nn import Linear4bit
+
+logger = logging.getLogger(__name__)
+
+
+class LLAMA2(HFTransformersInterface):
+    """This lobe enables the integration of HuggingFace pretrained LLAMA2 model.
+     Source paper LLAMA2:
+       https://arxiv.org/abs/2307.09288
+    Transformer from HuggingFace needs to be installed:
+        https://huggingface.co/transformers/installation.html
+
+    The model can be finetuned. It will download automatically the model from
+    HuggingFace or use a local path.
+
+    Notes:
+    - To use this model, you need to install the extra dependencies in recipes/MultiWOZ/response_generation/llama2/extra_requirements.txt
+    - transformers and peft libraries should follow the versions mentioned in the extra_requirements.
+    - Llama 2 is licensed under the LLAMA 2 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "meta-llama/Llama-2-7b-chat-hf"
+    save_path : str
+        Path (dir) of the downloaded model.
+    freeze : bool (default: False)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    max_new_tokens: int (default: 200)
+    use_4bit: bool (default: True)
+    bnb_4bit_compute_dtype: str (default: "float16")
+        This sets the computational type which might be different than the input time. For example, inputs might be fp32, but computation can be set to bf16 for speedups.
+    bnb_4bit_quant_type: str (default:"nf4")
+        This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types which are specified by fp4 or nf4.
+    use_nested_quant: bool (default: False)
+        You have set this to False, which means you're not using nested quantization. This seems reasonable, as nested quantization can be computationally expensive.
+    min_length: int (default: 1)
+        The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + min_new_tokens. Its effect is overridden by min_new_tokens, if also set
+    top_k: int (default: 45)
+        The number of highest probability vocabulary tokens to keep for top-k-filtering.
+    top_p: float (default: 0.9)
+        If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
+    num_beams: int (default: 8)
+         Number of beams for beam search. 1 means no beam search.
+    early_stopping: bool (default: True)
+        Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
+        - True, where the generation stops as soon as there are num_beams complete candidates
+        - False, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates
+        - "never", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm).
+    with_peft: bool (default:False)
+        If set to True, the peft model (model + adaptors) are loaded. If set to False, the original model is loaded.
+
+    Example
+    -------
+    >>> model_hub = "meta-llama/Llama-2-7b-chat-hf"
+    >>> save_path = "savedir"
+    >>> model = LLAMA2(model_hub, save_path)
+    >>> tokens = torch.tensor([[1, 1]])
+    >>> attention_mask = torch.tensor([[1, 1]])
+    >>> outputs = model(tokens, attention_mask)
+    """
+
+    def __init__(
+        self,
+        source: str,
+        save_path: str,
+        freeze: bool = False,
+        max_new_tokens: int = 200,
+        use_4bit: bool = True,
+        bnb_4bit_compute_dtype: str = "float16",
+        bnb_4bit_quant_type: str = "nf4",
+        use_nested_quant: bool = False,
+        min_length: int = 1,
+        top_k: int = 45,
+        top_p: float = 0.9,
+        num_beams: int = 8,
+        early_stopping: bool = True,
+        with_peft: bool = False,
+    ) -> None:
+
+        self.with_peft = with_peft
+        self.max_new_tokens = max_new_tokens
+        self.min_length = min_length
+        self.top_k = top_k
+        self.top_p = top_p
+        self.num_beams = num_beams
+        self.early_stopping = early_stopping
+        self.source = source
+        self.save_path = save_path
+        self.is_sb = False
+
+        compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
+        self.bnb_config = None
+        if with_peft:
+            self.bnb_config = BitsAndBytesConfig(
+                load_in_4bit=use_4bit,
+                bnb_4bit_quant_type=bnb_4bit_quant_type,
+                bnb_4bit_compute_dtype=compute_dtype,
+                bnb_4bit_use_double_quant=use_nested_quant,
+            )
+            # Check GPU compatibility with bfloat16
+            if compute_dtype == torch.float16 and use_4bit:
+                major, _ = torch.cuda.get_device_capability()
+                if major >= 8:
+                    logger.info("=" * 80)
+                    logger.info(
+                        "Your GPU supports bfloat16: accelerate training with bf16=True"
+                    )
+                    logger.info("=" * 80)
+
+        super().__init__(
+            source=source,
+            save_path=save_path,
+            freeze=freeze,
+            with_casual_lm=True,
+            quantization_config=self.bnb_config,
+        )
+
+        self.load_tokenizer(source=source, pad_token=None, use_fast=False)
+        # Define a custom padding token
+        self.tokenizer.pad_token = "<PAD>"
+        # Set the padding direction to the right
+        self.tokenizer.padding_side = "right"
+
+        # Here we deal with quantization
+        # If the loaded model is an SB checkpoint, skip this because we also do it in _modify_state_dict
+        if with_peft and not self.is_sb:
+            self.model = prepare_model_for_kbit_training(self.model)
+
+            config = LoraConfig(
+                lora_alpha=16,
+                lora_dropout=0.1,
+                r=64,
+                bias="none",
+                task_type="CAUSAL_LM",
+            )
+
+            self.model = get_peft_model(self.model, config)
+        self.print_trainable_parameters(self.model)
+
+    def forward(
+        self, input_ids: Tensor, attention_mask: Tensor,
+    ):
+        """ Takes an input a history of conversation and returns its corresponding reply.
+
+        Arguments
+        ---------
+        input_ids : torch.Tensor ()
+            A batch of input-id to transform to features.
+        attention_mask : torch.Tensor ()
+            A batch of attention_mask.
+        """
+        with torch.set_grad_enabled(not self.freeze):
+            output = self.model.forward(
+                input_ids, attention_mask=attention_mask,
+            )
+        return output
+
+    def _modify_state_dict(self, path, replacables=["base_model"]):
+        """A custom loading ensures SpeechBrain compatibility for Pretrain and model
+        de/serialization. Here, the scope is to remove '.wav2vec2' before loading.
+
+        Arguments
+        ---------
+        path : str
+            Checkpoint path, file name relative to the repo root.
+        replacables : List[str]
+            State dict sub-keys that if found, shall be dropped (incl. the 'model.' parent key), elevating key structures.
+
+        Returns
+        -------
+        modified_state_dict : see torch.load
+            SpeechBrain-valid deserialized pretrained model.
+        """
+
+        # Set is_sb = True for the ckpt is SB's nature
+        self.is_sb = True
+
+        # Load the state_dict of the ckpt
+        orig_state_dict = torch.load(path, map_location="cpu")
+
+        # Check if the dimension of the embed_tokens layer is greater than the vocab size defined by the HF Llama config
+        # If it is True, enlarge this layer
+        # This happens because sometimes one wants to add a <pad> token to the vocab.
+        desired_key = next(
+            (key for key in orig_state_dict if "embed_tokens.weight" in key),
+            None,
+        )
+        new_num_tokens = (
+            orig_state_dict.get(desired_key).size(0)
+            - self.model.config.vocab_size
+        )
+        if new_num_tokens > 0:
+            self.model.resize_token_embeddings(new_num_tokens=32001)
+
+        # Here we deal with quantization
+        if self.with_peft:
+            from transformers.integrations import replace_with_bnb_linear
+
+            self.model = replace_with_bnb_linear(
+                self.model,
+                modules_to_not_convert=["lm_head"],
+                quantization_config=self.bnb_config,
+            )
+
+            from transformers.modeling_utils import (
+                _load_state_dict_into_meta_model,
+            )
+
+            state_dict = self.model.state_dict()
+            for key in state_dict.keys():
+                state_dict[key] = torch.rand(
+                    state_dict[key].shape, dtype=torch.float16, device="cpu"
+                )
+
+            (
+                new_error_msgs,
+                offload_index,
+                state_dict_index,
+            ) = _load_state_dict_into_meta_model(
+                model=self.model,
+                state_dict=state_dict,
+                loaded_state_dict_keys=state_dict.keys(),
+                start_prefix="",
+                expected_keys=state_dict.keys(),
+                device_map={"": 0},
+                dtype=torch.float16,
+                is_quantized=True,
+            )
+
+            from transformers.utils.quantization_config import (
+                QuantizationMethod,
+            )
+
+            self.model._is_quantized_training_enabled = True
+            self.model.is_8bit_serializable = True
+            self.model.quantization_method = QuantizationMethod.BITS_AND_BYTES
+            self.model.is_quantized = True
+            self.model.is_loaded_in_4bit = True
+            self.model.is_loaded_in_8bit = False
+
+            quantization_config = {}
+            quantization_config["bnb_4bit_compute_dtype"] = "float16"
+            quantization_config["bnb_4bit_quant_type"] = "nf4"
+            quantization_config["bnb_4bit_use_double_quant"] = False
+            quantization_config["llm_int8_enable_fp32_cpu_offload"] = False
+            quantization_config["llm_int8_has_fp16_weight"] = False
+            quantization_config["llm_int8_skip_modules"] = None
+            quantization_config["llm_int8_threshold"] = 6.0
+            quantization_config["load_in_4bit"] = True
+            quantization_config["load_in_8bit"] = False
+            quantization_config["quant_method"] = "bitsandbytes"
+
+            self.model.config.quantization_config = quantization_config
+
+            from accelerate import dispatch_model
+
+            device_map_kwargs = {
+                "device_map": {"": 0},
+                "offload_dir": None,
+                "offload_index": None,
+                "skip_keys": "past_key_values",
+            }
+
+            dispatch_model(self.model, **device_map_kwargs)
+
+            self.model = prepare_model_for_kbit_training(self.model)
+
+            lora_config = LoraConfig(
+                lora_alpha=16,
+                lora_dropout=0.1,
+                r=64,
+                bias="none",
+                task_type="CAUSAL_LM",
+            )
+
+            self.model = get_peft_model(self.model, lora_config)
+
+        modified_state_dict = {}
+        # Matching the state_dict of the ckpt with that of the HF Llama model.
+        for key, params in orig_state_dict.items():
+            for tag in replacables:
+                if f"{tag}" in key:
+                    save_key = key.replace(f"model.{tag}", f"{tag}")
+                    modified_state_dict[save_key] = params
+        return modified_state_dict
+
+    def replace_linear(self, module):
+        """Modify the loaded module linear layers with Linear4bit to be compatible
+
+        Arguments
+        ---------
+        module : nn.nodule
+            llama2 model.
+        """
+        for name, child in module.named_children():
+            if isinstance(child, nn.Linear) and name != "lm_head":
+                # Replace Linear layer with your custom layer
+                setattr(
+                    module,
+                    name,
+                    Linear4bit(
+                        child.in_features, child.out_features, bias=child.bias
+                    ),
+                )
+            else:
+                self.replace_linear(child)
+
+    def generate(
+        self, input_ids: Tensor, attention_mask: Tensor, decoder_type="greedy",
+    ):
+        """ Takes an input a history of conversation and returns its corresponding reply.
+
+        Arguments
+        --------
+        input_ids : torch.Tensor ()
+            A batch of input-id   which are dialogue context tokens
+        # decoder_type : Str
+        #     It shows strategy for autoregressive decoding either beam seach or greedy.
+        # attention_mask : torch.Tensor ()
+        #     A batch of attention_mask.
+        """
+
+        with torch.no_grad():
+            if decoder_type == "beam":
+                # beam decoding based on the input_ids which are dialogue context tokens (here only history)
+                hyp = self.model.generate(
+                    input_ids=input_ids,
+                    attention_mask=attention_mask,
+                    do_sample=True,
+                    max_new_tokens=self.max_new_tokens,
+                    min_length=self.min_length,
+                    top_k=self.top_k,
+                    top_p=self.top_p,
+                    temperature=1.0,
+                    num_beams=self.num_beams,
+                    num_return_sequences=1,
+                    repetition_penalty=1.0,
+                    length_penalty=1,
+                    early_stopping=self.early_stopping,
+                )
+            else:
+                # greedy decoding based on the input_ids which are dialogue context tokens (here only history)
+                hyp = self.model.generate(
+                    input_ids=input_ids,
+                    max_new_tokens=self.max_new_tokens,
+                    attention_mask=attention_mask,
+                )
+        return hyp
+
+    def override_config(self, config):
+        """override config to include quantization config.
+
+        Arguments
+        ---------
+        config : HuggingFace config object
+            The orginal config.
+
+        Returns
+        ---------
+        config : HuggingFace config object
+            Overridden config.
+        """
+        if self.bnb_config:
+            config = config.from_pretrained(
+                self.source,
+                cache_dir=self.save_path,
+                quantization_config=self.bnb_config,
+            )
+        return config
+
+    def print_trainable_parameters(self, model):
+        """
+        Prints the number of trainable parameters in the model.
+        """
+        trainable_params = 0
+        all_param = 0
+        for _, param in model.named_parameters():
+            all_param += param.numel()
+            if param.requires_grad:
+                trainable_params += param.numel()
+        logger.info(
+            f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
+        )
diff --git a/speechbrain/lobes/models/huggingface_transformers/mbart.py b/speechbrain/lobes/models/huggingface_transformers/mbart.py
new file mode 100644
index 0000000000000000000000000000000000000000..1749cddd5fb963dad05cec99f63bc0aed38e56ad
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/mbart.py
@@ -0,0 +1,201 @@
+"""This lobe enables the integration of huggingface pretrained mBART models.
+Reference: https://arxiv.org/abs/2001.08210
+
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Ha Nguyen 2023
+"""
+
+import torch
+import logging
+
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    HFTransformersInterface,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class mBART(HFTransformersInterface):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained mBART models.
+
+    Source paper mBART: https://arxiv.org/abs/2001.08210
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model is normally used as a text decoder of seq2seq models. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "facebook/mbart-large-50-many-to-many-mmt"
+    save_path : str
+        Path (dir) of the downloaded model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    target_lang: str (default: fra_Latn (a.k.a French)
+        The target language code according to NLLB model.
+    decoder_only : bool (default: True)
+        If True, only take the decoder part (and/or the lm_head) of the model.
+        This is useful in case one wants to couple a pre-trained speech encoder (e.g. wav2vec)
+        with a text-based pre-trained decoder (e.g. mBART, NLLB).
+    share_input_output_embed : bool (default: True)
+        If True, use the embedded layer as the lm_head.
+
+    Example
+    -------
+    >>> src = torch.rand([10, 1, 1024])
+    >>> tgt = torch.LongTensor([[250008,    313,     25,    525,    773,  21525,   4004,      2]])
+    >>> model_hub = "facebook/mbart-large-50-many-to-many-mmt"
+    >>> save_path = "savedir"
+    >>> model = mBART(model_hub, save_path) # doctest: +SKIP
+    >>> outputs = model(src, tgt) # doctest: +SKIP
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        freeze=True,
+        target_lang="fr_XX",
+        decoder_only=True,
+        share_input_output_embed=True,
+    ):
+        super().__init__(
+            source=source, save_path=save_path, freeze=freeze, seq2seqlm=True,
+        )
+
+        self.target_lang = target_lang
+        self.decoder_only = decoder_only
+        self.share_input_output_embed = share_input_output_embed
+
+        self.load_tokenizer(source=source, pad_token=None, tgt_lang=target_lang)
+
+        if share_input_output_embed:
+            self.model.lm_head.weight = (
+                self.model.model.decoder.embed_tokens.weight
+            )
+            self.model.lm_head.requires_grad = False
+            self.model.model.decoder.embed_tokens.requires_grad = False
+
+        if decoder_only:
+            # When we only want to use the decoder part
+            del self.model.model.encoder
+
+        for k, p in self.model.named_parameters():
+            # It is a common practice to only fine-tune the encoder_attn and layer_norm layers of this model.
+            if "encoder_attn" in k or "layer_norm" in k:
+                p.requires_grad = True
+            else:
+                p.requires_grad = False
+
+    def forward(self, src, tgt, pad_idx=0):
+        """This method implements a forward step for mt task using a wav2vec encoder
+        (same than above, but without the encoder stack)
+
+        Arguments
+        ----------
+        src (transcription): tensor
+            output features from the w2v2 encoder
+        tgt (translation): tensor
+            The sequence to the decoder (required).
+        pad_idx : int
+            The index for <pad> token (default=0).
+        """
+
+        # should we replace 0 elements by pax_idx as pad_idx of mbart model seems to be different from 0?
+        tgt = self.custom_padding(
+            tgt, 0, self.model.model.decoder.config.pad_token_id
+        )
+
+        if self.freeze:
+            with torch.no_grad():
+                if hasattr(self.model.model, "encoder"):
+                    src = self.model.model.encoder(
+                        inputs_embeds=src
+                    ).last_hidden_state.detach()
+                dec_out = self.model.model.decoder(
+                    input_ids=tgt, encoder_hidden_states=src
+                ).last_hidden_state.detach()
+                dec_out = self.model.lm_head(dec_out).detach()
+                return dec_out
+
+        if hasattr(self.model.model, "encoder"):
+            src = self.model.model.encoder(inputs_embeds=src).last_hidden_state
+        dec_out = self.model.model.decoder(
+            input_ids=tgt, encoder_hidden_states=src
+        ).last_hidden_state
+        dec_out = self.model.lm_head(dec_out)
+        return dec_out
+
+    @torch.no_grad()
+    def decode(self, tgt, encoder_out, enc_len=None):
+        """This method implements a decoding step for the transformer model.
+
+        Arguments
+        ---------
+        tgt : torch.Tensor
+            The sequence to the decoder.
+        encoder_out : torch.Tensor
+            Hidden output of the encoder.
+        enc_len : torch.LongTensor
+            The actual length of encoder states.
+        """
+
+        if tgt.dtype not in [torch.long, torch.int64]:
+            tgt = tgt.long()
+
+        tgt_mask = torch.ones(tgt.size(), device=tgt.device)
+
+        output = self.model.model.decoder(
+            input_ids=tgt,
+            encoder_hidden_states=encoder_out,
+            attention_mask=tgt_mask,
+            output_attentions=True,
+        )
+
+        return (
+            self.model.lm_head(output.last_hidden_state),
+            output.cross_attentions[-1],
+        )
+
+    def custom_padding(self, x, org_pad, custom_pad):
+        """This method customizes the padding.
+        Default pad_idx of SpeechBrain is 0.
+        However, it happens that some text-based models like mBART reserves 0 for something else,
+        and are trained with specific pad_idx.
+        This method change org_pad to custom_pad
+
+        Arguments
+        ---------
+        x : torch.Tensor
+          Input tensor with original pad_idx
+        org_pad : int
+          Orginal pad_idx
+        custom_pad : int
+          Custom pad_idx
+        """
+        out = x.clone()
+        out[x == org_pad] = custom_pad
+
+        return out
+
+    def override_config(self, config):
+        """If the config needs to be overrided, here is the place.
+
+        Arguments
+        ---------
+        config : MBartConfig
+            The original config needs to be overrided.
+
+        Returns
+        -------
+        Overridded config
+        """
+        config.decoder_layerdrop = 0.05
+        return config
diff --git a/speechbrain/lobes/models/huggingface_transformers/nllb.py b/speechbrain/lobes/models/huggingface_transformers/nllb.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ad0dc93eb9247386974e770edc6d2358fee0c9e
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/nllb.py
@@ -0,0 +1,76 @@
+"""This lobe enables the integration of huggingface pretrained NLLB models.
+Reference: https://arxiv.org/abs/2207.04672
+
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Ha Nguyen 2023
+"""
+
+import logging
+
+from speechbrain.lobes.models.huggingface_transformers.mbart import mBART
+
+logger = logging.getLogger(__name__)
+
+
+class NLLB(mBART):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained NLLB models.
+
+    Source paper NLLB: https://arxiv.org/abs/2207.04672
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model is normally used as a text decoder of seq2seq models. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    For now, HuggingFace's NLLB model can be loaded using the exact code for mBART model.
+    For this reason, NLLB can be fine inheriting the mBART class.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "facebook/nllb-200-1.3B"
+    save_path : str
+        Path (dir) of the downloaded model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    target_lang: str (default: fra_Latn (a.k.a French)
+        The target language code according to NLLB model.
+    decoder_only : bool (default: True)
+        If True, only take the decoder part (and/or the lm_head) of the model.
+        This is useful in case one wants to couple a pre-trained speech encoder (e.g. wav2vec)
+        with a text-based pre-trained decoder (e.g. mBART, NLLB).
+    share_input_output_embed : bool (default: True)
+        If True, use the embedded layer as the lm_head.
+    Example
+    -------
+    >>> import torch
+    >>> src = torch.rand([10, 1, 1024])
+    >>> tgt = torch.LongTensor([[256057,    313,     25,    525,    773,  21525,   4004,      2]])
+    >>> model_hub = "facebook/nllb-200-distilled-600M"
+    >>> save_path = "savedir"
+    >>> model = NLLB(model_hub, save_path)
+    >>> outputs = model(src, tgt)
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        freeze=True,
+        target_lang="fra_Latn",
+        decoder_only=True,
+        share_input_output_embed=True,
+    ):
+        super().__init__(
+            source=source,
+            save_path=save_path,
+            freeze=freeze,
+            target_lang=target_lang,
+            decoder_only=decoder_only,
+            share_input_output_embed=share_input_output_embed,
+        )
diff --git a/speechbrain/lobes/models/huggingface_transformers/vocos.py b/speechbrain/lobes/models/huggingface_transformers/vocos.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdfa44520e80fbdaefbb7585a53047ac077908b1
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/vocos.py
@@ -0,0 +1,151 @@
+"""This lobe enables the integration of huggingface pretrained
+Vocos model.
+
+Vocos is a vocoder trained on top of EnCodec tokens. While
+EnCodec itself can be used for a lossy reconstruction of speech,
+a vocoder, such as Vocos, can be used to improve the quality.
+
+Repository: https://huggingface.co/charactr/vocos-encodec-24khz
+Paper: https://arxiv.org/pdf/2306.00814.pdf
+
+TODO: There is an open feature request to add this model to
+HuggingFace Transformers.
+
+If this is impemented, it will be possible to make this model
+inherit from HFTransformersInterface
+
+https://github.com/huggingface/transformers/issues/25123
+
+Authors
+ * Artem Ploujnikov 2023
+"""
+
+import torch
+import logging
+from torch import nn
+from speechbrain.dataio.dataio import length_to_mask
+from huggingface_hub import hf_hub_download
+
+try:
+    from vocos import Vocos as VocosModel
+    from vocos.feature_extractors import EncodecFeatures
+except ImportError:
+    MSG = "Please install vocos to use the Vocos model\n"
+    MSG += "E.G. run: pip install vocos"
+    raise ImportError(MSG)
+
+
+DEFAULT_SAMPLE_RATE = 24000
+BANDWIDTHS = [1.5, 3.0, 6.0, 12.0]
+
+logger = logging.getLogger(__name__)
+
+
+class Vocos(nn.Module):
+    """An wrapper for the HuggingFace Vocos model
+
+    Arguments
+    ---------
+    source : str
+        A HuggingFace repository identifier or a path
+    save_path : str
+        The location where the pretrained model will be saved
+    revision : str
+        The model revision
+    bandwidth : float
+        The bandwidth value
+        Supported:
+        1.5, 3.0, 6.0, 12.0
+    freeze : bool
+        Whether or not parameters should be
+        frozen
+
+    Example
+    -------
+    >>> model_hub = "charactr/vocos-encodec-24khz"
+    >>> save_path = "savedir"
+    >>> model = Vocos(model_hub, save_path)
+    >>> tokens = torch.randint(1024, (4, 10, 2))
+    >>> length = torch.tensor([1.0, 0.5, 0.75, 1.0])
+    >>> audio, out_length = model(tokens, length)
+    >>> audio.shape
+    torch.Size([4, 3200])
+    >>> out_length
+    tensor([1.0000, 0.5000, 0.7500, 1.0000])
+    """
+
+    def __init__(
+        self, source, save_path, revision=None, bandwidth=1.5, freeze=True,
+    ):
+        super().__init__()
+        self.source = source
+        self.save_path = save_path
+        self.revision = revision
+        self.model = self._load_model()
+        self.freeze = freeze
+        self.bandwidth = bandwidth
+        self.bandwidth_id = (
+            (torch.tensor(BANDWIDTHS) - bandwidth).abs().argmin().item()
+        )
+        if self.freeze:
+            logger.warning("huggingface_Vocos - Vocos is frozen.")
+            for param in self.model.parameters():
+                param.requires_grad = False
+
+    def _load_model(self):
+        """Loads the pretrained model. This is a customized implementation of
+        Vocos.from_pretrained(), which has been customized to specify an
+        alternate cache_dir"""
+        config_path = hf_hub_download(
+            repo_id=self.source,
+            filename="config.yaml",
+            revision=self.revision,
+            cache_dir=self.save_path,
+        )
+        model_path = hf_hub_download(
+            repo_id=self.source,
+            filename="pytorch_model.bin",
+            revision=self.revision,
+            cache_dir=self.save_path,
+        )
+        model = VocosModel.from_hparams(config_path)
+        state_dict = torch.load(model_path, map_location="cpu")
+        if isinstance(model.feature_extractor, EncodecFeatures):
+            encodec_parameters = {
+                "feature_extractor.encodec." + key: value
+                for key, value in model.feature_extractor.encodec.state_dict().items()
+            }
+            state_dict.update(encodec_parameters)
+        model.load_state_dict(state_dict)
+        model.eval()
+        return model
+
+    def forward(self, inputs, length):
+        """Converts EnCodec tokens to audio
+
+        Arguments
+        ---------
+        inputs : torch.Tensor
+            A tensor of EnCodec tokens
+        length : torch.Tensor
+            A 1-D tensor of relative lengths
+
+        Returns
+        -------
+        wavs : torch.Tensor
+            A (Batch x Length) tensor of raw waveforms
+        length : torch.Tensor
+            Relative lengths
+        """
+        with torch.set_grad_enabled(not self.freeze):
+            features = self.model.codes_to_features(inputs.permute(2, 0, 1))
+            wavs = self.model.decode(
+                features,
+                bandwidth_id=torch.tensor(
+                    [self.bandwidth_id], device=inputs.device
+                ),
+            )
+            mask = length_to_mask(
+                length * wavs.size(1), max_len=wavs.size(1), device=wavs.device
+            )
+            return wavs * mask, length
diff --git a/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py b/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9594b3393485c7c5c328fd32557eb4adf01f658b
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py
@@ -0,0 +1,311 @@
+"""This lobe enables the integration of huggingface pretrained wav2vec2 models.
+
+Reference: https://arxiv.org/abs/2006.11477
+Reference: https://arxiv.org/abs/1904.05862
+Reference: https://arxiv.org/abs/2110.13900
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Titouan Parcollet 2021
+ * Boumadane Abdelmoumene 2021
+ * Ha Nguyen 2023
+"""
+
+import torch
+import logging
+import numpy as np
+import torch.nn.functional as F
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    HFTransformersInterface,
+)
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    make_padding_masks,
+)
+import transformers
+from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
+
+logger = logging.getLogger(__name__)
+
+
+class Wav2Vec2(HFTransformersInterface):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained wav2vec2.0/Hubert models.
+
+    Source paper wav2vec2.0: https://arxiv.org/abs/2006.11477
+    Source paper Hubert: https://arxiv.org/abs/2106.07447
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model can be used as a fixed feature extractor or can be finetuned. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+    save_path : str
+        Path (dir) of the downloaded model.
+    output_norm : bool (default: True)
+        If True, a layer_norm (affine) will be applied to the output obtained
+        from the wav2vec model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    freeze_feature_extractor :  bool (default: False)
+        When freeze = False and freeze_feature_extractor True, the featue_extractor module of the model is Frozen. If False
+        all the wav2vec model will be trained including featue_extractor module.
+    apply_spec_augment : bool (default: False)
+        If True, the model will apply spec augment on the output of feature extractor
+        (inside huggingface Wav2VecModel() class).
+        If False, the model will not apply spec augment. We set this to false to prevent from doing it twice.
+    output_all_hiddens : bool (default: False)
+        If True, the forward function outputs the hidden states from all transformer layers.
+        For example wav2vec2-base has 12 transformer layers and the output is of shape (13, B, T, C),
+        where a projection of the CNN output is added to the beginning.
+        If False, the forward function outputs the hidden states only from the last transformer layer.
+
+    Example
+    -------
+    >>> inputs = torch.rand([10, 600])
+    >>> model_hub = "facebook/wav2vec2-base-960h"
+    >>> save_path = "savedir"
+    >>> model = Wav2Vec2(model_hub, save_path)
+    >>> outputs = model(inputs)
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        output_norm=False,
+        freeze=False,
+        freeze_feature_extractor=False,
+        apply_spec_augment=False,
+        output_all_hiddens=False,
+    ):
+        super().__init__(source=source, save_path=save_path, freeze=freeze)
+
+        self.model.config.apply_spec_augment = apply_spec_augment
+
+        # We check if inputs need to be normalized w.r.t pretrained wav2vec2
+        self.load_feature_extractor(source, cache_dir=save_path)
+        self.normalize_wav = self.feature_extractor.do_normalize
+
+        self.freeze_feature_extractor = freeze_feature_extractor
+        if not self.freeze and self.freeze_feature_extractor:
+            logger.warning(
+                "speechbrain.lobes.models.huggingface_transformers.wav2vec2 - wav2vec 2.0 feature extractor is frozen."
+            )
+            self.model.feature_extractor.eval()
+            for param in self.model.feature_extractor.parameters():
+                param.requires_grad = False
+
+        self.output_norm = output_norm
+        self.output_all_hiddens = output_all_hiddens
+
+    def _modify_state_dict(self, path, replacables=["wav2vec2"]):
+        """A custom loading ensures SpeechBrain compatibility for Pretrain and model
+        de/serialization. Here, the scope is to remove '.wav2vec2' before loading.
+
+        Arguments
+        ---------
+        path : str
+            Checkpoint path, file name relative to the repo root.
+        replacables : List[str]
+            State dict sub-keys that if found, shall be dropped (incl. the 'model.' parent key), elevating key structures.
+
+        Returns
+        -------
+        modified_state_dict : see torch.load
+            SpeechBrain-valid deserialized pretrained model.
+        """
+        modified_state_dict = {}
+        orig_state_dict = torch.load(path, map_location="cpu")
+
+        # We remove the .wav2vec2 in the state dict.
+        for key, params in orig_state_dict.items():
+            for tag in replacables:
+                if f"{tag}." in key:
+                    save_key = key.replace(f"model.{tag}.", "")
+                    modified_state_dict[save_key] = params
+        return modified_state_dict
+
+    def forward(self, wav, wav_lens=None):
+        """Takes an input waveform and return its corresponding wav2vec encoding.
+
+        Arguments
+        ---------
+        wav : torch.Tensor (signal)
+            A batch of audio signals to transform to features.
+        wav_len : tensor
+            The relative length of the wav given in SpeechBrain format.
+        """
+
+        # If we freeze, we simply remove all grads from the graph.
+        if self.freeze:
+            with torch.no_grad():
+                return self.extract_features(wav, wav_lens)
+
+        return self.extract_features(wav, wav_lens)
+
+    def extract_features(self, wav, wav_lens=None):
+        """Takes an input waveform and return its corresponding wav2vec encoding.
+
+        Arguments
+        ---------
+        wav : torch.Tensor (signal)
+            A batch of audio signals to transform to features.
+        wav_len : tensor
+            The relative length of the wav given in SpeechBrain format.
+        """
+
+        padding_mask = make_padding_masks(wav, wav_len=wav_lens)
+
+        if self.normalize_wav:
+            wav = F.layer_norm(wav, wav.shape[1:])
+
+        # Extract wav2vec output
+        out = self.model(
+            wav,
+            attention_mask=padding_mask,
+            output_hidden_states=self.output_all_hiddens,
+        )
+
+        if self.output_all_hiddens:
+            out = torch.stack(list(out.hidden_states), dim=0)
+            norm_shape = out.shape[-3:]
+        else:
+            out = out.last_hidden_state
+            norm_shape = out.shape
+
+        # We normalize the output if required
+        if self.output_norm:
+            out = F.layer_norm(out, norm_shape[1:])
+
+        return out
+
+
+class Wav2Vec2Pretrain(HFTransformersInterface):
+    """This lobe enables the integration of HuggingFace
+     wav2vec2.0 models to be pretrained.
+
+    Source paper: https://arxiv.org/abs/2006.11477
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The return is an HuggingFace format and the mask indices that contains:
+    https://huggingface.co/transformers/model_doc/wav2vec2.html#wav2vec2forpretraining
+
+    For instance, it returns the loss that can be accessed with .loss
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+    save_path : str
+        Path (dir) of the downloaded model.
+    mask_prob : float (default: 0.65)
+        Probability of masking a given frame. Default is taken from the paper.
+    mask_length : float (default: 10)
+        Length (i.e. number of consecutive masked frames). Default is taken from
+        the paper.
+
+    Example
+    -------
+    >>> inputs = torch.rand([10, 32000])
+    >>> model_hub = "facebook/wav2vec2-base-960h"
+    >>> save_path = "savedir"
+    >>> model = Wav2Vec2Pretrain(model_hub, save_path)
+    >>> outputs, _ = model(inputs, wav_lens=None)
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        mask_prob=0.65,
+        mask_length=10,
+        normalize_wav=True,
+    ):
+        super().__init__(
+            source=source, save_path=save_path, for_pretraining=True
+        )
+
+        self.mask_prob = mask_prob
+        self.mask_length = mask_length
+        self.normalize_wav = normalize_wav
+
+        # We check if inputs need to be normalized w.r.t pretrained wav2vec2
+
+    def forward(self, wav, wav_lens=None):
+        """Takes an input waveform and return its corresponding wav2vec encoding.
+
+        Arguments
+        ---------
+        wav : torch.Tensor (signal)
+            A batch of audio signals to transform to features.
+        wav_len : tensor
+            The relative length of the wav given in SpeechBrain format.
+        """
+        batch_size, raw_sequence_length = wav.shape
+
+        if self.normalize_wav:
+            wav = F.layer_norm(wav, wav.shape)
+
+        sequence_length = self.model._get_feat_extract_output_lengths(
+            raw_sequence_length
+        ).item()
+
+        # 1. Compute the indices that will be masked
+        mask_time_indices = _compute_mask_indices(
+            (batch_size, sequence_length),
+            mask_prob=self.mask_prob,
+            mask_length=self.mask_length,
+        )
+        torch_mask_time_indices = torch.tensor(
+            mask_time_indices, device=wav.device, dtype=torch.long,
+        )
+        padding_mask = make_padding_masks(wav, wav_len=wav_lens)
+
+        # 2. Sample the negative samples from the entire sequence.
+        # Fairseq does it only on the masked indices, but this only work if you
+        # have long sentences. For more versatily, we sample on the entire sequence.
+        # value.
+        full_sentence_indices = np.ones((batch_size, sequence_length))
+
+        # print(np.sum(mask_time_indices, axis=1))
+        negative_sample_indices = torch.tensor(
+            transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices(
+                (batch_size, sequence_length),
+                num_negatives=self.config.num_negatives,
+                mask_time_indices=full_sentence_indices,
+            ),
+            device=wav.device,
+            dtype=torch.long,
+        )
+
+        return (
+            self.model(
+                wav,
+                mask_time_indices=torch_mask_time_indices,
+                sampled_negative_indices=negative_sample_indices,
+                attention_mask=padding_mask,
+            ),
+            torch_mask_time_indices,
+        )
+
+    def override_config(self, config):
+        """If the config needs to be overrided, here is the place
+
+        Arguments
+        ---------
+        config : Wav2Vec2Config
+            The original config needs to be overrided.
+
+        Returns
+        -------
+        Overridded config
+        """
+        config.output_hidden_states = True
+        return config
diff --git a/speechbrain/lobes/models/huggingface_transformers/wavlm.py b/speechbrain/lobes/models/huggingface_transformers/wavlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29a36fc63d2dc188db3a020da36f6a3c9bf1512
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/wavlm.py
@@ -0,0 +1,89 @@
+"""This lobe enables the integration of huggingface pretrained wavlm models.
+
+Reference: https://arxiv.org/abs/2006.11477
+Reference: https://arxiv.org/abs/1904.05862
+Reference: https://arxiv.org/abs/2110.13900
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Titouan Parcollet 2021
+ * Boumadane Abdelmoumene 2021
+ * Ha Nguyen 2023
+"""
+
+import logging
+
+from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import Wav2Vec2
+
+logger = logging.getLogger(__name__)
+
+
+class WavLM(Wav2Vec2):
+    """This lobe enables the integration of HuggingFace and SpeechBrain
+    pretrained WavLM models.
+
+    Source paper WavLM: https://arxiv.org/abs/2110.13900
+    Transformer from HuggingFace needs to be installed:
+    https://huggingface.co/transformers/installation.html
+
+    The model can be used as a fixed feature extractor or can be finetuned. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    For now, HuggingFace's HuBERT and WavLM model can be loaded using the exact code for Wav2Vec2 model.
+    For this reason, HuBERT and WavLM can be fine inheriting the Wav2Vec2 class.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "microsoft/wavlm-large"
+    save_path : str
+        Path (dir) of the downloaded model.
+    output_norm : bool (default: True)
+        If True, a layer_norm (affine) will be applied to the output obtained
+        from the wavlm model.
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    freeze_feature_extractor :  bool (default: False)
+        When freeze = False and freeze_feature_extractor True, the featue_extractor module of the model is Frozen. If False
+        all the wavlm model will be trained including featue_extractor module.
+    apply_spec_augment : bool (default: False)
+        If True, the model will apply spec augment on the output of feature extractor
+        (inside huggingface WavLMModel() class).
+        If False, the model will not apply spec augment. We set this to false to prevent from doing it twice.
+    output_all_hiddens : bool (default: False)
+        If True, the forward function outputs the hidden states from all transformer layers.
+        For example wavlm-base has 12 transformer layers and the output is of shape (13, B, T, C),
+        where a projection of the CNN output is added to the beginning.
+        If False, the forward function outputs the hidden states only from the last transformer layer.
+
+    Example
+    -------
+    >>> import torch
+    >>> inputs = torch.rand([10, 600])
+    >>> model_hub = "microsoft/wavlm-large"
+    >>> save_path = "savedir"
+    >>> model = WavLM(model_hub, save_path)
+    >>> outputs = model(inputs)
+    """
+
+    def __init__(
+        self,
+        source,
+        save_path,
+        output_norm=False,
+        freeze=False,
+        freeze_feature_extractor=False,
+        apply_spec_augment=False,
+        output_all_hiddens=False,
+    ):
+        super().__init__(
+            source=source,
+            save_path=save_path,
+            output_norm=output_norm,
+            freeze=freeze,
+            freeze_feature_extractor=freeze_feature_extractor,
+            apply_spec_augment=apply_spec_augment,
+            output_all_hiddens=output_all_hiddens,
+        )
diff --git a/speechbrain/lobes/models/huggingface_transformers/weighted_ssl.py b/speechbrain/lobes/models/huggingface_transformers/weighted_ssl.py
new file mode 100644
index 0000000000000000000000000000000000000000..00b039d01d7d8e04260a5804f7edba6055b6e43f
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_transformers/weighted_ssl.py
@@ -0,0 +1,105 @@
+"""This lobe enables the integration of huggingface pretrained wav2vec2 models.
+
+Reference: https://arxiv.org/abs/2006.11477
+Reference: https://arxiv.org/abs/1904.05862
+Reference: https://arxiv.org/abs/2110.13900
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Salah Zaiem 2023
+ * Adel Moumen 2023, 2024
+"""
+
+import torch
+import logging
+import torch.nn.functional as F
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    HFTransformersInterface,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class WeightedSSLModel(HFTransformersInterface):
+    """This lobe enables the integration of use of weighted sum representations
+    from different layers in a SSL encoder.
+
+    The model can be used as a fixed feature extractor for SSL benchmarking. It
+    will download automatically the model from HuggingFace or use a local path.
+
+    More details in recipes/SSL_benchmark
+
+    Arguments
+    ---------
+    hub : str
+        HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
+    save_path : str
+        Path (dir) of the downloaded model.
+    layernorm: bool, (default: False)
+        Whether layer representations should be layernormed before sum
+    freeze : bool (default: True)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+
+    Example
+    -------
+    >>> inputs = torch.rand([10, 600])
+    >>> model_hub = "facebook/wav2vec2-base-960h"
+    >>> save_path = "savedir"
+    >>> model = WeightedSSLModel(model_hub, save_path)
+    >>> outputs = model(inputs)
+    """
+
+    def __init__(self, hub, save_path="", layernorm=False, freeze=False):
+        super().__init__(source=hub, save_path=save_path, freeze=freeze)
+        self.model.eval()
+        self.num_layers = self.config.num_hidden_layers + 1
+        # Initializing the learnable weights
+        zero_init = torch.cat([torch.zeros(self.num_layers)])
+        self.weights = torch.nn.Parameter(zero_init, requires_grad=True)
+        self.layernorm = layernorm
+
+    def forward(self, wav, wav_lens=None):
+        """This method outputs a weighted sum of the layers representations of the SSL encoder
+
+        Arguments
+        ---------
+        wav : tensor
+            The wavs
+        wav_lens : tensor
+            The wav lengths
+        """
+
+        feats = self.model(wav)
+        hidden_states = torch.stack(feats.hidden_states, dim=0).detach()
+        # First dimension should be equal to the number of layers in the hparams
+        assert (
+            self.num_layers == hidden_states.shape[0]
+        ), "Num layers not equal to num hidden states"
+        norm_weights = torch.nn.functional.softmax(self.weights, dim=-1)
+        # Layernorming the layers representations if asked
+        if self.layernorm:
+            hidden_states = [
+                F.layer_norm(t, (t.shape[-1],)) for t in hidden_states
+            ]
+        # Summing the weighted layers
+        weighted_feats = (
+            hidden_states * norm_weights[:, None, None, None]
+        ).sum(axis=0)
+        return weighted_feats
+
+    def override_config(self, config):
+        """If the config needs to be overrided, here is the place
+
+        Arguments
+        ---------
+        config : Wav2Vec2Config
+            The original config needs to be overrided.
+
+        Returns
+        -------
+        Overridded config
+        """
+        config.output_hidden_states = True
+        return config
diff --git a/speechbrain/lobes/models/huggingface_whisper.py b/speechbrain/lobes/models/huggingface_transformers/whisper.py
similarity index 84%
rename from speechbrain/lobes/models/huggingface_whisper.py
rename to speechbrain/lobes/models/huggingface_transformers/whisper.py
index e37044eb7607b7a18c40c9d16d556b5c3c6175a8..f93bbebd5111f5c3edf01cb23a1e390ed851e1f4 100644
--- a/speechbrain/lobes/models/huggingface_whisper.py
+++ b/speechbrain/lobes/models/huggingface_transformers/whisper.py
@@ -7,28 +7,23 @@ Authors
  * Adel Moumen 2022
  * Titouan Parcollet 2022
  * Luca Della Libera 2022
+ * Ha Nguyen 2023
 """
 
 import torch
 import logging
 from torch import nn
 
-try:
-    from transformers import WhisperModel
-    from transformers import WhisperFeatureExtractor
-    from transformers.models.whisper.tokenization_whisper import (
-        WhisperTokenizer,
-    )
-except ImportError:
-    MSG = "Please install transformers from HuggingFace to use Whisper\n"
-    MSG += "E.G. run: pip install transformers"
-    raise ImportError(MSG)
+from speechbrain.lobes.models.huggingface_transformers.huggingface import (
+    HFTransformersInterface,
+)
 
 logger = logging.getLogger(__name__)
 
 
-class HuggingFaceWhisper(nn.Module):
+class Whisper(HFTransformersInterface):
     """This lobe enables the integration of HuggingFace pretrained Whisper model.
+
     Source paper whisper:
         https://cdn.openai.com/papers/whisper.pdf
     Transformer from HuggingFace needs to be installed:
@@ -39,6 +34,7 @@ class HuggingFaceWhisper(nn.Module):
 
     The model can be finetuned. It will download automatically the model from
     HuggingFace or use a local path.
+
     Arguments
     ---------
     source : str
@@ -61,12 +57,13 @@ class HuggingFaceWhisper(nn.Module):
         For example whisper-base has 6 transformer layers and the output is of shape (7, B, T, C),
         where the output of the CNN output is added to the beginning.
         If False, the forward function outputs the hidden states only from the last transformer layer of the encoder.
+
     Example
     -------
     >>> model_hub = "openai/whisper-tiny"
     >>> save_path = "savedir"
     >>> sampling_rate = 16000
-    >>> model = HuggingFaceWhisper(model_hub, save_path, sampling_rate)
+    >>> model = Whisper(model_hub, save_path, sampling_rate)
     >>> tokens = torch.tensor([[1, 1]]) * model.model.config.decoder_start_token_id
     >>> inputs = torch.randn([1, 93680])
     >>> outputs = model(inputs, tokens)
@@ -83,56 +80,66 @@ class HuggingFaceWhisper(nn.Module):
         output_attentions=True,
         output_all_hiddens=False,
     ):
-        super().__init__()
+        super().__init__(
+            source=source,
+            save_path=save_path,
+            freeze=freeze,
+            sampling_rate=sampling_rate,
+        )
         self.sampling_rate = sampling_rate
         self.encoder_only = encoder_only
-        self.freeze = freeze
         self.freeze_encoder = freeze_encoder
         self.output_attentions = output_attentions
         self.output_all_hiddens = output_all_hiddens
 
-        self.tokenizer = None
-        # Download the tokenizer only if we are going to use the Decoder.
-        if not encoder_only:
-            self.tokenizer = WhisperTokenizer.from_pretrained(source)
+        if encoder_only:
+            self.tokenizer = None
+        else:
+            self.load_tokenizer(source)
 
-        # Download the extractor from HuggingFace.
-        feature_extractor = WhisperFeatureExtractor.from_pretrained(
-            source, cache_dir=save_path, sampling_rate=sampling_rate,
+        self.load_feature_extractor(
+            source, save_path, sampling_rate=sampling_rate
         )
-        self._n_fft = feature_extractor.n_fft
-        self._hop_length = feature_extractor.hop_length
-        self._n_samples = feature_extractor.n_samples
+
+        self._n_fft = self.feature_extractor.n_fft
+        self._hop_length = self.feature_extractor.hop_length
+        self._n_samples = self.feature_extractor.n_samples
         # The following breaking changes were introduced in transformers>=4.29:
         # 1) mel_filters.shape = (..., feature_extractor.feature_size) instead of (feature_extractor.feature_size, ...)
         # 2) mel_filters.dtype = float64 instead of float32
         # The following code fixes the issue in a backward compatible way
-        mel_filters = feature_extractor.mel_filters
-        if mel_filters.shape[0] != feature_extractor.feature_size:
+        mel_filters = self.feature_extractor.mel_filters
+        if mel_filters.shape[0] != self.feature_extractor.feature_size:
             mel_filters = mel_filters.T
-        assert mel_filters.shape[0] == feature_extractor.feature_size
+        assert mel_filters.shape[0] == self.feature_extractor.feature_size
         self.register_buffer(
             "_mel_filters", torch.as_tensor(mel_filters, dtype=torch.float32)
         )
         #################################################################
 
-        self.model = WhisperModel.from_pretrained(source, cache_dir=save_path)
-
-        if self.freeze:
+        if not self.freeze and self.freeze_encoder:
             logger.warning(
-                "speechbrain.lobes.models.huggingface_whisper - whisper encoder-decoder is frozen."
+                "speechbrain.lobes.models.huggingface_transformers.whisper - whisper encoder is frozen."
             )
-            self.model.train()  # we keep it to train to have dropout and LN computed adequaly
-            for param in self.model.parameters():
+            for param in self.model.encoder.parameters():
                 param.requires_grad = False
-        else:
-            self.model.train()
-            if self.freeze_encoder:
-                logger.warning(
-                    "speechbrain.lobes.models.huggingface_whisper - whisper encoder is frozen."
-                )
-                for param in self.model.encoder.parameters():
-                    param.requires_grad = False
+
+    def freeze_model(self, model):
+        """
+        Freezes parameters of a model.
+
+        Arguments
+        ---------
+        model : from AutoModel.from_config
+            Valid HuggingFace transformers model object.
+        """
+
+        logger.warning(
+            "speechbrain.lobes.models.huggingface_transformers.whisper - whisper encoder-decoder is frozen."
+        )
+        model.train()  # we keep it to train to have dropout and LN computed adequaly
+        for param in model.parameters():
+            param.requires_grad = False
 
     def forward(self, wav, decoder_input_ids=None):
         """Perform mel transformation and one step of the whisper (encoder-decoder).
@@ -185,6 +192,7 @@ class HuggingFaceWhisper(nn.Module):
 
     def forward_encoder(self, wav):
         """Perform one step of the whisper encoder with Mel FBANKs as Input.
+
         Arguments
         ---------
         wav : torch.Tensor (FBANKs)
@@ -201,6 +209,7 @@ class HuggingFaceWhisper(nn.Module):
         """Takes an input waveform and return its corresponding encoder states.
         Returns the last hidden state of the encoder or all hidden states if
         output_all_hiddens is True.
+
         Arguments
         ---------
         wav : torch.Tensor (signal)
@@ -217,6 +226,7 @@ class HuggingFaceWhisper(nn.Module):
         """Takes an input waveform and return its corresponding mel spectrogram
         according to HuggingFace implementation. WARNING: it's slow! Better push this
         in the DataLoader.
+
         Arguments
         ---------
         wav : torch.Tensor (signal)
@@ -301,6 +311,7 @@ class HuggingFaceWhisper(nn.Module):
 
     def forward_decoder(self, audio_features, decoder_input_ids):
         """Perform one step of the whisper decoder.
+
         Arguments
         ---------
         audio_features : torch.Tensor
diff --git a/speechbrain/lobes/models/huggingface_wav2vec.py b/speechbrain/lobes/models/huggingface_wav2vec.py
deleted file mode 100644
index 97bcd9341e12f918b2104687f17f4727c032ba13..0000000000000000000000000000000000000000
--- a/speechbrain/lobes/models/huggingface_wav2vec.py
+++ /dev/null
@@ -1,569 +0,0 @@
-"""This lobe enables the integration of huggingface pretrained wav2vec2/hubert/wavlm models.
-
-Reference: https://arxiv.org/abs/2006.11477
-Reference: https://arxiv.org/abs/1904.05862
-Reference: https://arxiv.org/abs/2110.13900
-Transformer from HuggingFace needs to be installed:
-https://huggingface.co/transformers/installation.html
-
-Authors
- * Titouan Parcollet 2021
- * Boumadane Abdelmoumene 2021
-"""
-
-import os
-import torch
-import logging
-import pathlib
-import numpy as np
-import torch.nn.functional as F
-from torch import nn
-from huggingface_hub import model_info
-from speechbrain.pretrained.fetching import fetch
-from speechbrain.dataio.dataio import length_to_mask
-
-# We check if transformers is installed.
-try:
-    import transformers
-    from transformers import AutoModel
-    from transformers import Wav2Vec2Model, HubertModel, WavLMModel
-    from transformers import Wav2Vec2Config, HubertConfig, WavLMConfig
-    from transformers import Wav2Vec2FeatureExtractor
-    from transformers import Wav2Vec2ForPreTraining
-    from transformers.models.wav2vec2.modeling_wav2vec2 import (
-        _compute_mask_indices,
-    )
-
-except ImportError:
-    MSG = "Please install transformers from HuggingFace to use wav2vec2 / Hubert\n"
-    MSG += "E.G. run: pip install transformers"
-    raise ImportError(MSG)
-
-logger = logging.getLogger(__name__)
-
-HF_models = {
-    "wav2vec2": Wav2Vec2Model,
-    "hubert": HubertModel,
-    "wavlm": WavLMModel,
-}
-
-HF_config = {
-    "wav2vec2": Wav2Vec2Config,
-    "hubert": HubertConfig,
-    "wavlm": WavLMConfig,
-}
-
-
-class HuggingFaceWav2Vec2(nn.Module):
-    """This lobe enables the integration of HuggingFace and SpeechBrain
-    pretrained wav2vec2.0/Hubert models.
-
-    Source paper wav2vec2.0: https://arxiv.org/abs/2006.11477
-    Source paper Hubert: https://arxiv.org/abs/2106.07447
-    Transformer from HuggingFace needs to be installed:
-    https://huggingface.co/transformers/installation.html
-
-    The model can be used as a fixed feature extractor or can be finetuned. It
-    will download automatically the model from HuggingFace or use a local path.
-
-    Arguments
-    ---------
-    source : str
-        HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
-    save_path : str
-        Path (dir) of the downloaded model.
-    output_norm : bool (default: True)
-        If True, a layer_norm (affine) will be applied to the output obtained
-        from the wav2vec model.
-    freeze : bool (default: True)
-        If True, the model is frozen. If False, the model will be trained
-        alongside with the rest of the pipeline.
-    freeze_feature_extractor :  bool (default: False)
-        When freeze = False and freeze_feature_extractor True, the featue_extractor module of the model is Frozen. If False
-        all the wav2vec model will be trained including featue_extractor module.
-    apply_spec_augment : bool (default: False)
-        If True, the model will apply spec augment on the output of feature extractor
-        (inside huggingface Wav2VecModel() class).
-        If False, the model will not apply spec augment. We set this to false to prevent from doing it twice.
-    output_all_hiddens : bool (default: False)
-        If True, the forward function outputs the hidden states from all transformer layers.
-        For example wav2vec2-base has 12 transformer layers and the output is of shape (13, B, T, C),
-        where a projection of the CNN output is added to the beginning.
-        If False, the forward function outputs the hidden states only from the last transformer layer.
-
-    Example
-    -------
-    >>> inputs = torch.rand([10, 600])
-    >>> model_hub = "facebook/wav2vec2-base-960h"
-    >>> save_path = "savedir"
-    >>> model = HuggingFaceWav2Vec2(model_hub, save_path)
-    >>> outputs = model(inputs)
-    """
-
-    def __init__(
-        self,
-        source,
-        save_path,
-        output_norm=False,
-        freeze=False,
-        freeze_feature_extractor=False,
-        apply_spec_augment=False,
-        output_all_hiddens=False,
-    ):
-        super().__init__()
-
-        # Download the extractor from HuggingFace.
-        # The extractor is only used to retrieve the normalisation information
-        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
-            source, cache_dir=save_path
-        )
-
-        # Select specific self-supervised loader (eg. Wav2Vec2, Hubert)
-        if "hubert" in source:
-            config = HF_config.get("hubert")
-            model = HF_models.get("hubert")
-        elif "wavlm" in source:
-            config = HF_config.get("wavlm")
-            model = HF_models.get("wavlm")
-        else:
-            config = HF_config.get("wav2vec2")
-            model = HF_models.get("wav2vec2")
-
-        # Download and load the model
-        self._from_pretrained(
-            source, config=config, model=model, save_path=save_path
-        )
-
-        self.model.config.apply_spec_augment = apply_spec_augment
-
-        # We check if inputs need to be normalized w.r.t pretrained wav2vec2
-        self.normalize_wav = self.feature_extractor.do_normalize
-
-        self.freeze = freeze
-        self.freeze_feature_extractor = freeze_feature_extractor
-        self.output_norm = output_norm
-        if self.freeze:
-            logger.warning(
-                "speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 is frozen."
-            )
-            self.model.eval()
-            for param in self.model.parameters():
-                param.requires_grad = False
-        else:
-            self.model.train()
-            if self.freeze_feature_extractor:
-                logger.warning(
-                    "speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen."
-                )
-                self.model.feature_extractor.eval()
-                for param in self.model.feature_extractor.parameters():
-                    param.requires_grad = False
-        self.output_all_hiddens = output_all_hiddens
-
-    def _from_pretrained(self, source, config, model, save_path):
-        """This function manages the source checking and loading of the params.
-        # 1. Is the model from HF or a local path
-        # 2. Is the model pretrained with HF or SpeechBrain
-        # 3. Download (if appropriate) and load with respect to 1. and 2.
-        """
-
-        is_sb, ckpt_file, is_local = self._check_model_source(source, save_path)
-        if is_sb:
-            config = config.from_pretrained(source, cache_dir=save_path)
-            self.model = model(config)
-            self.model.gradient_checkpointing_disable()  # Required by DDP
-            # fetch the checkpoint file
-            ckpt_full_path = fetch(
-                filename=ckpt_file, source=source, savedir=save_path
-            )
-            # We transfer the parameters from the checkpoint.
-            self._load_sb_pretrained_w2v2_parameters(ckpt_full_path)
-        else:
-            self.model = model.from_pretrained(
-                source, cache_dir=save_path, local_files_only=is_local
-            )
-
-    def _load_sb_pretrained_w2v2_parameters(self, path):
-        """Loads the parameter of a w2v2 model pretrained with SpeechBrain and the
-        HuggingFaceWav2Vec2Pretrain Object. It is necessary to perform a custom
-        loading because HuggingFace adds a level to the checkpoint when storing
-        the model breaking the compatibility between HuggingFaceWav2Vec2Pretrain
-        and HuggingFaceWav2Vec2.
-
-        In practice a typical HuggingFaceWav2Vec2 checkpoint for a given parameter
-        would be: model.conv.weight.data while for HuggingFaceWav2Vec2Pretrain it
-        is: model.wav2vec2.weight.data (wav2vec2 must be removed before loading).
-        """
-
-        modified_state_dict = {}
-        orig_state_dict = torch.load(path, map_location="cpu")
-
-        # We remove the .wav2vec2 in the state dict.
-        for key, params in orig_state_dict.items():
-            if "wav2vec2." in key:
-                save_key = key.replace("model.wav2vec2.", "")
-                modified_state_dict[save_key] = params
-
-        incompatible_keys = self.model.load_state_dict(
-            modified_state_dict, strict=False
-        )
-        for missing_key in incompatible_keys.missing_keys:
-            logger.warning(
-                f"During parameter transfer to {self.model} loading from "
-                + f"{path}, the transferred parameters did not have "
-                + f"parameters for the key: {missing_key}"
-            )
-        for unexpected_key in incompatible_keys.unexpected_keys:
-            logger.warning(
-                f"The param with the key: {unexpected_key} is discarded as it "
-                + "is useless for wav2vec 2.0 finetuning."
-            )
-
-    def _check_model_source(self, path, save_path):
-        """Checks if the pretrained model has been trained with SpeechBrain and
-        is hosted locally or on a HuggingFace hub.
-        Called as static function in HuggingFaceTransformer._from_pretrained.
-        Arguments
-        ---------
-        path : str
-            Used as "source"; local path or HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
-        save_path : str
-            norm_output (dir) of the downloaded model.
-        Returns
-        -------
-        is_sb : bool
-            Whether/not the model is deserializable w/ SpeechBrain or not (then, model conversion is needed).
-        checkpoint_filename : str
-            as of HuggingFace documentation: file name relative to the repo root (guaranteed to be here).
-        """
-        checkpoint_filename = ""
-        source = pathlib.Path(path)
-        is_local = True
-
-        # If path is a huggingface hub.
-        if not source.exists():
-            is_local = False
-
-        # Check if source is downloaded already
-        sink = pathlib.Path(
-            save_path + "/models--" + path.replace("/", "--") + "/snapshots"
-        )
-        if sink.exists():
-            sink = (
-                sink / os.listdir(str(sink))[0]
-            )  # there's a hash-id subfolder
-            if any(
-                File.endswith(".bin") or File.endswith(".ckpt")
-                for File in os.listdir(str(sink))
-            ):
-                is_local = True
-                local_path = str(sink)
-            else:
-                local_path = path
-        else:
-            local_path = path
-
-        if is_local:
-            # Test for HuggingFace model
-            if any(File.endswith(".bin") for File in os.listdir(local_path)):
-                is_sb = False
-                return is_sb, checkpoint_filename, is_local
-
-            # Test for SpeechBrain model and get the filename.
-            for File in os.listdir(local_path):
-                if File.endswith(".ckpt"):
-                    checkpoint_filename = os.path.join(path, File)
-                    is_sb = True
-                    return is_sb, checkpoint_filename, is_local
-        else:
-            files = model_info(
-                path
-            ).siblings  # get the list of files of the Hub
-
-            # Test if it's an HuggingFace model or a SB one
-            for File in files:
-                if File.rfilename.endswith(".ckpt"):
-                    checkpoint_filename = File.rfilename
-                    is_sb = True
-                    return is_sb, checkpoint_filename, is_local
-
-            for File in files:
-                if File.rfilename.endswith(".bin"):
-                    checkpoint_filename = File.rfilename
-                    is_sb = False
-                    return is_sb, checkpoint_filename, is_local
-
-        err_msg = f"{path} does not contain a .bin or .ckpt checkpoint !"
-        raise FileNotFoundError(err_msg)
-
-    def forward(self, wav, wav_lens=None):
-        """Takes an input waveform and return its corresponding wav2vec encoding.
-
-        Arguments
-        ---------
-        wav : torch.Tensor (signal)
-            A batch of audio signals to transform to features.
-        wav_len : tensor
-            The relative length of the wav given in SpeechBrain format.
-        """
-
-        # If we freeze, we simply remove all grads from the graph.
-        if self.freeze:
-            with torch.no_grad():
-                return self.extract_features(wav, wav_lens)
-
-        return self.extract_features(wav, wav_lens)
-
-    def extract_features(self, wav, wav_lens=None):
-        """Takes an input waveform and return its corresponding wav2vec encoding.
-
-        Arguments
-        ---------
-        wav : torch.Tensor (signal)
-            A batch of audio signals to transform to features.
-        wav_len : tensor
-            The relative length of the wav given in SpeechBrain format.
-        """
-
-        padding_mask = self.make_masks(wav, wav_len=wav_lens)
-
-        if self.normalize_wav:
-            wav = F.layer_norm(wav, wav.shape[1:])
-
-        # Extract wav2vec output
-        out = self.model(
-            wav,
-            attention_mask=padding_mask,
-            output_hidden_states=self.output_all_hiddens,
-        )
-
-        if self.output_all_hiddens:
-            out = torch.stack(list(out.hidden_states), dim=0)
-            norm_shape = out.shape[-3:]
-        else:
-            out = out.last_hidden_state
-            norm_shape = out.shape
-
-        # We normalize the output if required
-        if self.output_norm:
-            out = F.layer_norm(out, norm_shape[1:])
-
-        return out
-
-    def make_masks(self, src, wav_len=None, pad_idx=0):
-        """This method generates the padding masks.
-        Arguments
-        ---------
-        src : tensor
-            The sequence to the encoder (required).
-        wav_len : tensor
-            The relative length of the wav given in SpeechBrain format.
-        pad_idx : int
-            The index for <pad> token (default=0).
-        """
-        src_key_padding_mask = None
-        if wav_len is not None:
-            abs_len = torch.round(wav_len * src.shape[1])
-            src_key_padding_mask = length_to_mask(abs_len).bool()
-
-        return src_key_padding_mask
-
-
-class HuggingFaceWav2Vec2Pretrain(nn.Module):
-    """This lobe enables the integration of HuggingFace
-     wav2vec2.0 models to be pretrained.
-
-    Source paper: https://arxiv.org/abs/2006.11477
-    Transformer from HuggingFace needs to be installed:
-    https://huggingface.co/transformers/installation.html
-
-    The return is an HuggingFace format and the mask indices that contains:
-    https://huggingface.co/transformers/model_doc/wav2vec2.html#wav2vec2forpretraining
-
-    For instance, it returns the loss that can be accessed with .loss
-
-    Arguments
-    ---------
-    source : str
-        HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
-    save_path : str
-        Path (dir) of the downloaded model.
-    mask_prob : float (default: 0.65)
-        Probability of masking a given frame. Default is taken from the paper.
-    mask_length : float (default: 10)
-        Length (i.e. number of consecutive masked frames). Default is taken from
-        the paper.
-    Example
-    -------
-    >>> inputs = torch.rand([10, 32000])
-    >>> model_hub = "facebook/wav2vec2-base-960h"
-    >>> save_path = "savedir"
-    >>> model = HuggingFaceWav2Vec2Pretrain(model_hub, save_path)
-    >>> outputs, _ = model(inputs, wav_lens=None)
-    """
-
-    def __init__(
-        self,
-        source,
-        save_path,
-        mask_prob=0.65,
-        mask_length=10,
-        normalize_wav=True,
-    ):
-        super().__init__()
-
-        self.mask_prob = mask_prob
-        self.mask_length = mask_length
-        self.normalize_wav = normalize_wav
-
-        # Download the config of the model from HuggingFace.
-        self.config = Wav2Vec2Config.from_pretrained(
-            source, cache_dir=save_path
-        )
-        self.config.output_hidden_states = (
-            True  # We want the hidden states as well!
-        )
-
-        self.model = Wav2Vec2ForPreTraining(self.config)
-        self.model.gradient_checkpointing_disable()  # Required by DDP
-        self.model.train()
-
-        # We check if inputs need to be normalized w.r.t pretrained wav2vec2
-
-    def forward(self, wav, wav_lens=None):
-        """Takes an input waveform and return its corresponding wav2vec encoding.
-
-        Arguments
-        ---------
-        wav : torch.Tensor (signal)
-            A batch of audio signals to transform to features.
-        wav_len : tensor
-            The relative length of the wav given in SpeechBrain format.
-        """
-        batch_size, raw_sequence_length = wav.shape
-
-        if self.normalize_wav:
-            wav = F.layer_norm(wav, wav.shape)
-
-        sequence_length = self.model._get_feat_extract_output_lengths(
-            raw_sequence_length
-        ).item()
-
-        # 1. Compute the indices that will be masked
-        mask_time_indices = _compute_mask_indices(
-            (batch_size, sequence_length),
-            mask_prob=self.mask_prob,
-            mask_length=self.mask_length,
-        )
-        torch_mask_time_indices = torch.tensor(
-            mask_time_indices, device=wav.device, dtype=torch.long,
-        )
-        padding_mask = self.make_padding_masks(wav, wav_len=wav_lens)
-
-        # 2. Sample the negative samples from the entire sequence.
-        # Fairseq does it only on the masked indices, but this only work if you
-        # have long sentences. For more versatily, we sample on the entire sequence.
-        # value.
-        full_sentence_indices = np.ones((batch_size, sequence_length))
-
-        # print(np.sum(mask_time_indices, axis=1))
-        negative_sample_indices = torch.tensor(
-            transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices(
-                (batch_size, sequence_length),
-                num_negatives=self.config.num_negatives,
-                mask_time_indices=full_sentence_indices,
-            ),
-            device=wav.device,
-            dtype=torch.long,
-        )
-
-        return (
-            self.model(
-                wav,
-                mask_time_indices=torch_mask_time_indices,
-                sampled_negative_indices=negative_sample_indices,
-                attention_mask=padding_mask,
-            ),
-            torch_mask_time_indices,
-        )
-
-    def make_padding_masks(self, src, wav_len=None, pad_idx=0):
-        """This method generates the padding masks.
-        Arguments
-        ---------
-        src : tensor
-            The sequence to the encoder (required).
-        wav_len : tensor
-            The relative length of the wav given in SpeechBrain format.
-        pad_idx : int
-            The index for <pad> token (default=0).
-        """
-        src_key_padding_mask = None
-        if wav_len is not None:
-            abs_len = torch.round(wav_len * src.shape[1])
-            src_key_padding_mask = length_to_mask(abs_len).bool()
-
-        return src_key_padding_mask
-
-
-class WeightedSSLModel(torch.nn.Module):
-    """This lobe enables the integration of use of weighted sum representations
-    from different layers in a SSL encoder.
-
-    The model can be used as a fixed feature extractor for SSL benchmarking. It
-    will download automatically the model from HuggingFace or use a local path.
-
-    More details in recipes/SSL_benchmark
-
-    Arguments
-    ---------
-    hub : str
-        HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60"
-    num_layers: int
-        Number of internal layers: e.g 13 for "Base" models.
-    layernorm: bool
-        Whether layer representations should be layernormed before sum
-    Example
-    -------
-    >>> inputs = torch.rand([10, 600])
-    >>> model_hub = "facebook/wav2vec2-base-960h"
-    >>> num_layers = 13
-    >>> model = WeightedSSLModel(model_hub, num_layers)
-    >>> outputs = model(inputs)
-    """
-
-    def __init__(self, hub, num_layers, layernorm=False):
-        super().__init__()
-        self.encoder = AutoModel.from_pretrained(hub, output_hidden_states=True)
-        self.num_layers = num_layers
-        # Initializing the learnable weights
-        zero_init = torch.cat([torch.zeros(self.num_layers)])
-        self.weights = torch.nn.Parameter(zero_init, requires_grad=True)
-        self.layernorm = layernorm
-
-    def forward(self, wav, wav_lens=None):
-        """This method outputs a weighted sum of the layers representations of the SSL encoder
-        Arguments
-        ---------
-        wav : tensor
-            The wavs
-        """
-
-        feats = self.encoder(wav)
-        hidden_states = torch.stack(feats.hidden_states, dim=0).detach()
-        # First dimension should be equal to the number of layers in the hparams
-        assert (
-            self.num_layers == hidden_states.shape[0]
-        ), "Num layers not equal to num hidden states"
-        norm_weights = torch.nn.functional.softmax(self.weights, dim=-1)
-        # Layernorming the layers representations if asked
-        if self.layernorm:
-            hidden_states = [
-                F.layer_norm(t, (t.shape[-1],)) for t in hidden_states
-            ]
-        # Summing the weighted layers
-        weighted_feats = hidden_states[0] * norm_weights[0]
-        for i in range(1, len(hidden_states)):
-            weighted_feats += hidden_states[i] * norm_weights[i]
-
-        return weighted_feats
diff --git a/speechbrain/lobes/models/transformer/Branchformer.py b/speechbrain/lobes/models/transformer/Branchformer.py
index 4de4aa535666ad7da125ed0b895c88c80fd90ab5..34ba5d20595167bab41916f4a1e7cee6119365ac 100644
--- a/speechbrain/lobes/models/transformer/Branchformer.py
+++ b/speechbrain/lobes/models/transformer/Branchformer.py
@@ -17,6 +17,8 @@ from speechbrain.nnet.attention import RelPosMHAXL, MultiheadAttention
 from speechbrain.nnet.normalization import LayerNorm
 from speechbrain.lobes.models.convolution import ConvolutionalSpatialGatingUnit
 
+from speechbrain.nnet.hypermixing import HyperMixing
+
 
 class ConvolutionBranch(nn.Module):
     """This is an implementation of the convolution branch in Branchformer.
@@ -158,6 +160,14 @@ class BranchformerEncoderLayer(nn.Module):
                 dropout=dropout,
                 mask_pos_future=False,
             )
+        elif attention_type == "hypermixing":
+            self.mha_layer = HyperMixing(
+                input_output_dim=d_model,
+                hypernet_size=d_model * 4,
+                tied=False,
+                num_heads=nhead,
+                fix_tm_hidden_size=False,
+            )
 
         self.convolution_branch = ConvolutionBranch(
             input_size=d_model,
@@ -310,6 +320,7 @@ class BranchformerEncoder(nn.Module):
         src_mask: Optional[torch.Tensor] = None,
         src_key_padding_mask: Optional[torch.Tensor] = None,
         pos_embs: Optional[torch.Tensor] = None,
+        dynchunktrain_config=None,
     ):
         """
         Arguments
@@ -325,6 +336,9 @@ class BranchformerEncoder(nn.Module):
             If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
             where S is the sequence length, and E is the embedding dimension.
         """
+        assert (
+            dynchunktrain_config is None
+        ), "Dynamic Chunk Training unsupported for this encoder"
 
         if self.attention_type == "RelPosMHAXL":
             if pos_embs is None:
diff --git a/speechbrain/lobes/models/transformer/Conformer.py b/speechbrain/lobes/models/transformer/Conformer.py
old mode 100644
new mode 100755
index c4c6e8496ecf7e05f40f5b0219177ffee4e0093c..d0aaab8f64d2c431160f4a17b7e147e73d2e28b7
--- a/speechbrain/lobes/models/transformer/Conformer.py
+++ b/speechbrain/lobes/models/transformer/Conformer.py
@@ -1,13 +1,17 @@
 """Conformer implementation.
 
 Authors
+-------
 * Jianyuan Zhong 2020
 * Samuele Cornell 2021
+* Sylvain de Langen 2023
 """
 
+from dataclasses import dataclass
 import torch
 import torch.nn as nn
-from typing import Optional
+import torch.nn.functional as F
+from typing import Optional, List
 import speechbrain as sb
 import warnings
 
@@ -17,10 +21,55 @@ from speechbrain.nnet.attention import (
     MultiheadAttention,
     PositionalwiseFeedForward,
 )
+from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+from speechbrain.nnet.hypermixing import HyperMixing
 from speechbrain.nnet.normalization import LayerNorm
 from speechbrain.nnet.activations import Swish
 
 
+@dataclass
+class ConformerEncoderLayerStreamingContext:
+    """Streaming metadata and state for a `ConformerEncoderLayer`.
+
+    The multi-head attention and Dynamic Chunk Convolution require to save some
+    left context that gets inserted as left padding.
+
+    See :class:`.ConvolutionModule` documentation for further details.
+    """
+
+    mha_left_context_size: int
+    """For this layer, specifies how many frames of inputs should be saved.
+    Usually, the same value is used across all layers, but this can be modified.
+    """
+
+    mha_left_context: Optional[torch.Tensor] = None
+    """Left context to insert at the left of the current chunk as inputs to the
+    multi-head attention. It can be `None` (if we're dealing with the first
+    chunk) or `<= mha_left_context_size` because for the first few chunks, not
+    enough left context may be available to pad.
+    """
+
+    dcconv_left_context: Optional[torch.Tensor] = None
+    """Left context to insert at the left of the convolution according to the
+    Dynamic Chunk Convolution method.
+
+    Unlike `mha_left_context`, here the amount of frames to keep is fixed and
+    inferred from the kernel size of the convolution module.
+    """
+
+
+@dataclass
+class ConformerEncoderStreamingContext:
+    """Streaming metadata and state for a `ConformerEncoder`."""
+
+    dynchunktrain_config: DynChunkTrainConfig
+    """Dynamic Chunk Training configuration holding chunk size and context size
+    information."""
+
+    layers: List[ConformerEncoderLayerStreamingContext]
+    """Streaming metadata and state for each layer of the encoder."""
+
+
 class ConvolutionModule(nn.Module):
     """This is an implementation of convolution module in Conformer.
 
@@ -63,7 +112,9 @@ class ConvolutionModule(nn.Module):
     ):
         super().__init__()
 
+        self.kernel_size = kernel_size
         self.causal = causal
+        self.dilation = dilation
 
         if self.causal:
             self.padding = (kernel_size - 1) * 2 ** (dilation - 1)
@@ -90,6 +141,9 @@ class ConvolutionModule(nn.Module):
             bias=bias,
         )
 
+        # NOTE: there appears to be a mismatch compared to the Conformer paper:
+        # I believe the first LayerNorm below is supposed to be a BatchNorm.
+
         self.after_conv = nn.Sequential(
             nn.LayerNorm(input_size),
             activation(),
@@ -98,20 +152,172 @@ class ConvolutionModule(nn.Module):
             nn.Dropout(dropout),
         )
 
-    def forward(self, x, mask=None):
-        """ Processes the input tensor x and returns the output an output tensor"""
-        out = self.layer_norm(x)
-        out = out.transpose(1, 2)
-        out = self.bottleneck(out)
-        out = self.conv(out)
+    def forward(
+        self,
+        x: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
+    ):
+        """Applies the convolution to an input tensor `x`.
+
+        Arguments
+        ---------
+        x: torch.Tensor
+            Input tensor to the convolution module.
+        mask: torch.Tensor, optional
+            Mask to be applied over the output of the convolution using
+            `masked_fill_`, if specified.
+        dynchunktrain_config: DynChunkTrainConfig, optional
+            If specified, makes the module support Dynamic Chunk Convolution
+            (DCConv) as implemented by
+            `Dynamic Chunk Convolution for Unified Streaming and Non-Streaming Conformer ASR <https://www.amazon.science/publications/dynamic-chunk-convolution-for-unified-streaming-and-non-streaming-conformer-asr>`_.
+            This allows masking future frames while preserving better accuracy
+            than a fully causal convolution, at a small speed cost.
+            This should only be used for training (or, if you know what you're
+            doing, for masked evaluation at inference time), as the forward
+            streaming function should be used at inference time.
+            """
+
+        if dynchunktrain_config is not None:
+            # chances are chunking+causal is unintended; i don't know where it
+            # may make sense, but if it does to you, feel free to implement it.
+            assert (
+                not self.causal
+            ), "Chunked convolution not supported with causal padding"
+
+            assert (
+                self.dilation == 1
+            ), "Current DynChunkTrain logic does not support dilation != 1"
+
+            # in a causal convolution, which is not the case here, an output
+            # frame would never be able to depend on a input frame from any
+            # point in the future.
+
+            # but with the dynamic chunk convolution, we instead use a "normal"
+            # convolution but where, for any output frame, the future beyond the
+            # "current" chunk gets masked.
+            # see the paper linked in the documentation for details.
+
+            chunk_size = dynchunktrain_config.chunk_size
+            batch_size = x.shape[0]
+
+            # determine the amount of padding we need to insert at the right of
+            # the last chunk so that all chunks end up with the same size.
+            if x.shape[1] % chunk_size != 0:
+                final_right_padding = chunk_size - (x.shape[1] % chunk_size)
+            else:
+                final_right_padding = 0
+
+            # -> [batch_size, t, in_channels]
+            out = self.layer_norm(x)
+
+            # -> [batch_size, in_channels, t] for the CNN
+            out = out.transpose(1, 2)
+
+            # -> [batch_size, in_channels, t] (pointwise)
+            out = self.bottleneck(out)
+
+            # -> [batch_size, in_channels, lc+t+final_right_padding]
+            out = F.pad(out, (self.padding, final_right_padding), value=0)
+
+            # now, make chunks with left context.
+            # as a recap to what the above padding and this unfold do, consider
+            # each a/b/c letter represents a frame as part of chunks a, b, c.
+            # consider a chunk size of 4 and a kernel size of 5 (padding=2):
+            #
+            # input seq: 00aaaabbbbcc00
+            # chunk #1:  00aaaa
+            # chunk #2:      aabbbb
+            # chunk #3:          bbcc00
+            #
+            # a few remarks here:
+            # - the left padding gets inserted early so that the unfold logic
+            #   works trivially
+            # - the right 0-padding got inserted as the number of time steps
+            #   could not be evenly split in `chunk_size` chunks
+
+            # -> [batch_size, in_channels, num_chunks, lc+chunk_size]
+            out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size)
+
+            # as we manually disable padding in the convolution below, we insert
+            # right 0-padding to the chunks, e.g. reusing the above example:
+            #
+            # chunk #1:  00aaaa00
+            # chunk #2:      aabbbb00
+            # chunk #3:          bbcc0000
+
+            # -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad]
+            out = F.pad(out, (0, self.padding), value=0)
+
+            # the transpose+flatten effectively flattens chunks into the batch
+            # dimension to be processed into the time-wise convolution. the
+            # chunks will later on be unflattened.
+
+            # -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad]
+            out = out.transpose(1, 2)
+
+            # -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad]
+            out = out.flatten(start_dim=0, end_dim=1)
+
+            # TODO: experiment around reflect padding, which is difficult
+            # because small chunks have too little time steps to reflect from
+
+            # let's keep backwards compat by pointing at the weights from the
+            # already declared Conv1d.
+            #
+            # still reusing the above example, the convolution will be applied,
+            # with the padding truncated on both ends. the following example
+            # shows the letter corresponding to the input frame on which the
+            # convolution was centered.
+            #
+            # as you can see, the sum of lengths of all chunks is equal to our
+            # input sequence length + `final_right_padding`.
+            #
+            # chunk #1:  aaaa
+            # chunk #2:      bbbb
+            # chunk #3:          cc00
+
+            # -> [batch_size * num_chunks, out_channels, chunk_size]
+            out = F.conv1d(
+                out,
+                weight=self.conv.weight,
+                bias=self.conv.bias,
+                stride=self.conv.stride,
+                padding=0,
+                dilation=self.conv.dilation,
+                groups=self.conv.groups,
+            )
+
+            # -> [batch_size * num_chunks, chunk_size, out_channels]
+            out = out.transpose(1, 2)
+
+            out = self.after_conv(out)
+
+            # -> [batch_size, num_chunks, chunk_size, out_channels]
+            out = torch.unflatten(out, dim=0, sizes=(batch_size, -1))
+
+            # -> [batch_size, t + final_right_padding, out_channels]
+            out = torch.flatten(out, start_dim=1, end_dim=2)
+
+            # -> [batch_size, t, out_channels]
+            if final_right_padding > 0:
+                out = out[:, :-final_right_padding, :]
+        else:
+            out = self.layer_norm(x)
+            out = out.transpose(1, 2)
+            out = self.bottleneck(out)
+            out = self.conv(out)
+
+            if self.causal:
+                # chomp
+                out = out[..., : -self.padding]
+
+            out = out.transpose(1, 2)
+            out = self.after_conv(out)
 
-        if self.causal:
-            # chomp
-            out = out[..., : -self.padding]
-        out = out.transpose(1, 2)
-        out = self.after_conv(out)
         if mask is not None:
             out.masked_fill_(mask, 0.0)
+
         return out
 
 
@@ -186,6 +392,14 @@ class ConformerEncoderLayer(nn.Module):
                 dropout=dropout,
                 mask_pos_future=causal,
             )
+        elif attention_type == "hypermixing":
+            self.mha_layer = HyperMixing(
+                input_output_dim=d_model,
+                hypernet_size=d_ffn,
+                tied=False,
+                num_heads=nhead,
+                fix_tm_hidden_size=False,
+            )
 
         self.convolution_module = ConvolutionModule(
             d_model, kernel_size, bias, activation, dropout, causal=causal
@@ -222,7 +436,8 @@ class ConformerEncoderLayer(nn.Module):
         x,
         src_mask: Optional[torch.Tensor] = None,
         src_key_padding_mask: Optional[torch.Tensor] = None,
-        pos_embs: Optional[torch.Tensor] = None,
+        pos_embs: torch.Tensor = None,
+        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
     ):
         """
         Arguments
@@ -235,8 +450,12 @@ class ConformerEncoderLayer(nn.Module):
             The mask for the src keys per batch.
         pos_embs: torch.Tensor, torch.nn.Module, optional
             Module or tensor containing the input sequence positional embeddings
+        dynchunktrain_config: Optional[DynChunkTrainConfig]
+            Dynamic Chunk Training configuration object for streaming,
+            specifically involved here to apply Dynamic Chunk Convolution to
+            the convolution module.
         """
-        conv_mask = None
+        conv_mask: Optional[torch.Tensor] = None
         if src_key_padding_mask is not None:
             conv_mask = src_key_padding_mask.unsqueeze(-1)
         # ffn module
@@ -244,6 +463,7 @@ class ConformerEncoderLayer(nn.Module):
         # muti-head attention module
         skip = x
         x = self.norm1(x)
+
         x, self_attn = self.mha_layer(
             x,
             x,
@@ -254,11 +474,100 @@ class ConformerEncoderLayer(nn.Module):
         )
         x = x + skip
         # convolution module
-        x = x + self.convolution_module(x, conv_mask)
+        x = x + self.convolution_module(
+            x, conv_mask, dynchunktrain_config=dynchunktrain_config
+        )
+        # ffn module
+        x = self.norm2(x + 0.5 * self.ffn_module2(x))
+        return x, self_attn
+
+    def forward_streaming(
+        self,
+        x,
+        context: ConformerEncoderLayerStreamingContext,
+        pos_embs: torch.Tensor = None,
+    ):
+        """Conformer layer streaming forward (typically for
+        DynamicChunkTraining-trained models), which is to be used at inference
+        time. Relies on a mutable context object as initialized by
+        `make_streaming_context` that should be used across chunks.
+        Invoked by `ConformerEncoder.forward_streaming`.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Input tensor for this layer. Batching is supported as long as you
+            keep the context consistent.
+        context: ConformerEncoderStreamingContext
+            Mutable streaming context; the same object should be passed across
+            calls.
+        pos_embs: torch.Tensor, optional
+            Positional embeddings, if used."""
+
+        orig_len = x.shape[-2]
+        # ffn module
+        x = x + 0.5 * self.ffn_module1(x)
+
+        # TODO: make the approach for MHA left context more efficient.
+        # currently, this saves the inputs to the MHA.
+        # the naive approach is suboptimal in a few ways, namely that the
+        # outputs for this left padding is being re-computed even though we
+        # discard them immediately after.
+
+        # left pad `x` with our MHA left context
+        if context.mha_left_context is not None:
+            x = torch.cat((context.mha_left_context, x), dim=1)
+
+        # compute new MHA left context for the next call to our function
+        if context.mha_left_context_size > 0:
+            context.mha_left_context = x[
+                ..., -context.mha_left_context_size :, :
+            ]
+
+        # multi-head attention module
+        skip = x
+        x = self.norm1(x)
+
+        x, self_attn = self.mha_layer(
+            x, x, x, attn_mask=None, key_padding_mask=None, pos_embs=pos_embs,
+        )
+        x = x + skip
+
+        # truncate outputs corresponding to the MHA left context (we only care
+        # about our chunk's outputs); see above to-do
+        x = x[..., -orig_len:, :]
+
+        if context.dcconv_left_context is not None:
+            x = torch.cat((context.dcconv_left_context, x), dim=1)
+
+        # compute new DCConv left context for the next call to our function
+        context.dcconv_left_context = x[
+            ..., -self.convolution_module.padding :, :
+        ]
+
+        # convolution module
+        x = x + self.convolution_module(x)
+
+        # truncate outputs corresponding to the DCConv left context
+        x = x[..., -orig_len:, :]
+
         # ffn module
         x = self.norm2(x + 0.5 * self.ffn_module2(x))
         return x, self_attn
 
+    def make_streaming_context(self, mha_left_context_size: int):
+        """Creates a blank streaming context for this encoding layer.
+
+        Arguments
+        ---------
+        mha_left_context_size : int
+            How many left frames should be saved and used as left context to the
+            current chunk when streaming
+        """
+        return ConformerEncoderLayerStreamingContext(
+            mha_left_context_size=mha_left_context_size
+        )
+
 
 class ConformerEncoder(nn.Module):
     """This class implements the Conformer encoder.
@@ -346,6 +655,7 @@ class ConformerEncoder(nn.Module):
         src_mask: Optional[torch.Tensor] = None,
         src_key_padding_mask: Optional[torch.Tensor] = None,
         pos_embs: Optional[torch.Tensor] = None,
+        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
     ):
         """
         Arguments
@@ -360,8 +670,11 @@ class ConformerEncoder(nn.Module):
             Module or tensor containing the input sequence positional embeddings
             If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
             where S is the sequence length, and E is the embedding dimension.
+        dynchunktrain_config: Optional[DynChunkTrainConfig]
+            Dynamic Chunk Training configuration object for streaming,
+            specifically involved here to apply Dynamic Chunk Convolution to the
+            convolution module.
         """
-
         if self.attention_type == "RelPosMHAXL":
             if pos_embs is None:
                 raise ValueError(
@@ -376,12 +689,74 @@ class ConformerEncoder(nn.Module):
                 src_mask=src_mask,
                 src_key_padding_mask=src_key_padding_mask,
                 pos_embs=pos_embs,
+                dynchunktrain_config=dynchunktrain_config,
+            )
+            attention_lst.append(attention)
+        output = self.norm(output)
+
+        return output, attention_lst
+
+    def forward_streaming(
+        self,
+        src: torch.Tensor,
+        context: ConformerEncoderStreamingContext,
+        pos_embs: Optional[torch.Tensor] = None,
+    ):
+        """Conformer streaming forward (typically for
+        DynamicChunkTraining-trained models), which is to be used at inference
+        time. Relies on a mutable context object as initialized by
+        `make_streaming_context` that should be used across chunks.
+
+        Arguments
+        ---------
+        src : torch.Tensor
+            Input tensor. Batching is supported as long as you keep the context
+            consistent.
+        context: ConformerEncoderStreamingContext
+            Mutable streaming context; the same object should be passed across
+            calls.
+        pos_embs: torch.Tensor, optional
+            Positional embeddings, if used."""
+
+        if self.attention_type == "RelPosMHAXL":
+            if pos_embs is None:
+                raise ValueError(
+                    "The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
+                )
+
+        output = src
+        attention_lst = []
+        for i, enc_layer in enumerate(self.layers):
+            output, attention = enc_layer.forward_streaming(
+                output, pos_embs=pos_embs, context=context.layers[i]
             )
             attention_lst.append(attention)
         output = self.norm(output)
 
         return output, attention_lst
 
+    def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig):
+        """Creates a blank streaming context for the encoder.
+
+        Arguments
+        ---------
+        dynchunktrain_config: Optional[DynChunkTrainConfig]
+            Dynamic Chunk Training configuration object for streaming
+        mha_left_context_size : int
+            How many left frames should be saved and used as left context to the
+            current chunk when streaming. This value is replicated across all
+            layers.
+        """
+        return ConformerEncoderStreamingContext(
+            dynchunktrain_config=dynchunktrain_config,
+            layers=[
+                layer.make_streaming_context(
+                    mha_left_context_size=dynchunktrain_config.left_context_size_frames()
+                )
+                for layer in self.layers
+            ],
+        )
+
 
 class ConformerDecoderLayer(nn.Module):
     """This is an implementation of Conformer encoder layer.
diff --git a/speechbrain/lobes/models/transformer/Transformer.py b/speechbrain/lobes/models/transformer/Transformer.py
index e290a7f87c2f325719da2497924e31534e35615b..1f0d4d7dc69b4e965cd2060fa541f538a2c4b5c5 100644
--- a/speechbrain/lobes/models/transformer/Transformer.py
+++ b/speechbrain/lobes/models/transformer/Transformer.py
@@ -1,4 +1,4 @@
-"""Transformer implementaion in the SpeechBrain style.
+"""Transformer implementation in the SpeechBrain style.
 Authors
 * Jianyuan Zhong 2020
 * Samuele Cornell 2021
@@ -126,7 +126,7 @@ class TransformerInterface(nn.Module):
         self.decoder_kdim = decoder_kdim
         self.decoder_vdim = decoder_vdim
 
-        assert attention_type in ["regularMHA", "RelPosMHAXL"]
+        assert attention_type in ["regularMHA", "RelPosMHAXL", "hypermixing"]
         assert positional_encoding in ["fixed_abs_sine", None]
 
         assert (
@@ -242,6 +242,10 @@ class PositionalEncoding(nn.Module):
 
     def __init__(self, input_size, max_len=2500):
         super().__init__()
+        if input_size % 2 != 0:
+            raise ValueError(
+                f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
+            )
         self.max_len = max_len
         pe = torch.zeros(self.max_len, input_size, requires_grad=False)
         positions = torch.arange(0, self.max_len).unsqueeze(1).float()
@@ -337,6 +341,14 @@ class TransformerEncoderLayer(nn.Module):
             self.self_att = sb.nnet.attention.RelPosMHAXL(
                 d_model, nhead, dropout, mask_pos_future=causal
             )
+        elif attention_type == "hypermixing":
+            self.self_att = sb.nnet.hypermixing.HyperMixing(
+                input_output_dim=d_model,
+                hypernet_size=d_ffn,
+                tied=False,
+                num_heads=nhead,
+                fix_tm_hidden_size=False,
+            )
 
         if ffn_type == "regularFFN":
             self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
@@ -351,14 +363,14 @@ class TransformerEncoderLayer(nn.Module):
                     in_channels=d_model,
                     out_channels=d_ffn,
                     kernel_size=ffn_cnn_kernel_size_list[0],
-                    padding="same",
+                    padding="causal" if causal else "same",
                 ),
                 nn.ReLU(),
                 Conv1d(
                     in_channels=d_ffn,
                     out_channels=d_model,
                     kernel_size=ffn_cnn_kernel_size_list[1],
-                    padding="same",
+                    padding="causal" if causal else "same",
                 ),
             )
 
@@ -518,6 +530,7 @@ class TransformerEncoder(nn.Module):
         src_mask: Optional[torch.Tensor] = None,
         src_key_padding_mask: Optional[torch.Tensor] = None,
         pos_embs: Optional[torch.Tensor] = None,
+        dynchunktrain_config=None,
     ):
         """
         Arguments
@@ -529,6 +542,10 @@ class TransformerEncoder(nn.Module):
         src_key_padding_mask : tensor
             The mask for the src keys per batch (optional).
         """
+        assert (
+            dynchunktrain_config is None
+        ), "Dynamic Chunk Training unsupported for this encoder"
+
         output = src
         if self.layerdrop_prob > 0.0:
             keep_probs = self.rng.random(len(self.layers))
@@ -858,6 +875,7 @@ class NormalizedEmbedding(nn.Module):
 
 def get_key_padding_mask(padded_input, pad_idx):
     """Creates a binary mask to prevent attention to padded locations.
+    We suggest using get_mask_from_lengths instead of this function.
     Arguments
     ----------
     padded_input: int
@@ -912,3 +930,32 @@ def get_lookahead_mask(padded_input):
         .masked_fill(mask == 1, float(0.0))
     )
     return mask.detach().to(padded_input.device)
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+    """Creates a binary mask from sequence lengths
+    Arguments
+    ---------
+    lengths: torch.Tensor
+        A tensor of sequence lengths
+    max_len: int (Optional)
+        Maximum sequence length, defaults to None.
+    Returns
+    -------
+    mask: torch.Tensor
+        the mask where padded elements are set to True.
+        Then one can use tensor.masked_fill_(mask, 0) for the masking.
+    Example
+    -------
+    >>> lengths = torch.tensor([3, 2, 4])
+    >>> get_mask_from_lengths(lengths)
+    tensor([[False, False, False,  True],
+            [False, False,  True,  True],
+            [False, False, False, False]])
+    """
+    if max_len is None:
+        max_len = torch.max(lengths).item()
+    seq_range = torch.arange(
+        max_len, device=lengths.device, dtype=lengths.dtype
+    )
+    return ~(seq_range.unsqueeze(0) < lengths.unsqueeze(1))
diff --git a/speechbrain/lobes/models/transformer/TransformerASR.py b/speechbrain/lobes/models/transformer/TransformerASR.py
old mode 100644
new mode 100755
index 14d936e7160f02df9f7ee55bcf68164e6ad5b916..92aba87af6b0b1666d18a30f106c1e67104d197a
--- a/speechbrain/lobes/models/transformer/TransformerASR.py
+++ b/speechbrain/lobes/models/transformer/TransformerASR.py
@@ -2,11 +2,14 @@
 
 Authors
 * Jianyuan Zhong 2020
+* Titouan Parcollet 2024
+* Luca Della Libera 2024
 """
 
+from dataclasses import dataclass
 import torch  # noqa 42
 from torch import nn
-from typing import Optional
+from typing import Any, Optional
 from speechbrain.nnet.linear import Linear
 from speechbrain.nnet.containers import ModuleList
 from speechbrain.lobes.models.transformer.Transformer import (
@@ -17,6 +20,129 @@ from speechbrain.lobes.models.transformer.Transformer import (
 )
 from speechbrain.nnet.activations import Swish
 from speechbrain.dataio.dataio import length_to_mask
+from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+
+
+@dataclass
+class TransformerASRStreamingContext:
+    """Streaming metadata and state for a `TransformerASR` instance."""
+
+    dynchunktrain_config: DynChunkTrainConfig
+    """Dynamic Chunk Training configuration holding chunk size and context size
+    information."""
+
+    encoder_context: Any
+    """Opaque encoder context information. It is constructed by the encoder's
+    `make_streaming_context` method and is passed to the encoder when using
+    `encode_streaming`.
+    """
+
+
+def make_transformer_src_mask(
+    src: torch.Tensor,
+    causal: bool = False,
+    dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
+) -> Optional[torch.Tensor]:
+    """Prepare the source transformer mask that restricts which frames can
+    attend to which frames depending on causal or other simple restricted
+    attention methods.
+
+    Arguments
+    ---------
+    src: torch.Tensor
+        The source tensor to build a mask from. The contents of the tensor are
+        not actually used currently; only its shape and other metadata (e.g.
+        device).
+    causal: bool
+        Whether strict causality shall be used. Frames will not be able to
+        attend to any future frame.
+    dynchunktrain_config: DynChunkTrainConfig, optional
+        Dynamic Chunk Training configuration. This implements a simple form of
+        chunkwise attention. Incompatible with `causal`.
+
+    Returns
+    -------
+    torch.Tensor
+        A boolean mask Tensor of shape (timesteps, timesteps).
+    """
+    if causal:
+        assert dynchunktrain_config is None
+        return get_lookahead_mask(src)
+
+    if dynchunktrain_config is None:
+        return
+
+    # The following is not really the sole source used to implement this,
+    # but it helps introduce the concept.
+    # ref: Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
+    # https://arxiv.org/pdf/2012.05481.pdf
+    timesteps = src.size(1)
+
+    # Mask the future at the right of each chunk
+    chunk_size = dynchunktrain_config.chunk_size
+    num_chunks = timesteps // chunk_size
+    timestep_idx = torch.arange(timesteps, device=src.device)
+    mask_idx = torch.arange(
+        chunk_size, chunk_size * (num_chunks + 2), chunk_size, device=src.device
+    ).repeat_interleave(chunk_size)[:timesteps]
+    src_mask = timestep_idx[None] >= mask_idx[:, None]
+
+    # Mask the past at the left of each chunk (accounting for left context)
+    # only relevant if using left context
+    if not dynchunktrain_config.is_infinite_left_context():
+        num_left_chunks = dynchunktrain_config.left_context_size
+        mask_idx -= chunk_size * (num_left_chunks + 1)
+        src_mask += timestep_idx[None] < mask_idx[:, None]
+
+    return src_mask
+
+
+def make_transformer_src_tgt_masks(
+    src,
+    tgt=None,
+    wav_len=None,
+    pad_idx=0,
+    causal: bool = False,
+    dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
+):
+    """This function generates masks for training the transformer model,
+    opiniated for an ASR context with encoding masks and, optionally, decoding
+    masks (if specifying `tgt`).
+
+    Arguments
+    ---------
+    src : tensor
+        The sequence to the encoder (required).
+    tgt : tensor
+        The sequence to the decoder.
+    pad_idx : int
+        The index for <pad> token (default=0).
+    causal: bool
+        Whether strict causality shall be used. See `make_asr_src_mask`
+    dynchunktrain_config: DynChunkTrainConfig, optional
+        Dynamic Chunk Training configuration. See `make_asr_src_mask`
+    """
+    src_key_padding_mask = None
+
+    # mask out audio beyond the length of audio for each batch
+    if wav_len is not None:
+        abs_len = torch.round(wav_len * src.shape[1])
+        src_key_padding_mask = ~length_to_mask(abs_len).bool()
+
+    # mask out the source
+    src_mask = make_transformer_src_mask(
+        src, causal=causal, dynchunktrain_config=dynchunktrain_config
+    )
+
+    # If no decoder in the transformer...
+    if tgt is not None:
+        tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx)
+        tgt_mask = get_lookahead_mask(tgt)
+    else:
+        tgt_key_padding_mask = None
+        tgt_mask = None
+
+    return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
 
 
 class TransformerASR(TransformerInterface):
@@ -152,9 +278,11 @@ class TransformerASR(TransformerInterface):
             ),
             torch.nn.Dropout(dropout),
         )
-        self.custom_tgt_module = ModuleList(
-            NormalizedEmbedding(d_model, tgt_vocab)
-        )
+
+        if num_decoder_layers > 0:
+            self.custom_tgt_module = ModuleList(
+                NormalizedEmbedding(d_model, tgt_vocab)
+            )
 
         # reset parameters using xavier_normal_
         self._init_params()
@@ -183,11 +311,15 @@ class TransformerASR(TransformerInterface):
             tgt_key_padding_mask,
             src_mask,
             tgt_mask,
-        ) = self.make_masks(src, tgt, wav_len, pad_idx=pad_idx)
+        ) = make_transformer_src_tgt_masks(
+            src, tgt, wav_len, causal=self.causal, pad_idx=pad_idx
+        )
 
         src = self.custom_src_module(src)
         # add pos encoding to queries if are sinusoidal ones else
-        if self.attention_type == "RelPosMHAXL":
+        if self.attention_type == "hypermixing":
+            pos_embs_encoder = None
+        elif self.attention_type == "RelPosMHAXL":
             pos_embs_encoder = self.positional_encoding(src)
         elif self.positional_encoding_type == "fixed_abs_sine":
             src = src + self.positional_encoding(src)  # add the encodings here
@@ -202,13 +334,14 @@ class TransformerASR(TransformerInterface):
 
         tgt = self.custom_tgt_module(tgt)
 
-        # Add positional encoding to the target before feeding the decoder.
         if self.attention_type == "RelPosMHAXL":
-            # use standard sinusoidal pos encoding in decoder
             tgt = tgt + self.positional_encoding_decoder(tgt)
             pos_embs_encoder = None  # self.positional_encoding(src)
             pos_embs_target = None
-        elif self.positional_encoding_type == "fixed_abs_sine":
+        elif (
+            self.positional_encoding_type == "fixed_abs_sine"
+            or self.attention_type == "hypermixing"
+        ):
             tgt = tgt + self.positional_encoding(tgt)
             pos_embs_target = None
             pos_embs_encoder = None
@@ -216,7 +349,7 @@ class TransformerASR(TransformerInterface):
         decoder_out, _, _ = self.decoder(
             tgt=tgt,
             memory=encoder_out,
-            memory_mask=src_mask,
+            memory_mask=None,
             tgt_mask=tgt_mask,
             tgt_key_padding_mask=tgt_key_padding_mask,
             memory_key_padding_mask=src_key_padding_mask,
@@ -226,35 +359,6 @@ class TransformerASR(TransformerInterface):
 
         return encoder_out, decoder_out
 
-    def make_masks(self, src, tgt=None, wav_len=None, pad_idx=0):
-        """This method generates the masks for training the transformer model.
-
-        Arguments
-        ---------
-        src : tensor
-            The sequence to the encoder (required).
-        tgt : tensor
-            The sequence to the decoder.
-        pad_idx : int
-            The index for <pad> token (default=0).
-        """
-        src_key_padding_mask = None
-        if wav_len is not None:
-            abs_len = torch.round(wav_len * src.shape[1])
-            src_key_padding_mask = ~length_to_mask(abs_len).bool()
-
-        src_mask = None
-
-        # If no decoder in the transformer...
-        if tgt is not None:
-            tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx)
-            tgt_mask = get_lookahead_mask(tgt)
-        else:
-            tgt_key_padding_mask = None
-            tgt_mask = None
-
-        return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
-
     @torch.no_grad()
     def decode(self, tgt, encoder_out, enc_len=None):
         """This method implements a decoding step for the transformer model.
@@ -275,12 +379,14 @@ class TransformerASR(TransformerInterface):
 
         tgt = self.custom_tgt_module(tgt)
         if self.attention_type == "RelPosMHAXL":
-            # use standard sinusoidal pos encoding in decoder
             tgt = tgt + self.positional_encoding_decoder(tgt)
             pos_embs_encoder = None  # self.positional_encoding(src)
             pos_embs_target = None
-        elif self.positional_encoding_type == "fixed_abs_sine":
-            tgt = tgt + self.positional_encoding(tgt)
+        elif (
+            self.positional_encoding_type == "fixed_abs_sine"
+            or self.attention_type == "hypermixing"
+        ):
+            tgt = tgt + self.positional_encoding(tgt)  # add the encodings here
             pos_embs_target = None
             pos_embs_encoder = None
 
@@ -294,7 +400,13 @@ class TransformerASR(TransformerInterface):
         )
         return prediction, multihead_attns[-1]
 
-    def encode(self, src, wav_len=None, pad_idx=0):
+    def encode(
+        self,
+        src,
+        wav_len=None,
+        pad_idx=0,
+        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
+    ):
         """
         Encoder forward pass
 
@@ -310,25 +422,156 @@ class TransformerASR(TransformerInterface):
             bz, t, ch1, ch2 = src.shape
             src = src.reshape(bz, t, ch1 * ch2)
 
-        (src_key_padding_mask, _, src_mask, _,) = self.make_masks(
-            src, None, wav_len, pad_idx=pad_idx
+        (
+            src_key_padding_mask,
+            _,
+            src_mask,
+            _,
+        ) = make_transformer_src_tgt_masks(
+            src,
+            None,
+            wav_len,
+            pad_idx=pad_idx,
+            causal=self.causal,
+            dynchunktrain_config=dynchunktrain_config,
         )
 
         src = self.custom_src_module(src)
-        if self.attention_type == "RelPosMHAXL":
+        if self.attention_type == "hypermixing":
+            pos_embs_source = None
+        elif self.attention_type == "RelPosMHAXL":
             pos_embs_source = self.positional_encoding(src)
-
         elif self.positional_encoding_type == "fixed_abs_sine":
             src = src + self.positional_encoding(src)
             pos_embs_source = None
 
         encoder_out, _ = self.encoder(
             src=src,
+            src_mask=src_mask,
             src_key_padding_mask=src_key_padding_mask,
             pos_embs=pos_embs_source,
+            dynchunktrain_config=dynchunktrain_config,
         )
+
         return encoder_out
 
+    def encode_streaming(self, src, context: TransformerASRStreamingContext):
+        """
+        Streaming encoder forward pass
+
+        Arguments
+        ---------
+        src : torch.Tensor
+            The sequence (chunk) to the encoder.
+
+        context : TransformerASRStreamingContext
+            Mutable reference to the streaming context. This holds the state
+            needed to persist across chunk inferences and can be built using
+            `make_streaming_context`. This will get mutated by this function.
+
+        Returns
+        -------
+        Encoder output for this chunk.
+
+        Example
+        -------
+        >>> import torch
+        >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
+        >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+        >>> net = TransformerASR(
+        ...     tgt_vocab=100,
+        ...     input_size=64,
+        ...     d_model=64,
+        ...     nhead=8,
+        ...     num_encoder_layers=1,
+        ...     num_decoder_layers=0,
+        ...     d_ffn=128,
+        ...     attention_type="RelPosMHAXL",
+        ...     positional_encoding=None,
+        ...     encoder_module="conformer",
+        ...     normalize_before=True,
+        ...     causal=False,
+        ... )
+        >>> ctx = net.make_streaming_context(DynChunkTrainConfig(16, 1))
+        >>> src1 = torch.rand([8, 16, 64])
+        >>> src2 = torch.rand([8, 16, 64])
+        >>> out1 = net.encode_streaming(src1, ctx)
+        >>> out1.shape
+        torch.Size([8, 16, 64])
+        >>> ctx.encoder_context.layers[0].mha_left_context.shape
+        torch.Size([8, 16, 64])
+        >>> out2 = net.encode_streaming(src2, ctx)
+        >>> out2.shape
+        torch.Size([8, 16, 64])
+        >>> ctx.encoder_context.layers[0].mha_left_context.shape
+        torch.Size([8, 16, 64])
+        >>> combined_out = torch.concat((out1, out2), dim=1)
+        >>> combined_out.shape
+        torch.Size([8, 32, 64])
+        """
+
+        if src.dim() == 4:
+            bz, t, ch1, ch2 = src.shape
+            src = src.reshape(bz, t, ch1 * ch2)
+
+        # HACK: our problem here is that the positional_encoding is computed
+        # against the size of our source tensor, but we only know how many left
+        # context frames we're injecting to the encoder within the encoder
+        # context.
+        # so this workaround does just that.
+        #
+        # i'm not sure how this would be best refactored, but an option would be
+        # to let the encoder get the pos embedding itself and have a way to
+        # cache it.
+        #
+        # additionally, positional encoding functions take in a whole source
+        # tensor just to get its attributes (size, device, type) but this is
+        # sort of silly for the embeddings that don't need one.
+        # so we craft a dummy empty (uninitialized) tensor to help...
+        known_left_context = context.encoder_context.layers[0].mha_left_context
+        if known_left_context is None:
+            pos_encoding_dummy = src
+        else:
+            target_shape = list(src.shape)
+            target_shape[-2] += known_left_context.shape[-2]
+            pos_encoding_dummy = torch.empty(size=target_shape).to(src)
+
+        src = self.custom_src_module(src)
+        if self.attention_type == "RelPosMHAXL":
+            pos_embs_source = self.positional_encoding(pos_encoding_dummy)
+
+        elif self.positional_encoding_type == "fixed_abs_sine":
+            src = src + self.positional_encoding(pos_encoding_dummy)
+            pos_embs_source = None
+
+        encoder_out, _ = self.encoder.forward_streaming(
+            src=src, pos_embs=pos_embs_source, context=context.encoder_context
+        )
+        return encoder_out
+
+    def make_streaming_context(
+        self, dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={}
+    ):
+        """Creates a blank streaming context for this transformer and its
+        encoder.
+
+        Arguments
+        ---------
+        dynchunktrain_config : DynChunkTrainConfig
+            Runtime chunkwise attention configuration.
+
+        encoder_kwargs : dict
+            Parameters to be forward to the encoder's `make_streaming_context`.
+            Metadata required for the encoder could differ depending on the
+            encoder.
+        """
+        return TransformerASRStreamingContext(
+            dynchunktrain_config=dynchunktrain_config,
+            encoder_context=self.encoder.make_streaming_context(
+                dynchunktrain_config, **encoder_kwargs,
+            ),
+        )
+
     def _init_params(self):
         for p in self.parameters():
             if p.dim() > 1:
@@ -363,8 +606,20 @@ class EncoderWrapper(nn.Module):
     def __init__(self, transformer, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.transformer = transformer
+        self.make_streaming_context = self.transformer.make_streaming_context
 
-    def forward(self, x, wav_lens=None, pad_idx=0):
+    def forward(self, x, wav_lens=None, pad_idx=0, **kwargs):
         """ Processes the input tensor x and returns an output tensor."""
-        x = self.transformer.encode(x, wav_lens, pad_idx)
+        x = self.transformer.encode(x, wav_lens, pad_idx, **kwargs,)
         return x
+
+    def forward_streaming(self, x, context):
+        """Processes the input audio chunk tensor `x`, using and updating the
+        mutable encoder `context`"""
+        x = self.transformer.encode_streaming(x, context)
+        return x
+
+    def make_streaming_context(self, *args, **kwargs):
+        """Initializes a streaming context. Forwards all arguments to the
+        underlying transformer. See :meth:`speechbrain.lobes.models.transformer.TransformerASR.make_streaming_context`."""
+        return self.transformer.make_streaming_context(*args, **kwargs)
diff --git a/speechbrain/nnet/CNN.py b/speechbrain/nnet/CNN.py
index e97cb929d47ac494c427d6baa43bcd4c7996fab8..54f83e5ec2827ad9736d795ec1cd46f56bce92cb 100644
--- a/speechbrain/nnet/CNN.py
+++ b/speechbrain/nnet/CNN.py
@@ -544,6 +544,12 @@ class Conv2d(nn.Module):
         documentation for more information.
     bias : bool
         If True, the additive bias b is adopted.
+    max_norm: float
+        kernel max-norm.
+    swap: bool
+        If True, the convolution is done with the format (B, C, W, H).
+        If False, the convolution is dine with (B, H, W, C).
+        Active only if skip_transpose is False.
     skip_transpose : bool
         If False, uses batch x spatial.dim2 x spatial.dim1 x channel convention of speechbrain.
         If True, uses batch x channel x spatial.dim1 x spatial.dim2 convention.
@@ -576,6 +582,8 @@ class Conv2d(nn.Module):
         groups=1,
         bias=True,
         padding_mode="reflect",
+        max_norm=None,
+        swap=False,
         skip_transpose=False,
         weight_norm=False,
         conv_init=None,
@@ -596,6 +604,8 @@ class Conv2d(nn.Module):
         self.padding = padding
         self.padding_mode = padding_mode
         self.unsqueeze = False
+        self.max_norm = max_norm
+        self.swap = swap
         self.skip_transpose = skip_transpose
 
         if input_shape is None and in_channels is None:
@@ -637,6 +647,8 @@ class Conv2d(nn.Module):
         """
         if not self.skip_transpose:
             x = x.transpose(1, -1)
+            if self.swap:
+                x = x.transpose(-1, -2)
 
         if self.unsqueeze:
             x = x.unsqueeze(1)
@@ -659,6 +671,11 @@ class Conv2d(nn.Module):
                 + self.padding
             )
 
+        if self.max_norm is not None:
+            self.conv.weight.data = torch.renorm(
+                self.conv.weight.data, p=2, dim=0, maxnorm=self.max_norm
+            )
+
         wx = self.conv(x)
 
         if self.unsqueeze:
@@ -666,7 +683,8 @@ class Conv2d(nn.Module):
 
         if not self.skip_transpose:
             wx = wx.transpose(1, -1)
-
+            if self.swap:
+                wx = wx.transpose(1, 2)
         return wx
 
     def _manage_padding(
@@ -733,72 +751,6 @@ class Conv2d(nn.Module):
         self.conv = nn.utils.remove_weight_norm(self.conv)
 
 
-class Conv2dWithConstraint(Conv2d):
-    """This function implements 2d convolution with kernel max-norm constaint.
-    This corresponds to set an upper bound for the kernel norm.
-
-    Arguments
-    ---------
-    out_channels : int
-        It is the number of output channels.
-    kernel_size : tuple
-        Kernel size of the 2d convolutional filters over time and frequency
-        axis.
-    input_shape : tuple
-        The shape of the input. Alternatively use ``in_channels``.
-    in_channels : int
-        The number of input channels. Alternatively use ``input_shape``.
-    stride: int
-        Stride factor of the 2d convolutional filters over time and frequency
-        axis.
-    dilation : int
-        Dilation factor of the 2d convolutional filters over time and
-        frequency axis.
-    padding : str
-        (same, valid). If "valid", no padding is performed.
-        If "same" and stride is 1, output shape is same as input shape.
-    padding_mode : str
-        This flag specifies the type of padding. See torch.nn documentation
-        for more information.
-    groups : int
-        This option specifies the convolutional groups. See torch.nn
-        documentation for more information.
-    bias : bool
-        If True, the additive bias b is adopted.
-    max_norm : float
-        kernel  max-norm
-
-    Example
-    -------
-    >>> inp_tensor = torch.rand([10, 40, 16, 8])
-    >>> max_norm = 1
-    >>> cnn_2d_constrained = Conv2dWithConstraint(
-    ...     in_channels=inp_tensor.shape[-1], out_channels=5, kernel_size=(7, 3)
-    ... )
-    >>> out_tensor = cnn_2d_constrained(inp_tensor)
-    >>> torch.any(torch.norm(cnn_2d_constrained.conv.weight.data, p=2, dim=0)>max_norm)
-    tensor(False)
-    """
-
-    def __init__(self, *args, max_norm=1, **kwargs):
-        self.max_norm = max_norm
-        super().__init__(*args, **kwargs)
-
-    def forward(self, x):
-        """Returns the output of the convolution.
-
-        Arguments
-        ---------
-        x : torch.Tensor (batch, time, channel)
-            input to convolve. 2d or 4d tensors are expected.
-
-        """
-        self.conv.weight.data = torch.renorm(
-            self.conv.weight.data, p=2, dim=0, maxnorm=self.max_norm
-        )
-        return super().forward(x)
-
-
 class ConvTranspose1d(nn.Module):
     """This class implements 1d transposed convolution with speechbrain.
     Transpose convolution is normally used to perform upsampling.
diff --git a/speechbrain/nnet/RNN.py b/speechbrain/nnet/RNN.py
index 36a4c8bb7e122d1bff9438703ff22c0efc9bf48e..4191437296f7ff4483d8ef58d61c457add0da236 100644
--- a/speechbrain/nnet/RNN.py
+++ b/speechbrain/nnet/RNN.py
@@ -758,9 +758,10 @@ class AttentionalRNNDecoder(nn.Module):
 
     Example
     -------
-    >>> enc_states = torch.rand([4, 10, 20])
-    >>> wav_len = torch.rand([4])
-    >>> inp_tensor = torch.rand([4, 5, 6])
+    >>> batch_size = 4
+    >>> enc_states = torch.rand([batch_size, 10, 20])
+    >>> wav_len = torch.ones([batch_size])
+    >>> inp_tensor = torch.rand([batch_size, 5, 6])
     >>> net = AttentionalRNNDecoder(
     ...     rnn_type="lstm",
     ...     attn_type="content",
diff --git a/speechbrain/nnet/activations.py b/speechbrain/nnet/activations.py
index 9261996a9ce6b360161a0f295cc3bb8510ab6f90..3ab15157e47b12946d44d84bdbe841740c183ddc 100644
--- a/speechbrain/nnet/activations.py
+++ b/speechbrain/nnet/activations.py
@@ -23,6 +23,8 @@ class Softmax(torch.nn.Module):
         If the dimension where softmax is applied.
     reshape: bool
         whether to apply reshaping (true by default)
+    dtype: torch.dtype
+        dtype of the output tensor
 
     Example
     -------
@@ -33,15 +35,19 @@ class Softmax(torch.nn.Module):
     torch.Size([10, 50, 40])
     """
 
-    def __init__(self, apply_log=False, dim=-1, reshape=True):
+    def __init__(
+        self, apply_log=False, dim=-1, reshape=True, dtype=torch.float32
+    ):
         super().__init__()
 
         if apply_log:
-            self.act = torch.nn.LogSoftmax(dim=dim)
+            self.act = F.log_softmax
         else:
-            self.act = torch.nn.Softmax(dim=dim)
+            self.act = F.softmax
 
+        self.dim = dim
         self.reshape = reshape
+        self.dtype = dtype
 
     def forward(self, x):
         """Returns the softmax of the input tensor.
@@ -61,7 +67,7 @@ class Softmax(torch.nn.Module):
             if len(dims) == 4:
                 x = x.reshape(dims[0] * dims[1], dims[2], dims[3])
 
-        x_act = self.act(x)
+        x_act = self.act(x, dim=self.dim, dtype=self.dtype)
 
         # Retrieving the original shape format
         if self.reshape:
diff --git a/speechbrain/nnet/attention.py b/speechbrain/nnet/attention.py
index 27e9930603ab9e640d01b694c258d91c9b18f9c5..5a9b1650c7e9f83fa84d34ec78176c4a6f10bee1 100644
--- a/speechbrain/nnet/attention.py
+++ b/speechbrain/nnet/attention.py
@@ -591,17 +591,28 @@ class RelPosMHAXL(nn.Module):
             query + self.pos_bias_v.view(1, 1, self.num_heads, self.head_dim)
         ).transpose(1, 2)
 
+        # Moved the `* self.scale` mul from after the `attn_score` sum to prior
+        # to the matmul in order to lower overflow risks on fp16.
+        # This change is inspired by the following paper, but no other changes
+        # were ported from there so far.
+        # ref: E.T.: Re-Thinking Self-Attention for Transformer Models on GPUs
+        # https://asherliu.github.io/docs/sc21a.pdf
+
         # (batch, head, qlen, klen)
-        matrix_ac = torch.matmul(q_with_bias_u, key.permute(0, 2, 3, 1))
+        matrix_ac = torch.matmul(
+            q_with_bias_u * self.scale, key.permute(0, 2, 3, 1)
+        )
         # (batch, num_heads, klen, 2*klen-1)
-        matrix_bd = torch.matmul(q_with_bias_v, p_k.permute(0, 2, 3, 1))
+        matrix_bd = torch.matmul(
+            q_with_bias_v * self.scale, p_k.permute(0, 2, 3, 1)
+        )
         matrix_bd = self.rel_shift(matrix_bd)  # shifting trick
 
         # if klen != qlen:
         #   import ipdb
         #  ipdb.set_trace(
 
-        attn_score = (matrix_ac + matrix_bd) * self.scale
+        attn_score = matrix_ac + matrix_bd  # already scaled above
 
         # compute attention probability
         if attn_mask is not None:
@@ -622,8 +633,25 @@ class RelPosMHAXL(nn.Module):
                 key_padding_mask.view(bsz, 1, 1, klen), self.attn_fill_value,
             )
 
-        attn_score = F.softmax(attn_score, dim=-1)
+        attn_score = F.softmax(attn_score, dim=-1, dtype=torch.float32)
         attn_score = self.dropout_att(attn_score)
+
+        # it is possible for us to hit full NaN when using chunked training
+        # so reapply masks, except with 0.0 instead as we are after the softmax
+        # because -inf would output 0.0 regardless anyway
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_score = attn_score.masked_fill(attn_mask, 0.0)
+            else:
+                # NOTE: the above fix is not implemented for this case as
+                # summing the mask with NaN would still result in NaN
+                pass
+
+        if key_padding_mask is not None:
+            attn_score = attn_score.masked_fill(
+                key_padding_mask.view(bsz, 1, 1, klen), 0.0,
+            )
+
         x = torch.matmul(
             attn_score, value.transpose(1, 2)
         )  # (batch, head, time1, d_k)
@@ -701,7 +729,7 @@ class MultiheadAttention(nn.Module):
         value,
         attn_mask: Optional[torch.Tensor] = None,
         key_padding_mask: Optional[torch.Tensor] = None,
-        return_attn_weights: Optional[torch.Tensor] = True,
+        return_attn_weights: bool = True,
         pos_embs: Optional[torch.Tensor] = None,
     ):
         """
@@ -716,13 +744,6 @@ class MultiheadAttention(nn.Module):
         value : torch.Tensor
             (B, S, E) where S is the source sequence length,
             B is the batch size, E is the embedding dimension.
-        key_padding_mask : torch.Tensor, optional
-            (B, S) where B is the batch size, S is the source sequence
-            length. If a ByteTensor is provided, the non-zero positions will
-            be ignored while the position with the zero positions will be
-            unchanged. If a BoolTensor is provided, the positions with the
-            value of True will be ignored while the position with the value
-            of False will be unchanged.
         attn_mask : torch.Tensor, optional
             2D mask (L, S) where L is the target sequence length, S is
             the source sequence length.
@@ -734,7 +755,16 @@ class MultiheadAttention(nn.Module):
             be unchanged. If a BoolTensor is provided, positions with True is
             not allowed to attend while False values will be unchanged. If a
             FloatTensor is provided, it will be added to the attention weight.
-        pos_embs: torch.Tensor, optional
+        key_padding_mask : torch.Tensor, optional
+            (B, S) where B is the batch size, S is the source sequence
+            length. If a ByteTensor is provided, the non-zero positions will
+            be ignored while the position with the zero positions will be
+            unchanged. If a BoolTensor is provided, the positions with the
+            value of True will be ignored while the position with the value
+            of False will be unchanged.
+        return_attn_weights : bool, optional
+            True to additionally return the attention weights, False otherwise.
+        pos_embs : torch.Tensor, optional
             Positional embeddings added to the attention map of shape (L, S, E) or (L, S, 1).
 
         Outputs
@@ -745,6 +775,7 @@ class MultiheadAttention(nn.Module):
         attn_output_weights : torch.Tensor
             (B, L, S) where B is the batch size, L is the target
             sequence length, S is the source sequence length.
+            This is returned only if `return_attn_weights=True` (True by default).
         """
         # give tensors of shape (time, batch, fea)
         query = query.permute(1, 0, 2)
@@ -759,7 +790,7 @@ class MultiheadAttention(nn.Module):
             else:
                 attn_mask = pos_embs
 
-        output = self.att(
+        output, attention_weights = self.att(
             query,
             key,
             value,
@@ -768,14 +799,13 @@ class MultiheadAttention(nn.Module):
             need_weights=return_attn_weights,
         )
 
+        # reshape the output back to (batch, time, fea)
+        output = output.permute(1, 0, 2)
+
         if return_attn_weights:
-            output, attention_weights = output
-            # reshape the output back to (batch, time, fea)
-            output = output.permute(1, 0, 2)
             return output, attention_weights
-        else:
-            output = output.permute(1, 0, 2)
-            return output
+
+        return output
 
 
 class PositionalwiseFeedForward(nn.Module):
diff --git a/speechbrain/nnet/hypermixing.py b/speechbrain/nnet/hypermixing.py
new file mode 100755
index 0000000000000000000000000000000000000000..7f7b8c7641931e82cbde7b86495e65572eda4278
--- /dev/null
+++ b/speechbrain/nnet/hypermixing.py
@@ -0,0 +1,372 @@
+"""This module mixes information from different tokens via HyperMixing.
+It can be viewed as a linear-time drop-in replacement for (self-)attention.
+
+source: https://arxiv.org/abs/2203.03691
+
+Authors
+ * Florian Mai 2023
+ * Juan Pablo Zuluaga 2023
+"""
+from typing import Optional
+
+import math
+
+import torch
+from torch import nn
+
+
+class HyperMixing(nn.Module):
+    """ This class implements multi-head HyperMixing.
+    It is an implementation of the token-mixing component in HyperMixer, a linear
+    time drop-in replacement for self-attention. In contrast to the original HyperMixer,
+    this module supports multiple heads, which improves the expressiveness of the model
+    while decreasing the number of parameters.
+
+    Reference: https://arxiv.org/abs/2203.03691
+
+    Arguments
+    ----------
+    input_output_dim : int
+        number of features in keys, queries, and values
+    hypernet_size : int
+        determines the size of the hidden layer of the token-mixing MLP.
+    tied : bool
+        If True, then the generated weight matrices of the token-mixing MLP are tied.
+    num_heads : int
+        parallel token-mixing MLPs.
+    fix_tm_hidden_size : bool
+        If True, the hidden-layer size is equal to hypernet_size rather than hypernet_size / num_heads.
+    max_length : int
+        Maximum number of input tokens. Needed for generating sufficiently large position embeddings.
+
+    Example
+    -------
+    >>> import torch
+    >>> inputs = torch.rand([8, 60, 512])
+    >>> net = HyperMixing(512, 2048, num_heads=8)
+    >>> outputs, attn = net(inputs, inputs, inputs)
+    >>> outputs.shape
+    torch.Size([8, 60, 512])
+    """
+
+    def __init__(
+        self,
+        input_output_dim: int,
+        hypernet_size: int,
+        tied: bool = False,
+        num_heads: int = 1,
+        fix_tm_hidden_size=False,
+        max_length=3000,
+    ) -> None:
+        super().__init__()
+        self.input_output_dim = input_output_dim
+        self.hyper = HyperNetwork(
+            input_output_dim,
+            hypernet_size,
+            tied=tied,
+            num_heads=num_heads,
+            keep_output_size=fix_tm_hidden_size,
+        )
+        self.activation = nn.GELU()
+        self.layer_norm = nn.LayerNorm(input_output_dim)
+        self.num_heads = num_heads
+
+        from speechbrain.lobes.models.transformer.Transformer import (
+            PositionalEncoding,
+        )
+
+        # add pos encoding
+        self.positional_encoding = PositionalEncoding(
+            input_output_dim, max_length
+        )
+
+    def _mlp_pass_from_components(self, out, W1, W2, activation):
+        """function to stick MLP1 together manually"""
+        out = torch.bmm(out, W1)
+        out = activation(out)
+        out = torch.bmm(out, W2.transpose(1, 2))
+        return out
+
+    def forward(
+        self,
+        query,
+        key,
+        value,
+        attn_mask: Optional[torch.Tensor] = None,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        return_attn_weights: Optional[bool] = True,
+        pos_embs: Optional[torch.Tensor] = None,
+    ):
+        """
+        The signature of this method is deliberately chosen to be the same as for
+        sb.nnet.attention.MultiHeadAttention for compatibility within SpeechBrain.
+
+        NOTE: key, value, attn_mask and pos_embs have no effect. Query is used for
+        all three. Thus, the module should only be used to replace self-attention at the moment.
+
+        Arguments
+        ----------
+        query : torch.Tensor
+            (B, L, E) where L is the target sequence length,
+            B is the batch size, E is the embedding dimension.
+        key : torch.Tensor
+            (B, S, E) where S is the source sequence length,
+            B is the batch size, E is the embedding dimension.
+            Currently unused. All
+        value : torch.Tensor
+            (B, S, E) where S is the source sequence length,
+            B is the batch size, E is the embedding dimension.
+            Currently unused.
+        attn_mask : torch.Tensor, optional
+            NOTE: Currently has NO effect.
+        key_padding_mask : torch.Tensor, optional
+            (B, S) where B is the batch size, S is the source sequence
+            length. If a ByteTensor is provided, the non-zero positions will
+            be ignored while the position with the zero positions will be
+            unchanged. If a BoolTensor is provided, the positions with the
+            value of True will be ignored while the position with the value
+            of False will be unchanged.
+        return_attn_weights: torch.Tensor, optional
+            NOTE: Currently has NO effect.
+        pos_embs: torch.Tensor, optional
+            NOTE: Currently has NO effect.
+
+        Outputs
+        -------
+        attn_output : torch.Tensor
+            (B, L, E) where L is the target sequence length, B is the
+            batch size, E is the embedding dimension.
+        attn_output_weights : torch.Tensor
+            (B, L, S) where B is the batch size, L is the target
+            sequence length, S is the source sequence length.
+            NOTE: always returns all zeros.
+        """
+
+        # NOTE: We are ignoring keys and values, because HyperMixing can only be used in the encoder atm (where it's all the same)
+        out = query
+
+        bsize = out.size(0)
+        seq_len = out.size(1)
+
+        if key_padding_mask is not None:
+            float_mask = (
+                torch.logical_not(key_padding_mask).unsqueeze(-1).float()
+            )
+            out = out * float_mask
+
+        # add position embedding before passing to hypernetwork
+        hyp_input = out + self.positional_encoding(out)
+        W1, W2 = self.hyper(
+            hyp_input
+        )  # [bsize, num_heads, seq_len, hypernet_size // num_heads]
+
+        if key_padding_mask is not None:
+            # mask the weights
+            W1 = W1 * float_mask.unsqueeze(1)
+            W2 = W2 * float_mask.unsqueeze(1)
+
+        # reshape the num_heads into the batch dimension for parallelizing
+        out = out.transpose(1, 2)  # [bsize, input_output_dim, seq_len]
+        out = out.reshape(
+            (
+                bsize * self.num_heads,
+                self.input_output_dim // self.num_heads,
+                seq_len,
+            )
+        )  # [bsize * num_heads, input_output_dim // num_heads, seq_len]
+        W1 = W1.reshape((bsize * self.num_heads, seq_len, -1))
+        W2 = W2.reshape((bsize * self.num_heads, seq_len, -1))
+
+        # we stick the token-mixing MLP together manually
+        out = self._mlp_pass_from_components(out, W1, W2, self.activation)
+
+        # concatenate heads
+        out = out.reshape((bsize, self.input_output_dim, seq_len))
+
+        # transpose back
+        out = out.transpose(1, 2)
+
+        # apply layer norm on outputs of the TM-MLP
+        out = self.layer_norm(out)
+
+        dummy_att_weights = torch.zeros(
+            (bsize, seq_len, seq_len), device=out.device
+        )
+        return out, dummy_att_weights
+
+
+class HyperNetwork(nn.Module):
+    """This class implements The HyperNetwork. It is an approach of using a one network,
+    also known as a hypernetwork, to generate the weights for another network.
+    Here, it is used to generate the labels of linear layers.
+
+    Reference: https://arxiv.org/abs/1609.09106
+
+    Arguments
+    ----------
+    input_output_dim : int
+        Dimension of the linear layers
+    hypernet_size:
+        Dimension of the HyperNetwork
+    tied : bool, optional
+        Define whether weights of layer 1 and layer 2 are shared
+    num_heads: int, optional
+        Number of heads, akin to heads in MultiHeadAttention
+    keep_output_size: bool, optional
+        Set whether to keep the same output size independent of number of heads
+    """
+
+    def __init__(
+        self,
+        input_output_dim: int,
+        hypernet_size: int,
+        tied=False,
+        num_heads=1,
+        keep_output_size=True,
+    ) -> None:
+        super(HyperNetwork, self).__init__()
+
+        # Define whether the two linear layers have tied weights
+        self.tied = tied
+        self.w1_gen = ParallelMLPs(
+            input_output_dim,
+            input_output_dim,
+            output_size=hypernet_size,
+            num_mlps=num_heads,
+            keep_output_size=keep_output_size,
+        )
+        if self.tied:
+            self.w2_gen = self.w1_gen
+        else:
+            self.w2_gen = ParallelMLPs(
+                input_output_dim,
+                input_output_dim,
+                output_size=hypernet_size,
+                num_mlps=num_heads,
+                keep_output_size=keep_output_size,
+            )
+
+    def forward(self, input_tensor: torch.Tensor):
+        """ Forward computation for a HyperNetwork.
+
+        Arguments
+        ----------
+        input_tensor : [batchsize, max_positions, d]
+            The HyperNetwork is supposed to generate an MLP of the form W_2(GELU(W1 x)), where
+            W1 : N -> k and W2 : k -> N, so it has to return tensors W1 and W2
+
+        Outputs
+        -------
+        W1 : torch.Tensor
+            Generated weights of Layer 1
+        W2 : torch.Tensor
+            Generated weights of Layer 2
+        """
+        W1 = self.w1_gen(input_tensor)
+        if self.tied:
+            W2 = W1
+        else:
+            W2 = self.w2_gen(input_tensor)
+
+        return W1, W2
+
+
+class ParallelMLPs(nn.Module):
+    """Class that implements the MultiHead HyperMixer or HyperConformer.
+
+    Arguments
+    ----------
+    input_size : int
+        Dimension of the linear layers
+    hidden_size: int
+        Dimension of the hidden layer
+    output_size : int
+        Dimension of the HyperNetwork
+    num_mlps : int
+        Number of heads, akin to heads in MultiHeadAttention
+    keep_output_size : bool, optional
+        Set whether to keep the same output size independent of number of heads
+    """
+
+    def __init__(
+        self,
+        input_size,
+        hidden_size,
+        output_size=None,
+        num_mlps=1,
+        keep_output_size=True,
+    ) -> None:
+        super(ParallelMLPs, self).__init__()
+
+        if output_size is None:
+            output_size = input_size
+
+        self.original_in_size = input_size
+        self.original_out_size = output_size
+
+        assert input_size % num_mlps == 0
+        assert output_size % num_mlps == 0
+        assert hidden_size % num_mlps == 0
+        input_size = input_size // num_mlps
+
+        if not keep_output_size:
+            output_size = output_size // num_mlps
+        hidden_size = hidden_size // num_mlps
+
+        self.input_size = input_size
+        self.output_size = output_size
+
+        self.num_mlps = num_mlps
+
+        # set the weights and biases parameters
+        self.fc1_weights = nn.Parameter(
+            torch.empty(num_mlps, hidden_size, input_size)
+        )
+        self.fc1_biases = nn.Parameter(torch.empty(num_mlps, hidden_size))
+        self.fc2_weights = nn.Parameter(
+            torch.empty(num_mlps, output_size, hidden_size)
+        )
+        self.fc2_biases = nn.Parameter(torch.empty(num_mlps, output_size))
+
+        # initialize the weights and biases
+        nn.init.xavier_uniform_(self.fc1_weights, gain=math.sqrt(2.0))
+        nn.init.xavier_uniform_(self.fc1_biases, gain=math.sqrt(2.0))
+        nn.init.xavier_uniform_(self.fc2_weights, gain=math.sqrt(2.0))
+        nn.init.xavier_uniform_(self.fc2_biases, gain=math.sqrt(2.0))
+
+        self.activation = nn.GELU()
+
+    def forward(self, x):
+        """Performs the forward computation of multi parallel MLPs.
+
+        Arguments
+        ----------
+        x : tensor
+            Input tensor
+
+        Outputs
+        -------
+        x : torch.Tensor
+            return output tensor
+        """
+
+        # x [bsize, seq_len, num_features]
+        bsize = x.size(0)
+        seq_len = x.size(1)
+
+        # Reshape the input tensor to match the number of parallel MLPs and their input size
+        x = x.reshape((bsize, seq_len, self.num_mlps, self.input_size))
+
+        # Perform the first linear transformation and add bias
+        # Using einsum so we can do it for multiple MLPs in parallel
+        x = torch.einsum(
+            "blmf,mhf->bmlh", x, self.fc1_weights
+        ) + self.fc1_biases.unsqueeze(0).unsqueeze(2)
+
+        # Apply activation function and perform the second linear transformation and add bias
+        x = self.activation(x)
+        x = torch.einsum(
+            "bmlh,mfh->bmlf", x, self.fc2_weights
+        ) + self.fc2_biases.unsqueeze(0).unsqueeze(2)
+
+        return x
diff --git a/speechbrain/nnet/linear.py b/speechbrain/nnet/linear.py
index c93209c959ca0648b422d3d021c5e2be7839a1d2..7fb007ba98198c62c37532607076660cab5c53a3 100644
--- a/speechbrain/nnet/linear.py
+++ b/speechbrain/nnet/linear.py
@@ -28,7 +28,8 @@ class Linear(torch.nn.Module):
         If True, the additive bias b is adopted.
     combine_dims : bool
         If True and the input is 4D, combine 3rd and 4th dimensions of input.
-
+    max_norm: float
+        weight max-norm.
     Example
     -------
     >>> inputs = torch.rand(10, 50, 40)
@@ -44,9 +45,11 @@ class Linear(torch.nn.Module):
         input_shape=None,
         input_size=None,
         bias=True,
+        max_norm=None,
         combine_dims=False,
     ):
         super().__init__()
+        self.max_norm = max_norm
         self.combine_dims = combine_dims
 
         if input_shape is None and input_size is None:
@@ -71,54 +74,11 @@ class Linear(torch.nn.Module):
         if x.ndim == 4 and self.combine_dims:
             x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
 
+        if self.max_norm is not None:
+            self.w.weight.data = torch.renorm(
+                self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm
+            )
+
         wx = self.w(x)
 
         return wx
-
-
-class LinearWithConstraint(Linear):
-    """Computes a linear transformation y = wx + b with kernel max-norm constaint.
-    This corresponds to set an upper bound for the kernel norm.
-
-    Arguments
-    ---------
-    n_neurons : int
-        It is the number of output neurons (i.e, the dimensionality of the
-        output).
-    input_shape: tuple
-        It is the shape of the input tensor.
-    input_size: int
-        Size of the input tensor.
-    bias : bool
-        If True, the additive bias b is adopted.
-    combine_dims : bool
-        If True and the input is 4D, combine 3rd and 4th dimensions of input.
-    max_norm : float
-        Kernel max-norm
-
-    Example
-    -------
-    >>> inputs = torch.rand(100,)
-    >>> max_norm = 1.
-    >>> lin_t_contrained = LinearWithConstraint(input_size=inputs.shape[0], n_neurons=2, max_norm=max_norm)
-    >>> output = lin_t_contrained(inputs)
-    >>> torch.any(torch.norm(lin_t_contrained.w.weight.data, p=2, dim=0)>max_norm)
-    tensor(False)
-    """
-
-    def __init__(self, *args, max_norm=1, **kwargs):
-        self.max_norm = max_norm
-        super().__init__(*args, **kwargs)
-
-    def forward(self, x):
-        """Returns the linear transformation of input tensor.
-
-        Arguments
-        ---------
-        x : torch.Tensor
-            Input to transform linearly.
-        """
-        self.w.weight.data = torch.renorm(
-            self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm
-        )
-        return super().forward(x)
diff --git a/speechbrain/nnet/loss/transducer_loss.py b/speechbrain/nnet/loss/transducer_loss.py
index 0c55939dfd47610399aad1d7553d2cd92113093e..7dbd2a95d36c0bc93bd82ebda1871e2427533c16 100644
--- a/speechbrain/nnet/loss/transducer_loss.py
+++ b/speechbrain/nnet/loss/transducer_loss.py
@@ -3,14 +3,40 @@ Transducer loss implementation (depends on numba)
 
 Authors
  * Abdelwahab Heba 2020
+ * Titouan Parcollet 2023
 """
 
 import torch
 from torch.autograd import Function
 from torch.nn import Module
+import logging
+import math
+import warnings
+
+NUMBA_VERBOSE = 0
+
+logger = logging.getLogger(__name__)
 
 try:
     from numba import cuda
+
+    # Numba is extra verbose and this may lead to log.txt file of multiple gigabytes... we deactivate
+    if not NUMBA_VERBOSE:
+        logger.info(
+            "Numba verbose is deactivated. To enable it, set NUMBA_VERBOSE to 1."
+        )
+
+        nb_logger = logging.getLogger("numba")
+        nb_logger.setLevel(logging.ERROR)  # only show error
+
+        from numba.core.errors import NumbaPerformanceWarning
+
+        warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
+    else:
+        logger.info(
+            "Numba verbose is enabled. To desactivate it, set NUMBA_VERBOSE to 0."
+        )
+
 except ImportError:
     err_msg = "The optional dependency Numba is needed to use this module\n"
     err_msg += "Cannot import numba. To use Transducer loss\n"
@@ -22,15 +48,11 @@ except ImportError:
     err_msg += "export NUMBAPRO_NVVM='/usr/local/cuda/nvvm/lib64/libnvvm.so' \n"
     err_msg += "================================ \n"
     err_msg += "If you use conda:\n"
-    err_msg += "conda install numba cudatoolkit=9.0"
+    err_msg += "conda install numba cudatoolkit"
     raise ImportError(err_msg)
 
-import math
-
 
-@cuda.jit(
-    "(float32[:,:,:,:], int32[:,:], float32[:,:,:], float32[:], int32[:], int32[:], int32, int32[:,:])"
-)
+@cuda.jit()
 def cu_kernel_forward(log_probs, labels, alpha, log_p, T, U, blank, lock):
     """
     Compute forward pass for the forward-backward algorithm using Numba cuda kernel.
@@ -106,9 +128,7 @@ def cu_kernel_forward(log_probs, labels, alpha, log_p, T, U, blank, lock):
             ) / T[b]
 
 
-@cuda.jit(
-    "(float32[:,:,:,:], int32[:,:], float32[:,:,:], float32[:], int32[:], int32[:], int32, int32[:,:])"
-)
+@cuda.jit()
 def cu_kernel_backward(log_probs, labels, beta, log_p, T, U, blank, lock):
     """
     Compute backward pass for the forward-backward algorithm using Numba cuda kernel.
@@ -180,9 +200,7 @@ def cu_kernel_backward(log_probs, labels, beta, log_p, T, U, blank, lock):
         log_p[b] = beta[b, 0, 0] / T[b]
 
 
-@cuda.jit(
-    "(float32[:,:,:,:], int32[:,:],float32[:,:,:], float32[:,:,:], float32[:,:,:,:], int32[:], int32[:], int32)"
-)
+@cuda.jit()
 def cu_kernel_compute_grad(log_probs, labels, alpha, beta, grads, T, U, blank):
     """
     Compute gradient for the forward-backward algorithm using Numba cuda kernel.
@@ -255,15 +273,23 @@ class Transducer(Function):
         log_probs = log_probs.detach()
         B, maxT, maxU, A = log_probs.shape
         grads = torch.zeros(
-            (B, maxT, maxU, A), dtype=torch.float32, device=log_probs.device
+            (B, maxT, maxU, A), dtype=log_probs.dtype, device=log_probs.device
+        )
+        alpha = torch.zeros(
+            (B, maxT, maxU), device=log_probs.device, dtype=log_probs.dtype
+        )
+        beta = torch.zeros(
+            (B, maxT, maxU), device=log_probs.device, dtype=log_probs.dtype
         )
-        alpha = torch.zeros((B, maxT, maxU), device=log_probs.device)
-        beta = torch.zeros((B, maxT, maxU), device=log_probs.device)
         lock = torch.zeros(
             (B, maxU), dtype=torch.int32, device=log_probs.device
         )
-        log_p_alpha = torch.zeros((B,), device=log_probs.device)
-        log_p_beta = torch.zeros((B,), device=log_probs.device)
+        log_p_alpha = torch.zeros(
+            (B,), device=log_probs.device, dtype=log_probs.dtype
+        )
+        log_p_beta = torch.zeros(
+            (B,), device=log_probs.device, dtype=log_probs.dtype
+        )
         cu_kernel_forward[B, maxU](
             log_probs, labels, alpha, log_p_alpha, T, U, blank, lock,
         )
diff --git a/speechbrain/nnet/losses.py b/speechbrain/nnet/losses.py
index dfcad4859e66d1bdd140933bb23f37c16f7439b5..46638e47c769a5975c261ae622e320298b4c7912 100644
--- a/speechbrain/nnet/losses.py
+++ b/speechbrain/nnet/losses.py
@@ -410,6 +410,7 @@ def nll_loss(
     length=None,
     label_smoothing=0.0,
     allowed_len_diff=3,
+    weight=None,
     reduction="mean",
 ):
     """Computes negative log likelihood loss.
@@ -425,6 +426,9 @@ def nll_loss(
         Length of each utterance, if frame-level loss is desired.
     allowed_len_diff : int
         Length difference that will be tolerated before raising an exception.
+    weight: torch.Tensor
+        A manual rescaling weight given to each class.
+        If given, has to be a Tensor of size C.
     reduction : str
         Options are 'mean', 'batch', 'batchmean', 'sum'.
         See pytorch for 'mean', 'sum'. The 'batch' option returns
@@ -443,7 +447,9 @@ def nll_loss(
         log_probabilities = log_probabilities.transpose(1, -1)
 
     # Pass the loss function but apply reduction="none" first
-    loss = functools.partial(torch.nn.functional.nll_loss, reduction="none")
+    loss = functools.partial(
+        torch.nn.functional.nll_loss, weight=weight, reduction="none"
+    )
     return compute_masked_loss(
         loss,
         log_probabilities,
@@ -959,7 +965,7 @@ def get_si_snr_with_pitwrapper(source, estimate_source):
 
 
 def get_snr_with_pitwrapper(source, estimate_source):
-    """This function wraps si_snr calculation with the speechbrain pit-wrapper.
+    """This function wraps snr calculation with the speechbrain pit-wrapper.
     Arguments:
     ---------
     source: [B, T, E, C],
diff --git a/speechbrain/nnet/pooling.py b/speechbrain/nnet/pooling.py
index be0d64a052980022fedbf82d52109610b586de6b..59e6b00b205f33d147646731b91768cca311160b 100644
--- a/speechbrain/nnet/pooling.py
+++ b/speechbrain/nnet/pooling.py
@@ -6,6 +6,7 @@ Authors
  * Nauman Dawalatabad 2020
  * Jianyuan Zhong 2020
  * Sarthak Yadav 2022
+ * Ha Nguyen 2023
 """
 
 import torch
@@ -525,3 +526,42 @@ class GaussianLowpassPooling(nn.Module):
         pad_value = get_padding_value(kernel_size)
         x = F.pad(x, pad_value, mode=self.padding_mode, value=0)
         return x
+
+
+class AttentionPooling(nn.Module):
+    """ This function implements a self-attention pooling (https://arxiv.org/abs/2008.01077).
+
+    Arguments
+    ---------
+    input_dim: int
+        The dimension of the input Tensor
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([4, 40])
+    >>> pool = AttentionPooling(input_dim=40)
+    >>> out_tensor = pool(inp_tensor)
+    """
+
+    def __init__(
+        self, input_dim,
+    ):
+        super().__init__()
+
+        self.input_dim = input_dim
+
+        # Matmul
+        self.attn_pooling_w = torch.nn.Linear(input_dim, 1)
+
+    def forward(self, x):
+        """Returns the output the adapter.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Input tensor.
+        """
+        out = self.attn_pooling_w(x).squeeze(-1).float()
+        out = torch.nn.functional.softmax(out, dim=-1).unsqueeze(-1)
+        out = torch.sum(x * out, dim=1)
+        return out
diff --git a/speechbrain/nnet/schedulers.py b/speechbrain/nnet/schedulers.py
index f5dadd42d6f72f0e030d26f30240f968bb19c7ea..9e340a23b6e43df586821215075fc71bcc10bfe1 100644
--- a/speechbrain/nnet/schedulers.py
+++ b/speechbrain/nnet/schedulers.py
@@ -55,6 +55,96 @@ def update_learning_rate(optimizer, new_lr, param_group=None):
             logger.info("Changing lr from %.2g to %.2g" % (old_lr, new_lr))
 
 
+@checkpoints.register_checkpoint_hooks
+class WarmAndExpDecayLRSchedule:
+    """Warms up linearly, and then decay exponentially to ('lr' / 'decay_factor') in 'total_steps' steps.
+
+
+    Arguments
+    ---------
+        lr : float
+            The max learning rate to reach after warmup.
+        warmup : int
+            Number of warmup steps (following a linear increase).
+        total_steps : int
+            Total number of steps (used to decay).
+        decay_factor : float
+            Decay factor applied every decay_every steps. (default: 0.01)
+
+    Example
+    -------
+    >>> from speechbrain.nnet.linear import Linear
+    >>> inp_tensor = torch.rand([1,660,3])
+    >>> model = Linear(input_size=3, n_neurons=4)
+    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
+    >>> output = model(inp_tensor)
+    >>> scheduler = WarmAndExpDecayLRSchedule(lr=1, n_warmup_steps=2, decay_factor=0.01, total_steps=6)
+    >>> scheduler(optim)
+    >>> optim.param_groups[0]["lr"]
+    0.0
+    >>> scheduler(optim)
+    >>> optim.param_groups[0]["lr"]
+    0.5
+    >>> scheduler(optim)
+    >>> optim.param_groups[0]["lr"]
+    1
+    >>> scheduler(optim)
+    >>> optim.param_groups[0]["lr"]
+    0.31622776601683794
+    """
+
+    def __init__(
+        self, lr, n_warmup_steps, total_steps, decay_factor=0.1,
+    ):
+        super(WarmAndExpDecayLRSchedule, self).__init__()
+        self.base_lr = lr
+        self.current_lr = 0
+        self.n_warmup_steps = n_warmup_steps
+        self.decay_factor = decay_factor
+        self.decay_steps = total_steps - self.n_warmup_steps
+        self.current_step = 0
+
+    def __call__(self, opt):
+        if self.current_step < self.n_warmup_steps:
+            # Warming up at the start of training.
+            lr = self.base_lr * self.current_step / self.n_warmup_steps
+        else:
+            decayed_lr = self.base_lr * self.decay_factor ** (
+                (self.current_step - self.n_warmup_steps) / self.decay_steps
+            )
+            lr = min(self.base_lr, decayed_lr)
+
+        for param_group in opt.param_groups:
+            param_group["lr"] = lr
+
+        self.current_lr = lr
+        self.current_step += 1
+
+    @checkpoints.mark_as_saver
+    def save(self, path):
+        """Saves the current metrics on the specified path."""
+        data = {
+            "base_lr": self.base_lr,
+            "n_warmup_steps": self.n_warmup_steps,
+            "decay_factor": self.decay_factor,
+            "decay_steps": self.decay_steps,
+            "current_step": self.current_step,
+        }
+        torch.save(data, path)
+
+    @checkpoints.mark_as_loader
+    def load(self, path, end_of_epoch=False, device=None):
+        """Loads the needed information."""
+        del end_of_epoch
+        del device
+        data = torch.load(path)
+        self.base_lr = data["base_lr"]
+        self.n_warmup_steps = data["n_warmup_steps"]
+        self.decay_steps = data["decay_steps"]
+        self.decay_factor = data["decay_factor"]
+        self.current_step = data["current_step"]
+
+
 @checkpoints.register_checkpoint_hooks
 class NewBobScheduler:
     """Scheduler with new-bob technique, used for LR annealing.
@@ -141,10 +231,9 @@ class NewBobScheduler:
         torch.save(data, path)
 
     @checkpoints.mark_as_loader
-    def load(self, path, end_of_epoch=False, device=None):
+    def load(self, path, end_of_epoch=False):
         """Loads the needed information."""
         del end_of_epoch  # Unused in this class
-        del device  # Unused in here
         data = torch.load(path)
         self.hyperparam_value = data["hyperparam_value"]
         self.metric_values = data["metric_values"]
@@ -274,10 +363,9 @@ class LinearWarmupScheduler:
         torch.save(data, path)
 
     @checkpoints.mark_as_loader
-    def load(self, path, end_of_epoch=False, device=None):
+    def load(self, path, end_of_epoch=False):
         """Loads the needed information."""
         del end_of_epoch  # Unused in this class
-        del device  # Unused in here
         data = torch.load(path)
         self.lr0 = data["initial_value"]
         self.num_warmup_steps = data["num_warmup_steps"]
@@ -447,10 +535,9 @@ class NoamScheduler:
         torch.save(data, path)
 
     @checkpoints.mark_as_loader
-    def load(self, path, end_of_epoch=False, device=None):
+    def load(self, path, end_of_epoch=False):
         """Loads the needed information."""
         del end_of_epoch  # Unused in this class
-        del device
         data = torch.load(path)
         self.losses = data["losses"]
         self.n_steps = data["n_steps"]
@@ -674,10 +761,9 @@ class CyclicCosineScheduler:
         torch.save(data, path)
 
     @checkpoints.mark_as_loader
-    def load(self, path, end_of_epoch=False, device=None):
+    def load(self, path, end_of_epoch=False):
         """Loads the needed information."""
         del end_of_epoch  # Unused in this class
-        del device  # Unused here
         data = torch.load(path)
         self.losses = data["losses"]
         self.n_steps = data["n_steps"]
@@ -786,10 +872,9 @@ class ReduceLROnPlateau:
         torch.save(data, path)
 
     @checkpoints.mark_as_loader
-    def load(self, path, end_of_epoch=False, device=None):
+    def load(self, path, end_of_epoch=False):
         """Loads the needed information."""
         del end_of_epoch  # Unused in this class
-        del device  # Not used
         data = torch.load(path)
         self.losses = data["losses"]
         self.anchor = data["anchor"]
@@ -959,10 +1044,9 @@ class CyclicLRScheduler:
         torch.save(data, path)
 
     @checkpoints.mark_as_loader
-    def load(self, path, end_of_epoch=False, device=None):
+    def load(self, path, end_of_epoch=False):
         """Loads the needed information."""
         del end_of_epoch  # Unused in this class
-        del device
         data = torch.load(path)
         self.losses = data["losses"]
         self.clr_iterations = data["clr_iterations"]
@@ -1064,10 +1148,9 @@ class IntervalScheduler:
         torch.save(data, path)
 
     @checkpoints.mark_as_loader
-    def load(self, path, end_of_epoch=False, device=None):
+    def load(self, path, end_of_epoch=False):
         """Loads the needed information."""
         del end_of_epoch  # Unused in this class
-        del device
         data = torch.load(path)
         self.losses = data["losses"]
         self.n_steps = data["n_steps"]
@@ -1222,10 +1305,9 @@ class WarmCoolDecayLRSchedule:
         torch.save(data, path)
 
     @checkpoints.mark_as_loader
-    def load(self, path, end_of_epoch=False, device=None):
+    def load(self, path, end_of_epoch=False):
         """Loads the needed information."""
         del end_of_epoch
-        del device
         data = torch.load(path)
         self.base_lr = data["base_lr"]
         self.warmup = data["warmup"]
@@ -1321,3 +1403,131 @@ class ScheduledLoss(nn.Module):
                 self.current_loss_fn = item["loss_fn"]
                 self.next_switch = cumulative_steps
                 break
+
+
+@checkpoints.register_checkpoint_hooks
+class TriStageLRSchedule:
+    """Warms up linearly, very slowly decays and cools down linearly again
+    at the end of training. This is a three steps scheduler.
+    Reference
+    https://arxiv.org/pdf/1904.08779.pdf
+
+    Arguments
+    ---------
+        lr : float
+            The max learning rate to reach after warmup.
+        warmup_steps : int
+            Number of warmup steps (following a linear increase).
+        hold_steps : int
+            Number of holding steps (lr remains unchanged).
+        total_steps : int
+            Total number of steps (used to decay).
+        init_lr_scale : float
+            The initial learning rate scale during warmup phase.
+        final_lr_scale : float
+            The final learning rate scale.
+
+    Example
+    -------
+    >>> from speechbrain.nnet.linear import Linear
+    >>> inp_tensor = torch.rand([1,660,3])
+    >>> model = Linear(input_size=3, n_neurons=4)
+    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
+    >>> output = model(inp_tensor)
+    >>> scheduler = TriStageLRSchedule(lr=1, warmup_steps=2, hold_steps=2, decay_steps=2, total_steps=6, init_lr_scale=0.01, final_lr_scale=0.05)
+    >>> optim.param_groups[0]["lr"]
+    1
+    >>> scheduler(optim, 1)
+    >>> optim.param_groups[0]["lr"]
+    0.505
+    >>> scheduler(optim, 2)
+    >>> optim.param_groups[0]["lr"]
+    1
+    >>> scheduler(optim, 3)
+    >>> optim.param_groups[0]["lr"]
+    1
+    >>> scheduler(optim, 4)
+    >>> optim.param_groups[0]["lr"]
+    1.0
+    >>> scheduler(optim, 5)
+    >>> optim.param_groups[0]["lr"]
+    0.223606797749979
+    >>> scheduler(optim, 6)
+    >>> optim.param_groups[0]["lr"]
+    0.05000000000000001
+    """
+
+    def __init__(
+        self,
+        lr,
+        warmup_steps,
+        hold_steps,
+        decay_steps,
+        total_steps,
+        init_lr_scale=0.01,
+        final_lr_scale=0.05,
+    ):
+        super(TriStageLRSchedule, self).__init__()
+        self.peak_lr = lr
+        self.warmup_steps = warmup_steps
+        self.hold_steps = hold_steps
+        self.decay_steps = decay_steps
+        self.total_steps = total_steps
+        self.init_lr_scale = init_lr_scale
+        self.final_lr_scale = final_lr_scale
+
+        self.init_lr = self.init_lr_scale * self.peak_lr
+        self.warmup_rate = (self.peak_lr - self.init_lr) / self.warmup_steps
+        self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps
+
+    def __call__(self, opt, num_updates):
+        """Calculate the learning rate corresponding to the current step (num_updates)."""
+        if num_updates < self.warmup_steps:
+            # Warming up at the start of training.
+            lr = self.init_lr + self.warmup_rate * num_updates
+        elif num_updates < self.warmup_steps + self.hold_steps:
+            # Hold lr unchanged.
+            lr = self.peak_lr
+        else:
+            # Decay lr
+            lr = self.peak_lr * math.exp(
+                -self.decay_factor
+                * (num_updates - self.hold_steps - self.warmup_steps)
+            )
+
+        for param_group in opt.param_groups:
+            param_group["lr"] = lr
+
+    @checkpoints.mark_as_saver
+    def save(self, path):
+        """Saves the current metrics on the specified path."""
+        data = {
+            "peak_lr": self.peak_lr,
+            "warmup_steps": self.warmup_steps,
+            "hold_steps": self.hold_steps,
+            "decay_steps": self.decay_steps,
+            "total_steps": self.total_steps,
+            "init_lr_scale": self.init_lr_scale,
+            "final_lr_scale": self.final_lr_scale,
+            "init_lr": self.init_lr,
+            "warmup_rate": self.warmup_rate,
+            "decay_factor": self.decay_factor,
+        }
+        torch.save(data, path)
+
+    @checkpoints.mark_as_loader
+    def load(self, path, end_of_epoch=False, device=None):
+        """Loads the needed information."""
+        del end_of_epoch
+        del device
+        data = torch.load(path)
+        self.peak_lr = data["peak_lr"]
+        self.warmup_steps = data["warmup_steps"]
+        self.hold_steps = data["hold_steps"]
+        self.decay_steps = data["decay_steps"]
+        self.total_steps = data["total_steps"]
+        self.init_lr_scale = data["init_lr_scale"]
+        self.final_lr_scale = data["final_lr_scale"]
+        self.init_lr = data["init_lr"]
+        self.warmup_rate = data["warmup_rate"]
+        self.decay_factor = data["decay_factor"]
diff --git a/speechbrain/pretrained/__init__.py b/speechbrain/pretrained/__init__.py
deleted file mode 100644
index 84bf3a3b01cebc126a7d16a248394370245e8f64..0000000000000000000000000000000000000000
--- a/speechbrain/pretrained/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Pretrained models"""
-
-from .interfaces import *  # noqa
diff --git a/speechbrain/pretrained/interfaces.py b/speechbrain/pretrained/interfaces.py
deleted file mode 100644
index 1c36d5077cf3fa14adee9fb854c8f3594e226647..0000000000000000000000000000000000000000
--- a/speechbrain/pretrained/interfaces.py
+++ /dev/null
@@ -1,4009 +0,0 @@
-"""Defines interfaces for simple inference with pretrained models
-
-Authors:
- * Aku Rouhe 2021
- * Peter Plantinga 2021
- * Loren Lugosch 2020
- * Mirco Ravanelli 2020
- * Titouan Parcollet 2021
- * Abdel Heba 2021
- * Andreas Nautsch 2022, 2023
- * Pooneh Mousavi 2023
- * Sylvain de Langen 2023
- * Adel Moumen 2023
-"""
-import logging
-import hashlib
-import sys
-import warnings
-import speechbrain
-import torch
-import torchaudio
-import sentencepiece
-from types import SimpleNamespace
-from torch.nn import SyncBatchNorm
-from torch.nn import DataParallel as DP
-from hyperpyyaml import load_hyperpyyaml
-from copy import copy
-from speechbrain.pretrained.fetching import fetch
-from speechbrain.dataio.preprocess import AudioNormalizer
-import torch.nn.functional as F
-from torch.nn.parallel import DistributedDataParallel as DDP
-from speechbrain.utils.data_utils import split_path
-from speechbrain.utils.distributed import run_on_main
-from speechbrain.dataio.batch import PaddedBatch, PaddedData
-from speechbrain.utils.data_pipeline import DataPipeline
-from speechbrain.utils.callchains import lengths_arg_exists
-from speechbrain.utils.superpowers import import_from_path
-from speechbrain.dataio.dataio import length_to_mask
-from speechbrain.processing.NMF import spectral_phase
-
-logger = logging.getLogger(__name__)
-
-
-def foreign_class(
-    source,
-    hparams_file="hyperparams.yaml",
-    pymodule_file="custom.py",
-    classname="CustomInterface",
-    overrides={},
-    savedir=None,
-    use_auth_token=False,
-    download_only=False,
-    **kwargs,
-):
-    """Fetch and load an interface from an outside source
-
-    The source can be a location on the filesystem or online/huggingface
-
-    The pymodule file should contain a class with the given classname. An
-    instance of that class is returned. The idea is to have a custom Pretrained
-    subclass in the file. The pymodule file is also added to the python path
-    before the Hyperparams YAML file is loaded, so it can contain any custom
-    implementations that are needed.
-
-    The hyperparams file should contain a "modules" key, which is a
-    dictionary of torch modules used for computation.
-
-    The hyperparams file should contain a "pretrainer" key, which is a
-    speechbrain.utils.parameter_transfer.Pretrainer
-
-    Arguments
-    ---------
-    source : str or Path or FetchSource
-        The location to use for finding the model. See
-        ``speechbrain.pretrained.fetching.fetch`` for details.
-    hparams_file : str
-        The name of the hyperparameters file to use for constructing
-        the modules necessary for inference. Must contain two keys:
-        "modules" and "pretrainer", as described.
-    pymodule_file : str
-        The name of the Python file that should be fetched.
-    classname : str
-        The name of the Class, of which an instance is created and returned
-    overrides : dict
-        Any changes to make to the hparams file when it is loaded.
-    savedir : str or Path
-        Where to put the pretraining material. If not given, will use
-        ./pretrained_models/<class-name>-hash(source).
-    use_auth_token : bool (default: False)
-        If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub,
-        default is False because the majority of models are public.
-    download_only : bool (default: False)
-        If true, class and instance creation is skipped.
-
-    Returns
-    -------
-    object
-        An instance of a class with the given classname from the given pymodule file.
-    """
-    if savedir is None:
-        savedir = f"./pretrained_models/{classname}-{hashlib.md5(source.encode('UTF-8', errors='replace')).hexdigest()}"
-    hparams_local_path = fetch(
-        filename=hparams_file,
-        source=source,
-        savedir=savedir,
-        overwrite=False,
-        save_filename=None,
-        use_auth_token=use_auth_token,
-        revision=None,
-    )
-    pymodule_local_path = fetch(
-        filename=pymodule_file,
-        source=source,
-        savedir=savedir,
-        overwrite=False,
-        save_filename=None,
-        use_auth_token=use_auth_token,
-        revision=None,
-    )
-    sys.path.append(str(pymodule_local_path.parent))
-
-    # Load the modules:
-    with open(hparams_local_path) as fin:
-        hparams = load_hyperpyyaml(fin, overrides)
-
-    # Pretraining:
-    pretrainer = hparams["pretrainer"]
-    pretrainer.set_collect_in(savedir)
-    # For distributed setups, have this here:
-    run_on_main(
-        pretrainer.collect_files, kwargs={"default_source": source},
-    )
-    # Load on the CPU. Later the params can be moved elsewhere by specifying
-    if not download_only:
-        # run_opts={"device": ...}
-        pretrainer.load_collected(device="cpu")
-
-        # Import class and create instance
-        module = import_from_path(pymodule_local_path)
-        cls = getattr(module, classname)
-        return cls(modules=hparams["modules"], hparams=hparams, **kwargs)
-
-
-class Pretrained(torch.nn.Module):
-    """Takes a trained model and makes predictions on new data.
-
-    This is a base class which handles some common boilerplate.
-    It intentionally has an interface similar to ``Brain`` - these base
-    classes handle similar things.
-
-    Subclasses of Pretrained should implement the actual logic of how
-    the pretrained system runs, and add methods with descriptive names
-    (e.g. transcribe_file() for ASR).
-
-    Pretrained is a torch.nn.Module so that methods like .to() or .eval() can
-    work. Subclasses should provide a suitable forward() implementation: by
-    convention, it should be a method that takes a batch of audio signals and
-    runs the full model (as applicable).
-
-    Arguments
-    ---------
-    modules : dict of str:torch.nn.Module pairs
-        The Torch modules that make up the learned system. These can be treated
-        in special ways (put on the right device, frozen, etc.). These are available
-        as attributes under ``self.mods``, like self.mods.model(x)
-    hparams : dict
-        Each key:value pair should consist of a string key and a hyperparameter
-        that is used within the overridden methods. These will
-        be accessible via an ``hparams`` attribute, using "dot" notation:
-        e.g., self.hparams.model(x).
-    run_opts : dict
-        Options parsed from command line. See ``speechbrain.parse_arguments()``.
-        List that are supported here:
-         * device
-         * data_parallel_count
-         * data_parallel_backend
-         * distributed_launch
-         * distributed_backend
-         * jit
-         * jit_module_keys
-         * compule
-         * compile_module_keys
-         * compile_mode
-         * compile_using_fullgraph
-         * compile_using_dynamic_shape_tracing
-    freeze_params : bool
-        To freeze (requires_grad=False) parameters or not. Normally in inference
-        you want to freeze the params. Also calls .eval() on all modules.
-    """
-
-    HPARAMS_NEEDED = []
-    MODULES_NEEDED = []
-
-    def __init__(
-        self, modules=None, hparams=None, run_opts=None, freeze_params=True
-    ):
-        super().__init__()
-        # Arguments passed via the run opts dictionary. Set a limited
-        # number of these, since some don't apply to inference.
-        run_opt_defaults = {
-            "device": "cpu",
-            "data_parallel_count": -1,
-            "data_parallel_backend": False,
-            "distributed_launch": False,
-            "distributed_backend": "nccl",
-            "jit": False,
-            "jit_module_keys": None,
-            "compile": False,
-            "compile_module_keys": None,
-            "compile_mode": "reduce-overhead",
-            "compile_using_fullgraph": False,
-            "compile_using_dynamic_shape_tracing": False,
-        }
-        for arg, default in run_opt_defaults.items():
-            if run_opts is not None and arg in run_opts:
-                setattr(self, arg, run_opts[arg])
-            else:
-                # If any arg from run_opt_defaults exist in hparams and
-                # not in command line args "run_opts"
-                if hparams is not None and arg in hparams:
-                    setattr(self, arg, hparams[arg])
-                else:
-                    setattr(self, arg, default)
-
-        # Put modules on the right device, accessible with dot notation
-        self.mods = torch.nn.ModuleDict(modules)
-        for module in self.mods.values():
-            if module is not None:
-                module.to(self.device)
-
-        # Check MODULES_NEEDED and HPARAMS_NEEDED and
-        # make hyperparams available with dot notation
-        if self.HPARAMS_NEEDED and hparams is None:
-            raise ValueError("Need to provide hparams dict.")
-        if hparams is not None:
-            # Also first check that all required params are found:
-            for hp in self.HPARAMS_NEEDED:
-                if hp not in hparams:
-                    raise ValueError(f"Need hparams['{hp}']")
-            self.hparams = SimpleNamespace(**hparams)
-
-        # Prepare modules for computation, e.g. jit
-        self._prepare_modules(freeze_params)
-
-        # Audio normalization
-        self.audio_normalizer = hparams.get(
-            "audio_normalizer", AudioNormalizer()
-        )
-
-    def _prepare_modules(self, freeze_params):
-        """Prepare modules for computation, e.g. jit.
-
-        Arguments
-        ---------
-        freeze_params : bool
-            Whether to freeze the parameters and call ``eval()``.
-        """
-
-        # Make jit-able
-        self._compile()
-        self._wrap_distributed()
-
-        # If we don't want to backprop, freeze the pretrained parameters
-        if freeze_params:
-            self.mods.eval()
-            for p in self.mods.parameters():
-                p.requires_grad = False
-
-    def load_audio(self, path, savedir="audio_cache", **kwargs):
-        """Load an audio file with this model's input spec
-
-        When using a speech model, it is important to use the same type of data,
-        as was used to train the model. This means for example using the same
-        sampling rate and number of channels. It is, however, possible to
-        convert a file from a higher sampling rate to a lower one (downsampling).
-        Similarly, it is simple to downmix a stereo file to mono.
-        The path can be a local path, a web url, or a link to a huggingface repo.
-        """
-        source, fl = split_path(path)
-        kwargs = copy(kwargs)  # shallow copy of references only
-        channels_first = kwargs.pop(
-            "channels_first", False
-        )  # False as default value: SB consistent tensor format
-        if kwargs:
-            fetch_kwargs = dict()
-            for key in [
-                "overwrite",
-                "save_filename",
-                "use_auth_token",
-                "revision",
-                "cache_dir",
-                "silent_local_fetch",
-            ]:
-                if key in kwargs:
-                    fetch_kwargs[key] = kwargs.pop(key)
-            path = fetch(fl, source=source, savedir=savedir, **fetch_kwargs)
-        else:
-            path = fetch(fl, source=source, savedir=savedir)
-        signal, sr = torchaudio.load(
-            str(path), channels_first=channels_first, **kwargs
-        )
-        return self.audio_normalizer(signal, sr)
-
-    def _compile(self):
-        """Compile requested modules with either JIT or TorchInductor."""
-        compile_available = hasattr(torch, "compile")
-
-        if not compile_available and self.compile_module_keys is not None:
-            raise ValueError(
-                "'compile_module_keys' specified, but this install of PyTorch "
-                "seems to be too old to support it."
-            )
-
-        # Modules to compile with torch.compile
-        compile_module_keys = set()
-        if self.compile:
-            if self.compile_module_keys is None:
-                compile_module_keys = set(self.mods)
-            else:
-                compile_module_keys = set(self.compile_module_keys)
-                logger.warning(
-                    "--compile and --compile_module_keys are both specified. "
-                    "Only modules specified in --compile_module_keys will be compiled."
-                )
-
-        # Modules to compile with jit
-        jit_module_keys = set()
-        if self.jit:
-            if self.jit_module_keys is None:
-                jit_module_keys = set(self.mods)
-            else:
-                jit_module_keys = set(self.jit_module_keys)
-                logger.warning(
-                    "--jit and --jit_module_keys are both specified. "
-                    "Only modules specified in --jit_module_keys will be compiled."
-                )
-
-        # find missing keys
-        for name in compile_module_keys | jit_module_keys:
-            if name not in self.mods:
-                raise ValueError(
-                    f"module {name} is not defined in your hparams file."
-                )
-
-        # try 'torch.compile', remove successful compiles from JIT list
-        for name in compile_module_keys:
-            try:
-                module = torch.compile(
-                    self.mods[name],
-                    mode=self.compile_mode,
-                    fullgraph=self.compile_using_fullgraph,
-                    dynamic=self.compile_using_dynamic_shape_tracing,
-                )
-            except Exception as e:
-                logger.warning(
-                    f"'{name}' in 'compile_module_keys' failed to compile "
-                    f"and will be skipped (may fallback onto JIT, if "
-                    f"specified): {e}"
-                )
-                continue
-
-            self.mods[name] = module.to(self.device)
-            jit_module_keys.discard(name)
-
-        for name in jit_module_keys:
-            module = torch.jit.script(self.mods[name])
-            self.mods[name] = module.to(self.device)
-
-    def _compile_jit(self):
-        warnings.warn("'_compile_jit' is deprecated; use '_compile' instead")
-        self._compile()
-
-    def _wrap_distributed(self):
-        """Wrap modules with distributed wrapper when requested."""
-        if not self.distributed_launch and not self.data_parallel_backend:
-            return
-        elif self.distributed_launch:
-            for name, module in self.mods.items():
-                if any(p.requires_grad for p in module.parameters()):
-                    # for ddp, all module must run on same GPU
-                    module = SyncBatchNorm.convert_sync_batchnorm(module)
-                    module = DDP(module, device_ids=[self.device])
-                    self.mods[name] = module
-        else:
-            # data_parallel_backend
-            for name, module in self.mods.items():
-                if any(p.requires_grad for p in module.parameters()):
-                    # if distributed_count = -1 then use all gpus
-                    # otherwise, specify the set of gpu to use
-                    if self.data_parallel_count == -1:
-                        module = DP(module)
-                    else:
-                        module = DP(
-                            module, [i for i in range(self.data_parallel_count)]
-                        )
-                    self.mods[name] = module
-
-    @classmethod
-    def from_hparams(
-        cls,
-        source,
-        hparams_file="hyperparams.yaml",
-        pymodule_file="custom.py",
-        overrides={},
-        savedir=None,
-        use_auth_token=False,
-        revision=None,
-        download_only=False,
-        **kwargs,
-    ):
-        """Fetch and load based from outside source based on HyperPyYAML file
-
-        The source can be a location on the filesystem or online/huggingface
-
-        You can use the pymodule_file to include any custom implementations
-        that are needed: if that file exists, then its location is added to
-        sys.path before Hyperparams YAML is loaded, so it can be referenced
-        in the YAML.
-
-        The hyperparams file should contain a "modules" key, which is a
-        dictionary of torch modules used for computation.
-
-        The hyperparams file should contain a "pretrainer" key, which is a
-        speechbrain.utils.parameter_transfer.Pretrainer
-
-        Arguments
-        ---------
-        source : str or Path or FetchSource
-            The location to use for finding the model. See
-            ``speechbrain.pretrained.fetching.fetch`` for details.
-        hparams_file : str
-            The name of the hyperparameters file to use for constructing
-            the modules necessary for inference. Must contain two keys:
-            "modules" and "pretrainer", as described.
-        pymodule_file : str
-            A Python file can be fetched. This allows any custom
-            implementations to be included. The file's location is added to
-            sys.path before the hyperparams YAML file is loaded, so it can be
-            referenced in YAML.
-            This is optional, but has a default: "custom.py". If the default
-            file is not found, this is simply ignored, but if you give a
-            different filename, then this will raise in case the file is not
-            found.
-        overrides : dict
-            Any changes to make to the hparams file when it is loaded.
-        savedir : str or Path
-            Where to put the pretraining material. If not given, will use
-            ./pretrained_models/<class-name>-hash(source).
-        use_auth_token : bool (default: False)
-            If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub,
-            default is False because the majority of models are public.
-        revision : str
-            The model revision corresponding to the HuggingFace Hub model revision.
-            This is particularly useful if you wish to pin your code to a particular
-            version of a model hosted at HuggingFace.
-        download_only : bool (default: False)
-            If true, class and instance creation is skipped.
-        """
-        if savedir is None:
-            clsname = cls.__name__
-            savedir = f"./pretrained_models/{clsname}-{hashlib.md5(source.encode('UTF-8', errors='replace')).hexdigest()}"
-        hparams_local_path = fetch(
-            filename=hparams_file,
-            source=source,
-            savedir=savedir,
-            overwrite=False,
-            save_filename=None,
-            use_auth_token=use_auth_token,
-            revision=revision,
-        )
-        try:
-            pymodule_local_path = fetch(
-                filename=pymodule_file,
-                source=source,
-                savedir=savedir,
-                overwrite=False,
-                save_filename=None,
-                use_auth_token=use_auth_token,
-                revision=revision,
-            )
-            sys.path.append(str(pymodule_local_path.parent))
-        except ValueError:
-            if pymodule_file == "custom.py":
-                # The optional custom Python module file did not exist
-                # and had the default name
-                pass
-            else:
-                # Custom Python module file not found, but some other
-                # filename than the default was given.
-                raise
-
-        # Load the modules:
-        with open(hparams_local_path) as fin:
-            hparams = load_hyperpyyaml(fin, overrides)
-
-        # Pretraining:
-        pretrainer = hparams["pretrainer"]
-        pretrainer.set_collect_in(savedir)
-        # For distributed setups, have this here:
-        run_on_main(
-            pretrainer.collect_files, kwargs={"default_source": source},
-        )
-        # Load on the CPU. Later the params can be moved elsewhere by specifying
-        if not download_only:
-            # run_opts={"device": ...}
-            pretrainer.load_collected(device="cpu")
-
-            # Now return the system
-            return cls(hparams["modules"], hparams, **kwargs)
-
-
-class EndToEndSLU(Pretrained):
-    """An end-to-end SLU model.
-
-    The class can be used either to run only the encoder (encode()) to extract
-    features or to run the entire model (decode()) to map the speech to its semantics.
-
-    Example
-    -------
-    >>> from speechbrain.pretrained import EndToEndSLU
-    >>> tmpdir = getfixture("tmpdir")
-    >>> slu_model = EndToEndSLU.from_hparams(
-    ...     source="speechbrain/slu-timers-and-such-direct-librispeech-asr",
-    ...     savedir=tmpdir,
-    ... )
-    >>> slu_model.decode_file("tests/samples/single-mic/example6.wav")
-    "{'intent': 'SimpleMath', 'slots': {'number1': 37.67, 'number2': 75.7, 'op': ' minus '}}"
-    """
-
-    HPARAMS_NEEDED = ["tokenizer", "asr_model_source"]
-    MODULES_NEEDED = ["slu_enc", "beam_searcher"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.tokenizer = self.hparams.tokenizer
-        self.asr_model = EncoderDecoderASR.from_hparams(
-            source=self.hparams.asr_model_source,
-            run_opts={"device": self.device},
-        )
-
-    def decode_file(self, path, **kwargs):
-        """Maps the given audio file to a string representing the
-        semantic dictionary for the utterance.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file to decode.
-
-        Returns
-        -------
-        str
-            The predicted semantics.
-        """
-        waveform = self.load_audio(path, **kwargs)
-        waveform = waveform.to(self.device)
-        # Fake a batch:
-        batch = waveform.unsqueeze(0)
-        rel_length = torch.tensor([1.0])
-        predicted_words, predicted_tokens = self.decode_batch(batch, rel_length)
-        return predicted_words[0]
-
-    def encode_batch(self, wavs, wav_lens):
-        """Encodes the input audio into a sequence of hidden states
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        torch.Tensor
-            The encoded batch
-        """
-        wavs = wavs.float()
-        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-        ASR_encoder_out = self.asr_model.encode_batch(wavs.detach(), wav_lens)
-        encoder_out = self.mods.slu_enc(ASR_encoder_out)
-        return encoder_out
-
-    def decode_batch(self, wavs, wav_lens):
-        """Maps the input audio to its semantics
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        list
-            Each waveform in the batch decoded.
-        tensor
-            Each predicted token id.
-        """
-        with torch.no_grad():
-            wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-            encoder_out = self.encode_batch(wavs, wav_lens)
-            predicted_tokens, scores = self.mods.beam_searcher(
-                encoder_out, wav_lens
-            )
-            predicted_words = [
-                self.tokenizer.decode_ids(token_seq)
-                for token_seq in predicted_tokens
-            ]
-        return predicted_words, predicted_tokens
-
-    def forward(self, wavs, wav_lens):
-        """Runs full decoding - note: no gradients through decoding"""
-        return self.decode_batch(wavs, wav_lens)
-
-
-class EncoderDecoderASR(Pretrained):
-    """A ready-to-use Encoder-Decoder ASR model
-
-    The class can be used either to run only the encoder (encode()) to extract
-    features or to run the entire encoder-decoder model
-    (transcribe()) to transcribe speech. The given YAML must contain the fields
-    specified in the *_NEEDED[] lists.
-
-    Example
-    -------
-    >>> from speechbrain.pretrained import EncoderDecoderASR
-    >>> tmpdir = getfixture("tmpdir")
-    >>> asr_model = EncoderDecoderASR.from_hparams(
-    ...     source="speechbrain/asr-crdnn-rnnlm-librispeech",
-    ...     savedir=tmpdir,
-    ... )
-    >>> asr_model.transcribe_file("tests/samples/single-mic/example2.flac")
-    "MY FATHER HAS REVEALED THE CULPRIT'S NAME"
-    """
-
-    HPARAMS_NEEDED = ["tokenizer"]
-    MODULES_NEEDED = ["encoder", "decoder"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.tokenizer = self.hparams.tokenizer
-
-    def transcribe_file(self, path, **kwargs):
-        """Transcribes the given audiofile into a sequence of words.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file which to transcribe.
-
-        Returns
-        -------
-        str
-            The audiofile transcription produced by this ASR system.
-        """
-        waveform = self.load_audio(path, **kwargs)
-        # Fake a batch:
-        batch = waveform.unsqueeze(0)
-        rel_length = torch.tensor([1.0])
-        predicted_words, predicted_tokens = self.transcribe_batch(
-            batch, rel_length
-        )
-        return predicted_words[0]
-
-    def encode_batch(self, wavs, wav_lens):
-        """Encodes the input audio into a sequence of hidden states
-
-        The waveforms should already be in the model's desired format.
-        You can call:
-        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
-        to get a correctly converted signal in most cases.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        torch.Tensor
-            The encoded batch
-        """
-        wavs = wavs.float()
-        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-        encoder_out = self.mods.encoder(wavs, wav_lens)
-        return encoder_out
-
-    def transcribe_batch(self, wavs, wav_lens):
-        """Transcribes the input audio into a sequence of words
-
-        The waveforms should already be in the model's desired format.
-        You can call:
-        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
-        to get a correctly converted signal in most cases.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        list
-            Each waveform in the batch transcribed.
-        tensor
-            Each predicted token id.
-        """
-        with torch.no_grad():
-            wav_lens = wav_lens.to(self.device)
-            encoder_out = self.encode_batch(wavs, wav_lens)
-            predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens)
-            predicted_words = [
-                self.tokenizer.decode_ids(token_seq)
-                for token_seq in predicted_tokens
-            ]
-        return predicted_words, predicted_tokens
-
-    def forward(self, wavs, wav_lens):
-        """Runs full transcription - note: no gradients through decoding"""
-        return self.transcribe_batch(wavs, wav_lens)
-
-
-class WaveformEncoder(Pretrained):
-    """A ready-to-use waveformEncoder model
-
-    It can be used to wrap different embedding models such as SSL ones (wav2vec2)
-    or speaker ones (Xvector) etc. Two functions are available: encode_batch and
-    encode_file. They can be used to obtain the embeddings directly from an audio
-    file or from a batch of audio tensors respectively.
-
-    The given YAML must contain the fields specified in the *_NEEDED[] lists.
-
-    Example
-    -------
-    >>> from speechbrain.pretrained import WaveformEncoder
-    >>> tmpdir = getfixture("tmpdir")
-    >>> ssl_model = WaveformEncoder.from_hparams(
-    ...     source="speechbrain/ssl-wav2vec2-base-libri",
-    ...     savedir=tmpdir,
-    ... ) # doctest: +SKIP
-    >>> ssl_model.encode_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
-    """
-
-    MODULES_NEEDED = ["encoder"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def encode_file(self, path, **kwargs):
-        """Encode the given audiofile into a sequence of embeddings.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file which to encode.
-
-        Returns
-        -------
-        torch.Tensor
-            The audiofile embeddings produced by this system.
-        """
-        waveform = self.load_audio(path, **kwargs)
-        # Fake a batch:
-        batch = waveform.unsqueeze(0)
-        rel_length = torch.tensor([1.0])
-        results = self.encode_batch(batch, rel_length)
-        return results["embeddings"]
-
-    def encode_batch(self, wavs, wav_lens):
-        """Encodes the input audio into a sequence of hidden states
-
-        The waveforms should already be in the model's desired format.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        torch.Tensor
-            The encoded batch
-        """
-        wavs = wavs.float()
-        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-        encoder_out = self.mods.encoder(wavs, wav_lens)
-        return encoder_out
-
-    def forward(self, wavs, wav_lens):
-        """Runs the encoder"""
-        return self.encode_batch(wavs, wav_lens)
-
-
-class EncoderASR(Pretrained):
-    """A ready-to-use Encoder ASR model
-
-    The class can be used either to run only the encoder (encode()) to extract
-    features or to run the entire encoder + decoder function model
-    (transcribe()) to transcribe speech. The given YAML must contain the fields
-    specified in the *_NEEDED[] lists.
-
-    Example
-    -------
-    >>> from speechbrain.pretrained import EncoderASR
-    >>> tmpdir = getfixture("tmpdir")
-    >>> asr_model = EncoderASR.from_hparams(
-    ...     source="speechbrain/asr-wav2vec2-commonvoice-fr",
-    ...     savedir=tmpdir,
-    ... ) # doctest: +SKIP
-    >>> asr_model.transcribe_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
-    """
-
-    HPARAMS_NEEDED = ["tokenizer", "decoding_function"]
-    MODULES_NEEDED = ["encoder"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-        self.tokenizer = self.hparams.tokenizer
-        self.decoding_function = self.hparams.decoding_function
-
-    def transcribe_file(self, path, **kwargs):
-        """Transcribes the given audiofile into a sequence of words.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file which to transcribe.
-
-        Returns
-        -------
-        str
-            The audiofile transcription produced by this ASR system.
-        """
-        waveform = self.load_audio(path, **kwargs)
-        # Fake a batch:
-        batch = waveform.unsqueeze(0)
-        rel_length = torch.tensor([1.0])
-        predicted_words, predicted_tokens = self.transcribe_batch(
-            batch, rel_length
-        )
-        return str(predicted_words[0])
-
-    def encode_batch(self, wavs, wav_lens):
-        """Encodes the input audio into a sequence of hidden states
-
-        The waveforms should already be in the model's desired format.
-        You can call:
-        ``normalized = EncoderASR.normalizer(signal, sample_rate)``
-        to get a correctly converted signal in most cases.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        torch.Tensor
-            The encoded batch
-        """
-        wavs = wavs.float()
-        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-        encoder_out = self.mods.encoder(wavs, wav_lens)
-        return encoder_out
-
-    def transcribe_batch(self, wavs, wav_lens):
-        """Transcribes the input audio into a sequence of words
-
-        The waveforms should already be in the model's desired format.
-        You can call:
-        ``normalized = EncoderASR.normalizer(signal, sample_rate)``
-        to get a correctly converted signal in most cases.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        list
-            Each waveform in the batch transcribed.
-        tensor
-            Each predicted token id.
-        """
-        with torch.no_grad():
-            wav_lens = wav_lens.to(self.device)
-            encoder_out = self.encode_batch(wavs, wav_lens)
-            predictions = self.decoding_function(encoder_out, wav_lens)
-            if isinstance(
-                self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder
-            ):
-                predicted_words = [
-                    "".join(self.tokenizer.decode_ndim(token_seq))
-                    for token_seq in predictions
-                ]
-            elif isinstance(
-                self.tokenizer, sentencepiece.SentencePieceProcessor
-            ):
-                predicted_words = [
-                    self.tokenizer.decode_ids(token_seq)
-                    for token_seq in predictions
-                ]
-            else:
-                raise ValueError(
-                    "The tokenizer must be sentencepiece or CTCTextEncoder"
-                )
-
-        return predicted_words, predictions
-
-    def forward(self, wavs, wav_lens):
-        """Runs the encoder"""
-        return self.encode_batch(wavs, wav_lens)
-
-
-class EncoderClassifier(Pretrained):
-    """A ready-to-use class for utterance-level classification (e.g, speaker-id,
-    language-id, emotion recognition, keyword spotting, etc).
-
-    The class assumes that an encoder called "embedding_model" and a model
-    called "classifier" are defined in the yaml file. If you want to
-    convert the predicted index into a corresponding text label, please
-    provide the path of the label_encoder in a variable called 'lab_encoder_file'
-    within the yaml.
-
-    The class can be used either to run only the encoder (encode_batch()) to
-    extract embeddings or to run a classification step (classify_batch()).
-    ```
-
-    Example
-    -------
-    >>> import torchaudio
-    >>> from speechbrain.pretrained import EncoderClassifier
-    >>> # Model is downloaded from the speechbrain HuggingFace repo
-    >>> tmpdir = getfixture("tmpdir")
-    >>> classifier = EncoderClassifier.from_hparams(
-    ...     source="speechbrain/spkrec-ecapa-voxceleb",
-    ...     savedir=tmpdir,
-    ... )
-    >>> classifier.hparams.label_encoder.ignore_len()
-
-    >>> # Compute embeddings
-    >>> signal, fs = torchaudio.load("tests/samples/single-mic/example1.wav")
-    >>> embeddings = classifier.encode_batch(signal)
-
-    >>> # Classification
-    >>> prediction = classifier.classify_batch(signal)
-    """
-
-    MODULES_NEEDED = [
-        "compute_features",
-        "mean_var_norm",
-        "embedding_model",
-        "classifier",
-    ]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def encode_batch(self, wavs, wav_lens=None, normalize=False):
-        """Encodes the input audio into a single vector embedding.
-
-        The waveforms should already be in the model's desired format.
-        You can call:
-        ``normalized = <this>.normalizer(signal, sample_rate)``
-        to get a correctly converted signal in most cases.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model. Make sure the sample rate is fs=16000 Hz.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-        normalize : bool
-            If True, it normalizes the embeddings with the statistics
-            contained in mean_var_norm_emb.
-
-        Returns
-        -------
-        torch.Tensor
-            The encoded batch
-        """
-        # Manage single waveforms in input
-        if len(wavs.shape) == 1:
-            wavs = wavs.unsqueeze(0)
-
-        # Assign full length if wav_lens is not assigned
-        if wav_lens is None:
-            wav_lens = torch.ones(wavs.shape[0], device=self.device)
-
-        # Storing waveform in the specified device
-        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-        wavs = wavs.float()
-
-        # Computing features and embeddings
-        feats = self.mods.compute_features(wavs)
-        feats = self.mods.mean_var_norm(feats, wav_lens)
-        embeddings = self.mods.embedding_model(feats, wav_lens)
-        if normalize:
-            embeddings = self.hparams.mean_var_norm_emb(
-                embeddings, torch.ones(embeddings.shape[0], device=self.device)
-            )
-        return embeddings
-
-    def classify_batch(self, wavs, wav_lens=None):
-        """Performs classification on the top of the encoded features.
-
-        It returns the posterior probabilities, the index and, if the label
-        encoder is specified it also the text label.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model. Make sure the sample rate is fs=16000 Hz.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        out_prob
-            The log posterior probabilities of each class ([batch, N_class])
-        score:
-            It is the value of the log-posterior for the best class ([batch,])
-        index
-            The indexes of the best class ([batch,])
-        text_lab:
-            List with the text labels corresponding to the indexes.
-            (label encoder should be provided).
-        """
-        emb = self.encode_batch(wavs, wav_lens)
-        out_prob = self.mods.classifier(emb).squeeze(1)
-        score, index = torch.max(out_prob, dim=-1)
-        text_lab = self.hparams.label_encoder.decode_torch(index)
-        return out_prob, score, index, text_lab
-
-    def classify_file(self, path, **kwargs):
-        """Classifies the given audiofile into the given set of labels.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file to classify.
-
-        Returns
-        -------
-        out_prob
-            The log posterior probabilities of each class ([batch, N_class])
-        score:
-            It is the value of the log-posterior for the best class ([batch,])
-        index
-            The indexes of the best class ([batch,])
-        text_lab:
-            List with the text labels corresponding to the indexes.
-            (label encoder should be provided).
-        """
-        waveform = self.load_audio(path, **kwargs)
-        # Fake a batch:
-        batch = waveform.unsqueeze(0)
-        rel_length = torch.tensor([1.0])
-        emb = self.encode_batch(batch, rel_length)
-        out_prob = self.mods.classifier(emb).squeeze(1)
-        score, index = torch.max(out_prob, dim=-1)
-        text_lab = self.hparams.label_encoder.decode_torch(index)
-        return out_prob, score, index, text_lab
-
-    def forward(self, wavs, wav_lens=None):
-        """Runs the classification"""
-        return self.classify_batch(wavs, wav_lens)
-
-
-class SpeakerRecognition(EncoderClassifier):
-    """A ready-to-use model for speaker recognition. It can be used to
-    perform speaker verification with verify_batch().
-
-    ```
-    Example
-    -------
-    >>> import torchaudio
-    >>> from speechbrain.pretrained import SpeakerRecognition
-    >>> # Model is downloaded from the speechbrain HuggingFace repo
-    >>> tmpdir = getfixture("tmpdir")
-    >>> verification = SpeakerRecognition.from_hparams(
-    ...     source="speechbrain/spkrec-ecapa-voxceleb",
-    ...     savedir=tmpdir,
-    ... )
-
-    >>> # Perform verification
-    >>> signal, fs = torchaudio.load("tests/samples/single-mic/example1.wav")
-    >>> signal2, fs = torchaudio.load("tests/samples/single-mic/example2.flac")
-    >>> score, prediction = verification.verify_batch(signal, signal2)
-    """
-
-    MODULES_NEEDED = [
-        "compute_features",
-        "mean_var_norm",
-        "embedding_model",
-        "mean_var_norm_emb",
-    ]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
-
-    def verify_batch(
-        self, wavs1, wavs2, wav1_lens=None, wav2_lens=None, threshold=0.25
-    ):
-        """Performs speaker verification with cosine distance.
-
-        It returns the score and the decision (0 different speakers,
-        1 same speakers).
-
-        Arguments
-        ---------
-        wavs1 : Torch.Tensor
-                Tensor containing the speech waveform1 (batch, time).
-                Make sure the sample rate is fs=16000 Hz.
-        wavs2 : Torch.Tensor
-                Tensor containing the speech waveform2 (batch, time).
-                Make sure the sample rate is fs=16000 Hz.
-        wav1_lens: Torch.Tensor
-                Tensor containing the relative length for each sentence
-                in the length (e.g., [0.8 0.6 1.0])
-        wav2_lens: Torch.Tensor
-                Tensor containing the relative length for each sentence
-                in the length (e.g., [0.8 0.6 1.0])
-        threshold: Float
-                Threshold applied to the cosine distance to decide if the
-                speaker is different (0) or the same (1).
-
-        Returns
-        -------
-        score
-            The score associated to the binary verification output
-            (cosine distance).
-        prediction
-            The prediction is 1 if the two signals in input are from the same
-            speaker and 0 otherwise.
-        """
-        emb1 = self.encode_batch(wavs1, wav1_lens, normalize=False)
-        emb2 = self.encode_batch(wavs2, wav2_lens, normalize=False)
-        score = self.similarity(emb1, emb2)
-        return score, score > threshold
-
-    def verify_files(self, path_x, path_y, **kwargs):
-        """Speaker verification with cosine distance
-
-        Returns the score and the decision (0 different speakers,
-        1 same speakers).
-
-        Returns
-        -------
-        score
-            The score associated to the binary verification output
-            (cosine distance).
-        prediction
-            The prediction is 1 if the two signals in input are from the same
-            speaker and 0 otherwise.
-        """
-        waveform_x = self.load_audio(path_x, **kwargs)
-        waveform_y = self.load_audio(path_y, **kwargs)
-        # Fake batches:
-        batch_x = waveform_x.unsqueeze(0)
-        batch_y = waveform_y.unsqueeze(0)
-        # Verify:
-        score, decision = self.verify_batch(batch_x, batch_y)
-        # Squeeze:
-        return score[0], decision[0]
-
-
-class VAD(Pretrained):
-    """A ready-to-use class for Voice Activity Detection (VAD) using a
-    pre-trained model.
-
-    Example
-    -------
-    >>> import torchaudio
-    >>> from speechbrain.pretrained import VAD
-    >>> # Model is downloaded from the speechbrain HuggingFace repo
-    >>> tmpdir = getfixture("tmpdir")
-    >>> VAD = VAD.from_hparams(
-    ...     source="speechbrain/vad-crdnn-libriparty",
-    ...     savedir=tmpdir,
-    ... )
-
-    >>> # Perform VAD
-    >>> boundaries = VAD.get_speech_segments("tests/samples/single-mic/example1.wav")
-    """
-
-    HPARAMS_NEEDED = ["sample_rate", "time_resolution", "device"]
-
-    MODULES_NEEDED = ["compute_features", "mean_var_norm", "model"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.time_resolution = self.hparams.time_resolution
-        self.sample_rate = self.hparams.sample_rate
-
-    def get_speech_prob_file(
-        self,
-        audio_file,
-        large_chunk_size=30,
-        small_chunk_size=10,
-        overlap_small_chunk=False,
-    ):
-        """Outputs the frame-level speech probability of the input audio file
-        using the neural model specified in the hparam file. To make this code
-        both parallelizable and scalable to long sequences, it uses a
-        double-windowing approach.  First, we sequentially read non-overlapping
-        large chunks of the input signal.  We then split the large chunks into
-        smaller chunks and we process them in parallel.
-
-        Arguments
-        ---------
-        audio_file: path
-            Path of the audio file containing the recording. The file is read
-            with torchaudio.
-        large_chunk_size: float
-            Size (in seconds) of the large chunks that are read sequentially
-            from the input audio file.
-        small_chunk_size:
-            Size (in seconds) of the small chunks extracted from the large ones.
-            The audio signal is processed in parallel within the small chunks.
-            Note that large_chunk_size/small_chunk_size must be an integer.
-        overlap_small_chunk: bool
-            True, creates overlapped small chunks. The probabilities of the
-            overlapped chunks are combined using hamming windows.
-
-        Returns
-        -------
-        prob_vad: torch.Tensor
-            Tensor containing the frame-level speech probabilities for the
-            input audio file.
-        """
-        # Getting the total size of the input file
-        sample_rate, audio_len = self._get_audio_info(audio_file)
-
-        if sample_rate != self.sample_rate:
-            raise ValueError(
-                "The detected sample rate is different from that set in the hparam file"
-            )
-
-        # Computing the length (in samples) of the large and small chunks
-        long_chunk_len = int(sample_rate * large_chunk_size)
-        small_chunk_len = int(sample_rate * small_chunk_size)
-
-        # Setting the step size of the small chunk (50% overlapping windows are supported)
-        small_chunk_step = small_chunk_size
-        if overlap_small_chunk:
-            small_chunk_step = small_chunk_size / 2
-
-        # Computing the length (in sample) of the small_chunk step size
-        small_chunk_len_step = int(sample_rate * small_chunk_step)
-
-        # Loop over big chunks
-        prob_chunks = []
-        last_chunk = False
-        begin_sample = 0
-        while True:
-
-            # Reading the big chunk
-            large_chunk, fs = torchaudio.load(
-                audio_file, frame_offset=begin_sample, num_frames=long_chunk_len
-            )
-            large_chunk = large_chunk.to(self.device)
-
-            # Manage padding of the last small chunk
-            if last_chunk or large_chunk.shape[-1] < small_chunk_len:
-                padding = torch.zeros(
-                    1, small_chunk_len, device=large_chunk.device
-                )
-                large_chunk = torch.cat([large_chunk, padding], dim=1)
-
-            # Splitting the big chunk into smaller (overlapped) ones
-            small_chunks = torch.nn.functional.unfold(
-                large_chunk.unsqueeze(1).unsqueeze(2),
-                kernel_size=(1, small_chunk_len),
-                stride=(1, small_chunk_len_step),
-            )
-            small_chunks = small_chunks.squeeze(0).transpose(0, 1)
-
-            # Getting (in parallel) the frame-level speech probabilities
-            small_chunks_prob = self.get_speech_prob_chunk(small_chunks)
-            small_chunks_prob = small_chunks_prob[:, :-1, :]
-
-            # Manage overlapping chunks
-            if overlap_small_chunk:
-                small_chunks_prob = self._manage_overlapped_chunks(
-                    small_chunks_prob
-                )
-
-            # Prepare for folding
-            small_chunks_prob = small_chunks_prob.permute(2, 1, 0)
-
-            # Computing lengths in samples
-            out_len = int(
-                large_chunk.shape[-1] / (sample_rate * self.time_resolution)
-            )
-            kernel_len = int(small_chunk_size / self.time_resolution)
-            step_len = int(small_chunk_step / self.time_resolution)
-
-            # Folding the frame-level predictions
-            small_chunks_prob = torch.nn.functional.fold(
-                small_chunks_prob,
-                output_size=(1, out_len),
-                kernel_size=(1, kernel_len),
-                stride=(1, step_len),
-            )
-
-            # Appending the frame-level speech probabilities of the large chunk
-            small_chunks_prob = small_chunks_prob.squeeze(1).transpose(-1, -2)
-            prob_chunks.append(small_chunks_prob)
-
-            # Check stop condition
-            if last_chunk:
-                break
-
-            # Update counter to process the next big chunk
-            begin_sample = begin_sample + long_chunk_len
-
-            # Check if the current chunk is the last one
-            if begin_sample + long_chunk_len > audio_len:
-                last_chunk = True
-
-        # Converting the list to a tensor
-        prob_vad = torch.cat(prob_chunks, dim=1)
-        last_elem = int(audio_len / (self.time_resolution * sample_rate))
-        prob_vad = prob_vad[:, 0:last_elem, :]
-
-        return prob_vad
-
-    def _manage_overlapped_chunks(self, small_chunks_prob):
-        """This support function manages overlapped the case in which the
-        small chunks have a 50% overlap."""
-
-        # Weighting the frame-level probabilities with a hamming window
-        # reduces uncertainty when overlapping chunks are used.
-        hamming_window = torch.hamming_window(
-            small_chunks_prob.shape[1], device=self.device
-        )
-
-        # First and last chunks require special care
-        half_point = int(small_chunks_prob.shape[1] / 2)
-        small_chunks_prob[0, half_point:] = small_chunks_prob[
-            0, half_point:
-        ] * hamming_window[half_point:].unsqueeze(1)
-        small_chunks_prob[-1, 0:half_point] = small_chunks_prob[
-            -1, 0:half_point
-        ] * hamming_window[0:half_point].unsqueeze(1)
-
-        # Applying the window to all the other probabilities
-        small_chunks_prob[1:-1] = small_chunks_prob[
-            1:-1
-        ] * hamming_window.unsqueeze(0).unsqueeze(2)
-
-        return small_chunks_prob
-
-    def get_speech_prob_chunk(self, wavs, wav_lens=None):
-        """Outputs the frame-level posterior probability for the input audio chunks
-        Outputs close to zero refers to time steps with a low probability of speech
-        activity, while outputs closer to one likely contain speech.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model. Make sure the sample rate is fs=16000 Hz.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        torch.Tensor
-            The encoded batch
-        """
-        # Manage single waveforms in input
-        if len(wavs.shape) == 1:
-            wavs = wavs.unsqueeze(0)
-
-        # Assign full length if wav_lens is not assigned
-        if wav_lens is None:
-            wav_lens = torch.ones(wavs.shape[0], device=self.device)
-
-        # Storing waveform in the specified device
-        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-        wavs = wavs.float()
-
-        # Computing features and embeddings
-        feats = self.mods.compute_features(wavs)
-        feats = self.mods.mean_var_norm(feats, wav_lens)
-        outputs = self.mods.cnn(feats)
-
-        outputs = outputs.reshape(
-            outputs.shape[0],
-            outputs.shape[1],
-            outputs.shape[2] * outputs.shape[3],
-        )
-
-        outputs, h = self.mods.rnn(outputs)
-        outputs = self.mods.dnn(outputs)
-        output_prob = torch.sigmoid(outputs)
-
-        return output_prob
-
-    def apply_threshold(
-        self, vad_prob, activation_th=0.5, deactivation_th=0.25
-    ):
-        """Scans the frame-level speech probabilities and applies a threshold
-        on them. Speech starts when a value larger than activation_th is
-        detected, while it ends when observing a value lower than
-        the deactivation_th.
-
-        Arguments
-        ---------
-        vad_prob: torch.Tensor
-            Frame-level speech probabilities.
-        activation_th:  float
-            Threshold for starting a speech segment.
-        deactivation_th: float
-            Threshold for ending a speech segment.
-
-        Returns
-        -------
-        vad_th: torch.Tensor
-            Tensor containing 1 for speech regions and 0 for non-speech regions.
-        """
-        vad_activation = (vad_prob >= activation_th).int()
-        vad_deactivation = (vad_prob >= deactivation_th).int()
-        vad_th = vad_activation + vad_deactivation
-
-        # Loop over batches and time steps
-        for batch in range(vad_th.shape[0]):
-            for time_step in range(vad_th.shape[1] - 1):
-                if (
-                    vad_th[batch, time_step] == 2
-                    and vad_th[batch, time_step + 1] == 1
-                ):
-                    vad_th[batch, time_step + 1] = 2
-
-        vad_th[vad_th == 1] = 0
-        vad_th[vad_th == 2] = 1
-        return vad_th
-
-    def get_boundaries(self, prob_th, output_value="seconds"):
-        """Computes the time boundaries where speech activity is detected.
-        It takes in input frame-level binary decisions
-        (1 for speech, 0 for non-speech) and outputs the begin/end second
-        (or sample) of each detected speech region.
-
-        Arguments
-        ---------
-        prob_th: torch.Tensor
-            Frame-level binary decisions (1 for speech frame, 0 for a
-            non-speech one).  The tensor can be obtained from apply_threshold.
-        output_value: 'seconds' or 'samples'
-            When the option 'seconds' is set, the returned boundaries are in
-            seconds, otherwise, it reports them in samples.
-
-        Returns
-        -------
-        boundaries: torch.Tensor
-            Tensor containing the start second (or sample) of speech segments
-            in even positions and their corresponding end in odd positions
-            (e.g, [1.0, 1.5, 5,.0 6.0] means that we have two speech segment;
-             one from 1.0 to 1.5 seconds and another from 5.0 to 6.0 seconds).
-        """
-        # Shifting frame-levels binary decision by 1
-        # This allows detecting changes in speech/non-speech activities
-        prob_th_shifted = torch.roll(prob_th, dims=1, shifts=1)
-        prob_th_shifted[:, 0, :] = 0
-        prob_th = prob_th + prob_th_shifted
-
-        # Needed to first and last time step
-        prob_th[:, 0, :] = (prob_th[:, 0, :] >= 1).int()
-        prob_th[:, -1, :] = (prob_th[:, -1, :] >= 1).int()
-
-        # Fix edge cases (when a speech starts in the last frames)
-        if (prob_th == 1).nonzero().shape[0] % 2 == 1:
-            prob_th = torch.cat(
-                (
-                    prob_th,
-                    torch.Tensor([1.0])
-                    .unsqueeze(0)
-                    .unsqueeze(2)
-                    .to(self.device),
-                ),
-                dim=1,
-            )
-
-        # Where prob_th is 1 there is a change
-        indexes = (prob_th == 1).nonzero()[:, 1].reshape(-1, 2)
-
-        # Remove 1 from end samples
-        indexes[:, -1] = indexes[:, -1] - 1
-
-        # From indexes to samples
-        seconds = (indexes * self.time_resolution).float()
-        samples = (self.sample_rate * seconds).round().int()
-
-        if output_value == "seconds":
-            boundaries = seconds
-        else:
-            boundaries = samples
-        return boundaries
-
-    def merge_close_segments(self, boundaries, close_th=0.250):
-        """Merges segments that are shorter than the given threshold.
-
-        Arguments
-        ---------
-        boundaries : str
-            Tensor containing the speech boundaries. It can be derived using the
-            get_boundaries method.
-        close_th: float
-            If the distance between boundaries is smaller than close_th, the
-            segments will be merged.
-
-        Returns
-        -------
-        new_boundaries
-            The new boundaries with the merged segments.
-        """
-
-        new_boundaries = []
-
-        # Single segment case
-        if boundaries.shape[0] == 0:
-            return boundaries
-
-        # Getting beg and end of previous segment
-        prev_beg_seg = boundaries[0, 0].float()
-        prev_end_seg = boundaries[0, 1].float()
-
-        # Process all the segments
-        for i in range(1, boundaries.shape[0]):
-            beg_seg = boundaries[i, 0]
-            segment_distance = beg_seg - prev_end_seg
-
-            # Merging close segments
-            if segment_distance <= close_th:
-                prev_end_seg = boundaries[i, 1]
-
-            else:
-                # Appending new segments
-                new_boundaries.append([prev_beg_seg, prev_end_seg])
-                prev_beg_seg = beg_seg
-                prev_end_seg = boundaries[i, 1]
-
-        new_boundaries.append([prev_beg_seg, prev_end_seg])
-        new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device)
-        return new_boundaries
-
-    def remove_short_segments(self, boundaries, len_th=0.250):
-        """Removes segments that are too short.
-
-        Arguments
-        ---------
-        boundaries : torch.Tensor
-            Tensor containing the speech boundaries. It can be derived using the
-            get_boundaries method.
-        len_th: float
-            If the length of the segment is smaller than close_th, the segments
-            will be merged.
-
-        Returns
-        -------
-        new_boundaries
-            The new boundaries without the short segments.
-        """
-        new_boundaries = []
-
-        # Process the segments
-        for i in range(boundaries.shape[0]):
-            # Computing segment length
-            seg_len = boundaries[i, 1] - boundaries[i, 0]
-
-            # Accept segment only if longer than len_th
-            if seg_len > len_th:
-                new_boundaries.append([boundaries[i, 0], boundaries[i, 1]])
-        new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device)
-
-        return new_boundaries
-
-    def save_boundaries(
-        self, boundaries, save_path=None, print_boundaries=True, audio_file=None
-    ):
-        """Saves the boundaries on a file (and/or prints them)  in a readable format.
-
-        Arguments
-        ---------
-        boundaries: torch.Tensor
-            Tensor containing the speech boundaries. It can be derived using the
-            get_boundaries method.
-        save_path: path
-            When to store the text file containing the speech/non-speech intervals.
-        print_boundaries: Bool
-            Prints the speech/non-speech intervals in the standard outputs.
-        audio_file: path
-            Path of the audio file containing the recording. The file is read
-            with torchaudio. It is used here to detect the length of the
-            signal.
-        """
-        # Create a new file if needed
-        if save_path is not None:
-            f = open(save_path, mode="w", encoding="utf-8")
-
-        # Getting the total size of the input file
-        if audio_file is not None:
-            sample_rate, audio_len = self._get_audio_info(audio_file)
-            audio_len = audio_len / sample_rate
-
-        # Setting the rights format for second- or sample-based boundaries
-        if boundaries.dtype == torch.int:
-            value_format = "% i"
-        else:
-            value_format = "% .2f "
-
-        # Printing speech and non-speech intervals
-        last_end = 0
-        cnt_seg = 0
-        for i in range(boundaries.shape[0]):
-            begin_value = boundaries[i, 0]
-            end_value = boundaries[i, 1]
-
-            if last_end != begin_value:
-                cnt_seg = cnt_seg + 1
-                print_str = (
-                    "segment_%03d " + value_format + value_format + "NON_SPEECH"
-                )
-                if print_boundaries:
-                    print(print_str % (cnt_seg, last_end, begin_value))
-                if save_path is not None:
-                    f.write(print_str % (cnt_seg, last_end, begin_value) + "\n")
-
-            cnt_seg = cnt_seg + 1
-            print_str = "segment_%03d " + value_format + value_format + "SPEECH"
-            if print_boundaries:
-                print(print_str % (cnt_seg, begin_value, end_value))
-            if save_path is not None:
-                f.write(print_str % (cnt_seg, begin_value, end_value) + "\n")
-
-            last_end = end_value
-
-        # Managing last segment
-        if audio_file is not None:
-            if last_end < audio_len:
-                cnt_seg = cnt_seg + 1
-                print_str = (
-                    "segment_%03d " + value_format + value_format + "NON_SPEECH"
-                )
-                if print_boundaries:
-                    print(print_str % (cnt_seg, end_value, audio_len))
-                if save_path is not None:
-                    f.write(print_str % (cnt_seg, end_value, audio_len) + "\n")
-
-        if save_path is not None:
-            f.close()
-
-    def energy_VAD(
-        self,
-        audio_file,
-        boundaries,
-        activation_th=0.5,
-        deactivation_th=0.0,
-        eps=1e-6,
-    ):
-        """Applies energy-based VAD within the detected speech segments.The neural
-        network VAD often creates longer segments and tends to merge segments that
-        are close with each other.
-
-        The energy VAD post-processes can be useful for having a fine-grained voice
-        activity detection.
-
-        The energy VAD computes the energy within the small chunks. The energy is
-        normalized within the segment to have mean 0.5 and +-0.5 of std.
-        This helps to set the energy threshold.
-
-        Arguments
-        ---------
-        audio_file: path
-            Path of the audio file containing the recording. The file is read
-            with torchaudio.
-        boundaries : torch.Tensor
-            Tensor containing the speech boundaries. It can be derived using the
-            get_boundaries method.
-        activation_th: float
-            A new speech segment is started it the energy is above activation_th.
-        deactivation_th: float
-            The segment is considered ended when the energy is <= deactivation_th.
-        eps: float
-            Small constant for numerical stability.
-
-
-        Returns
-        -------
-        new_boundaries
-            The new boundaries that are post-processed by the energy VAD.
-        """
-
-        # Getting the total size of the input file
-        sample_rate, audio_len = self._get_audio_info(audio_file)
-
-        if sample_rate != self.sample_rate:
-            raise ValueError(
-                "The detected sample rate is different from that set in the hparam file"
-            )
-
-        # Computing the chunk length of the energy window
-        chunk_len = int(self.time_resolution * sample_rate)
-        new_boundaries = []
-
-        # Processing speech segments
-        for i in range(boundaries.shape[0]):
-            begin_sample = int(boundaries[i, 0] * sample_rate)
-            end_sample = int(boundaries[i, 1] * sample_rate)
-            seg_len = end_sample - begin_sample
-
-            # Reading the speech segment
-            segment, _ = torchaudio.load(
-                audio_file, frame_offset=begin_sample, num_frames=seg_len
-            )
-
-            # Create chunks
-            segment_chunks = self.create_chunks(
-                segment, chunk_size=chunk_len, chunk_stride=chunk_len
-            )
-
-            # Energy computation within each chunk
-            energy_chunks = segment_chunks.abs().sum(-1) + eps
-            energy_chunks = energy_chunks.log()
-
-            # Energy normalization
-            energy_chunks = (
-                (energy_chunks - energy_chunks.mean())
-                / (2 * energy_chunks.std())
-            ) + 0.5
-            energy_chunks = energy_chunks.unsqueeze(0).unsqueeze(2)
-
-            # Apply threshold based on the energy value
-            energy_vad = self.apply_threshold(
-                energy_chunks,
-                activation_th=activation_th,
-                deactivation_th=deactivation_th,
-            )
-
-            # Get the boundaries
-            energy_boundaries = self.get_boundaries(
-                energy_vad, output_value="seconds"
-            )
-
-            # Get the final boundaries in the original signal
-            for j in range(energy_boundaries.shape[0]):
-                start_en = boundaries[i, 0] + energy_boundaries[j, 0]
-                end_end = boundaries[i, 0] + energy_boundaries[j, 1]
-                new_boundaries.append([start_en, end_end])
-
-        # Convert boundaries to tensor
-        new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device)
-        return new_boundaries
-
-    def create_chunks(self, x, chunk_size=16384, chunk_stride=16384):
-        """Splits the input into smaller chunks of size chunk_size with
-        an overlap chunk_stride. The chunks are concatenated over
-        the batch axis.
-
-        Arguments
-        ---------
-        x: torch.Tensor
-            Signal to split into chunks.
-        chunk_size : str
-            The size of each chunk.
-        chunk_stride:
-            The stride (hop) of each chunk.
-
-
-        Returns
-        -------
-        x: torch.Tensor
-            A new tensors with the chunks derived from the input signal.
-
-        """
-        x = x.unfold(1, chunk_size, chunk_stride)
-        x = x.reshape(x.shape[0] * x.shape[1], -1)
-        return x
-
-    def _get_audio_info(self, audio_file):
-        """Returns the sample rate and the length of the input audio file"""
-
-        # Getting the total size of the input file
-        metadata = torchaudio.info(audio_file)
-        sample_rate = metadata.sample_rate
-        audio_len = metadata.num_frames
-        return sample_rate, audio_len
-
-    def upsample_VAD(self, vad_out, audio_file, time_resolution=0.01):
-        """Upsamples the output of the vad to help visualization. It creates a
-        signal that is 1 when there is speech and 0 when there is no speech.
-        The vad signal has the same resolution as the input one and can be
-        opened with it (e.g, using audacity) to visually figure out VAD regions.
-
-        Arguments
-        ---------
-        vad_out: torch.Tensor
-            Tensor containing 1 for each frame of speech and 0 for each non-speech
-            frame.
-        audio_file: path
-            The original audio file used to compute vad_out
-        time_resolution : float
-            Time resolution of the vad_out signal.
-
-        Returns
-        -------
-        vad_signal
-            The upsampled version of the vad_out tensor.
-        """
-
-        # Getting the total size of the input file
-        sample_rate, sig_len = self._get_audio_info(audio_file)
-
-        if sample_rate != self.sample_rate:
-            raise ValueError(
-                "The detected sample rate is different from that set in the hparam file"
-            )
-
-        beg_samp = 0
-        step_size = int(time_resolution * sample_rate)
-        end_samp = step_size
-        index = 0
-
-        # Initialize upsampled signal
-        vad_signal = torch.zeros(1, sig_len, device=vad_out.device)
-
-        # Upsample signal
-        while end_samp < sig_len:
-            vad_signal[0, beg_samp:end_samp] = vad_out[0, index, 0]
-            index = index + 1
-            beg_samp = beg_samp + step_size
-            end_samp = beg_samp + step_size
-        return vad_signal
-
-    def upsample_boundaries(self, boundaries, audio_file):
-        """Based on the input boundaries, this method creates a signal that is 1
-        when there is speech and 0 when there is no speech.
-        The vad signal has the same resolution as the input one and can be
-        opened with it (e.g, using audacity) to visually figure out VAD regions.
-
-        Arguments
-        ---------
-        boundaries: torch.Tensor
-            Tensor containing the boundaries of the speech segments.
-        audio_file: path
-            The original audio file used to compute vad_out
-
-        Returns
-        -------
-        vad_signal
-            The output vad signal with the same resolution of the input one.
-        """
-
-        # Getting the total size of the input file
-        sample_rate, sig_len = self._get_audio_info(audio_file)
-
-        if sample_rate != self.sample_rate:
-            raise ValueError(
-                "The detected sample rate is different from that set in the hparam file"
-            )
-
-        # Initialization of the output signal
-        vad_signal = torch.zeros(1, sig_len, device=boundaries.device)
-
-        # Composing the vad signal from boundaries
-        for i in range(boundaries.shape[0]):
-            beg_sample = int(boundaries[i, 0] * sample_rate)
-            end_sample = int(boundaries[i, 1] * sample_rate)
-            vad_signal[0, beg_sample:end_sample] = 1.0
-        return vad_signal
-
-    def double_check_speech_segments(
-        self, boundaries, audio_file, speech_th=0.5
-    ):
-        """Takes in input the boundaries of the detected speech segments and
-        double checks (using the neural VAD) that they actually contain speech.
-
-        Arguments
-        ---------
-        boundaries: torch.Tensor
-            Tensor containing the boundaries of the speech segments.
-        audio_file: path
-            The original audio file used to compute vad_out.
-        speech_th: float
-            Threshold on the mean posterior probability over which speech is
-            confirmed. Below that threshold, the segment is re-assigned to a
-            non-speech region.
-
-        Returns
-        -------
-        new_boundaries
-            The boundaries of the segments where speech activity is confirmed.
-        """
-
-        # Getting the total size of the input file
-        sample_rate, sig_len = self._get_audio_info(audio_file)
-
-        # Double check the segments
-        new_boundaries = []
-        for i in range(boundaries.shape[0]):
-            beg_sample = int(boundaries[i, 0] * sample_rate)
-            end_sample = int(boundaries[i, 1] * sample_rate)
-            len_seg = end_sample - beg_sample
-
-            # Read the candidate speech segment
-            segment, fs = torchaudio.load(
-                audio_file, frame_offset=beg_sample, num_frames=len_seg
-            )
-            speech_prob = self.get_speech_prob_chunk(segment)
-            if speech_prob.mean() > speech_th:
-                # Accept this as a speech segment
-                new_boundaries.append([boundaries[i, 0], boundaries[i, 1]])
-
-        # Convert boundaries from list to tensor
-        new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device)
-        return new_boundaries
-
-    def get_segments(
-        self, boundaries, audio_file, before_margin=0.1, after_margin=0.1
-    ):
-        """Returns a list containing all the detected speech segments.
-
-        Arguments
-        ---------
-        boundaries: torch.Tensor
-            Tensor containing the boundaries of the speech segments.
-        audio_file: path
-            The original audio file used to compute vad_out.
-        before_margin: float
-            Used to cut the segments samples a bit before the detected margin.
-        after_margin: float
-            Use to cut the segments samples a bit after the detected margin.
-
-        Returns
-        -------
-        segments: list
-            List containing the detected speech segments
-        """
-        sample_rate, sig_len = self._get_audio_info(audio_file)
-
-        if sample_rate != self.sample_rate:
-            raise ValueError(
-                "The detected sample rate is different from that set in the hparam file"
-            )
-
-        segments = []
-        for i in range(boundaries.shape[0]):
-            beg_sample = boundaries[i, 0] * sample_rate
-            end_sample = boundaries[i, 1] * sample_rate
-
-            beg_sample = int(max(0, beg_sample - before_margin * sample_rate))
-            end_sample = int(
-                min(sig_len, end_sample + after_margin * sample_rate)
-            )
-
-            len_seg = end_sample - beg_sample
-            vad_segment, fs = torchaudio.load(
-                audio_file, frame_offset=beg_sample, num_frames=len_seg
-            )
-            segments.append(vad_segment)
-        return segments
-
-    def get_speech_segments(
-        self,
-        audio_file,
-        large_chunk_size=30,
-        small_chunk_size=10,
-        overlap_small_chunk=False,
-        apply_energy_VAD=False,
-        double_check=True,
-        close_th=0.250,
-        len_th=0.250,
-        activation_th=0.5,
-        deactivation_th=0.25,
-        en_activation_th=0.5,
-        en_deactivation_th=0.0,
-        speech_th=0.50,
-    ):
-        """Detects speech segments within the input file. The input signal can
-        be both a short or a long recording. The function computes the
-        posterior probabilities on large chunks (e.g, 30 sec), that are read
-        sequentially (to avoid storing big signals in memory).
-        Each large chunk is, in turn, split into smaller chunks (e.g, 10 seconds)
-        that are processed in parallel. The pipeline for detecting the speech
-        segments is the following:
-            1- Compute posteriors probabilities at the frame level.
-            2- Apply a threshold on the posterior probability.
-            3- Derive candidate speech segments on top of that.
-            4- Apply energy VAD within each candidate segment (optional).
-            5- Merge segments that are too close.
-            6- Remove segments that are too short.
-            7- Double check speech segments (optional).
-
-
-        Arguments
-        ---------
-        audio_file : str
-            Path to audio file.
-        large_chunk_size: float
-            Size (in seconds) of the large chunks that are read sequentially
-            from the input audio file.
-        small_chunk_size: float
-            Size (in seconds) of the small chunks extracted from the large ones.
-            The audio signal is processed in parallel within the small chunks.
-            Note that large_chunk_size/small_chunk_size must be an integer.
-        overlap_small_chunk: bool
-            If True, it creates overlapped small chunks (with 50% overlap).
-            The probabilities of the overlapped chunks are combined using
-            hamming windows.
-        apply_energy_VAD: bool
-            If True, a energy-based VAD is used on the detected speech segments.
-            The neural network VAD often creates longer segments and tends to
-            merge close segments together. The energy VAD post-processes can be
-            useful for having a fine-grained voice activity detection.
-            The energy thresholds is  managed by activation_th and
-            deactivation_th (see below).
-        double_check: bool
-            If True, double checks (using the neural VAD) that the candidate
-            speech segments actually contain speech. A threshold on the mean
-            posterior probabilities provided by the neural network is applied
-            based on the speech_th parameter (see below).
-        activation_th:  float
-            Threshold of the neural posteriors above which starting a speech segment.
-        deactivation_th: float
-            Threshold of the neural posteriors below which ending a speech segment.
-        en_activation_th: float
-            A new speech segment is started it the energy is above activation_th.
-            This is active only if apply_energy_VAD is True.
-        en_deactivation_th: float
-            The segment is considered ended when the energy is <= deactivation_th.
-            This is active only if apply_energy_VAD is True.
-        speech_th: float
-            Threshold on the mean posterior probability within the candidate
-            speech segment. Below that threshold, the segment is re-assigned to
-            a non-speech region. This is active only if double_check is True.
-        close_th: float
-            If the distance between boundaries is smaller than close_th, the
-            segments will be merged.
-        len_th: float
-            If the length of the segment is smaller than close_th, the segments
-            will be merged.
-
-        Returns
-        -------
-        boundaries: torch.Tensor
-            Tensor containing the start second of speech segments in even
-            positions and their corresponding end in odd positions
-            (e.g, [1.0, 1.5, 5,.0 6.0] means that we have two speech segment;
-             one from 1.0 to 1.5 seconds and another from 5.0 to 6.0 seconds).
-        """
-
-        # Fetch audio file from web if not local
-        source, fl = split_path(audio_file)
-        audio_file = fetch(fl, source=source)
-
-        # Computing speech vs non speech probabilities
-        prob_chunks = self.get_speech_prob_file(
-            audio_file,
-            large_chunk_size=large_chunk_size,
-            small_chunk_size=small_chunk_size,
-            overlap_small_chunk=overlap_small_chunk,
-        )
-
-        # Apply a threshold to get candidate speech segments
-        prob_th = self.apply_threshold(
-            prob_chunks,
-            activation_th=activation_th,
-            deactivation_th=deactivation_th,
-        ).float()
-
-        # Compute the boundaries of the speech segments
-        boundaries = self.get_boundaries(prob_th, output_value="seconds")
-
-        # Apply energy-based VAD on the detected speech segments
-        if apply_energy_VAD:
-            boundaries = self.energy_VAD(
-                audio_file,
-                boundaries,
-                activation_th=en_activation_th,
-                deactivation_th=en_deactivation_th,
-            )
-
-        # Merge short segments
-        boundaries = self.merge_close_segments(boundaries, close_th=close_th)
-
-        # Remove short segments
-        boundaries = self.remove_short_segments(boundaries, len_th=len_th)
-
-        # Double check speech segments
-        if double_check:
-            boundaries = self.double_check_speech_segments(
-                boundaries, audio_file, speech_th=speech_th
-            )
-
-        return boundaries
-
-    def forward(self, wavs, wav_lens=None):
-        """Gets frame-level speech-activity predictions"""
-        return self.get_speech_prob_chunk(wavs, wav_lens)
-
-
-class SepformerSeparation(Pretrained):
-    """A "ready-to-use" speech separation model.
-
-    Uses Sepformer architecture.
-
-    Example
-    -------
-    >>> tmpdir = getfixture("tmpdir")
-    >>> model = SepformerSeparation.from_hparams(
-    ...     source="speechbrain/sepformer-wsj02mix",
-    ...     savedir=tmpdir)
-    >>> mix = torch.randn(1, 400)
-    >>> est_sources = model.separate_batch(mix)
-    >>> print(est_sources.shape)
-    torch.Size([1, 400, 2])
-    """
-
-    MODULES_NEEDED = ["encoder", "masknet", "decoder"]
-
-    def separate_batch(self, mix):
-        """Run source separation on batch of audio.
-
-        Arguments
-        ---------
-        mix : torch.Tensor
-            The mixture of sources.
-
-        Returns
-        -------
-        tensor
-            Separated sources
-        """
-
-        # Separation
-        mix = mix.to(self.device)
-        mix_w = self.mods.encoder(mix)
-        est_mask = self.mods.masknet(mix_w)
-        mix_w = torch.stack([mix_w] * self.hparams.num_spks)
-        sep_h = mix_w * est_mask
-
-        # Decoding
-        est_source = torch.cat(
-            [
-                self.mods.decoder(sep_h[i]).unsqueeze(-1)
-                for i in range(self.hparams.num_spks)
-            ],
-            dim=-1,
-        )
-
-        # T changed after conv1d in encoder, fix it here
-        T_origin = mix.size(1)
-        T_est = est_source.size(1)
-        if T_origin > T_est:
-            est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
-        else:
-            est_source = est_source[:, :T_origin, :]
-        return est_source
-
-    def separate_file(self, path, savedir="audio_cache"):
-        """Separate sources from file.
-
-        Arguments
-        ---------
-        path : str
-            Path to file which has a mixture of sources. It can be a local
-            path, a web url, or a huggingface repo.
-        savedir : path
-            Path where to store the wav signals (when downloaded from the web).
-        Returns
-        -------
-        tensor
-            Separated sources
-        """
-        source, fl = split_path(path)
-        path = fetch(fl, source=source, savedir=savedir)
-
-        batch, fs_file = torchaudio.load(path)
-        batch = batch.to(self.device)
-        fs_model = self.hparams.sample_rate
-
-        # resample the data if needed
-        if fs_file != fs_model:
-            print(
-                "Resampling the audio from {} Hz to {} Hz".format(
-                    fs_file, fs_model
-                )
-            )
-            tf = torchaudio.transforms.Resample(
-                orig_freq=fs_file, new_freq=fs_model
-            ).to(self.device)
-            batch = batch.mean(dim=0, keepdim=True)
-            batch = tf(batch)
-
-        est_sources = self.separate_batch(batch)
-        est_sources = (
-            est_sources / est_sources.abs().max(dim=1, keepdim=True)[0]
-        )
-        return est_sources
-
-    def forward(self, mix):
-        """Runs separation on the input mix"""
-        return self.separate_batch(mix)
-
-
-class SpectralMaskEnhancement(Pretrained):
-    """A ready-to-use model for speech enhancement.
-
-    Arguments
-    ---------
-    See ``Pretrained``.
-
-    Example
-    -------
-    >>> import torch
-    >>> from speechbrain.pretrained import SpectralMaskEnhancement
-    >>> # Model is downloaded from the speechbrain HuggingFace repo
-    >>> tmpdir = getfixture("tmpdir")
-    >>> enhancer = SpectralMaskEnhancement.from_hparams(
-    ...     source="speechbrain/metricgan-plus-voicebank",
-    ...     savedir=tmpdir,
-    ... )
-    >>> enhanced = enhancer.enhance_file(
-    ...     "speechbrain/metricgan-plus-voicebank/example.wav"
-    ... )
-    """
-
-    HPARAMS_NEEDED = ["compute_stft", "spectral_magnitude", "resynth"]
-    MODULES_NEEDED = ["enhance_model"]
-
-    def compute_features(self, wavs):
-        """Compute the log spectral magnitude features for masking.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            A batch of waveforms to convert to log spectral mags.
-        """
-        feats = self.hparams.compute_stft(wavs)
-        feats = self.hparams.spectral_magnitude(feats)
-        return torch.log1p(feats)
-
-    def enhance_batch(self, noisy, lengths=None):
-        """Enhance a batch of noisy waveforms.
-
-        Arguments
-        ---------
-        noisy : torch.Tensor
-            A batch of waveforms to perform enhancement on.
-        lengths : torch.Tensor
-            The lengths of the waveforms if the enhancement model handles them.
-
-        Returns
-        -------
-        torch.Tensor
-            A batch of enhanced waveforms of the same shape as input.
-        """
-        noisy = noisy.to(self.device)
-        noisy_features = self.compute_features(noisy)
-
-        # Perform masking-based enhancement, multiplying output with input.
-        if lengths is not None:
-            mask = self.mods.enhance_model(noisy_features, lengths=lengths)
-        else:
-            mask = self.mods.enhance_model(noisy_features)
-        enhanced = torch.mul(mask, noisy_features)
-
-        # Return resynthesized waveforms
-        return self.hparams.resynth(torch.expm1(enhanced), noisy)
-
-    def enhance_file(self, filename, output_filename=None, **kwargs):
-        """Enhance a wav file.
-
-        Arguments
-        ---------
-        filename : str
-            Location on disk to load file for enhancement.
-        output_filename : str
-            If provided, writes enhanced data to this file.
-        """
-        noisy = self.load_audio(filename, **kwargs)
-        noisy = noisy.to(self.device)
-
-        # Fake a batch:
-        batch = noisy.unsqueeze(0)
-        if lengths_arg_exists(self.enhance_batch):
-            enhanced = self.enhance_batch(batch, lengths=torch.tensor([1.0]))
-        else:
-            enhanced = self.enhance_batch(batch)
-
-        if output_filename is not None:
-            torchaudio.save(output_filename, enhanced, channels_first=False)
-
-        return enhanced.squeeze(0)
-
-
-class EncodeDecodePipelineMixin:
-    """
-    A mixin for pretrained models that makes it possible to specify an encoding pipeline and a decoding pipeline
-    """
-
-    def create_pipelines(self):
-        """
-        Initializes the encode and decode pipeline
-        """
-        self._run_init_steps(self.hparams.encode_pipeline)
-        self._run_init_steps(self.hparams.decode_pipeline)
-        self.encode_pipeline = DataPipeline(
-            static_data_keys=self.INPUT_STATIC_KEYS,
-            dynamic_items=self.hparams.encode_pipeline["steps"],
-            output_keys=self.hparams.encode_pipeline["output_keys"],
-        )
-        self.decode_pipeline = DataPipeline(
-            static_data_keys=self.hparams.model_output_keys,
-            dynamic_items=self.hparams.decode_pipeline["steps"],
-            output_keys=self.OUTPUT_KEYS,
-        )
-
-    def _run_init_steps(self, pipeline_definition):
-        """Encode/decode pipelines may include initialization
-        steps, such as filling text encoders with tokens. Calling
-        this method will run them, if defined"""
-        steps = pipeline_definition.get("init", [])
-        for step in steps:
-            step_func = step.get("func")
-            if not step_func or not callable(step_func):
-                raise ValueError("Invalid pipeline init definition")
-            step_func()
-
-    def _run_pipeline(self, pipeline, input, batch):
-        if batch:
-            output = pipeline(input)
-        else:
-            output = [pipeline(item) for item in input]
-        return output
-
-    def _get_encode_pipeline_input(self, input):
-        return input if self.batch_inputs else self._itemize(input)
-
-    def _get_decode_pipeline_input(self, model_output):
-        model_output_keys = getattr(self.hparams, "model_output_keys", None)
-        pipeline_input = model_output
-        if len(model_output_keys) == 1:
-            pipeline_input = (pipeline_input,)
-        # The input to a pipeline is a dictionary. If model_output_keys
-        # is provided, the output of the model is assumed to be a collection
-        # (e.g. a list or a tuple).
-        if model_output_keys:
-            pipeline_input = dict(zip(model_output_keys, pipeline_input))
-
-        # By default, the pipeline will be applied to in batch mode
-        # to the entire model input
-        if not self.batch_outputs:
-            pipeline_input = self._itemize(pipeline_input)
-        return pipeline_input
-
-    def _itemize(self, pipeline_input):
-        first_item = next(iter(pipeline_input.values()))
-        keys, values = pipeline_input.keys(), pipeline_input.values()
-        batch_length = len(first_item)
-        return [
-            dict(zip(keys, [value[idx] for value in values]))
-            for idx in range(batch_length)
-        ]
-
-    def to_dict(self, data):
-        """
-        Converts padded batches to dictionaries, leaves
-        other data types as is
-
-        Arguments
-        ---------
-        data: object
-            a dictionary or a padded batch
-
-        Returns
-        -------
-        results: dict
-            the dictionary
-        """
-        if isinstance(data, PaddedBatch):
-            data = {
-                key: self._get_value(data, key)
-                for key in self.hparams.encode_pipeline["output_keys"]
-            }
-        return data
-
-    def _get_value(self, data, key):
-        """
-        Retrieves the value associated with the specified key, dereferencing
-        .data where applicable
-
-        Arguments
-        ---------
-        data: PaddedBatch
-            a padded batch
-        key: str
-            the key
-
-        Returns
-        -------
-        result: object
-            the result
-        """
-        value = getattr(data, key)
-        if not self.input_use_padded_data and isinstance(value, PaddedData):
-            value = value.data
-        return value
-
-    @property
-    def batch_inputs(self):
-        """
-        Determines whether the input pipeline
-        operates on batches or individual examples
-        (true means batched)
-
-        Returns
-        -------
-        batch_inputs: bool
-        """
-        return self.hparams.encode_pipeline.get("batch", True)
-
-    @property
-    def input_use_padded_data(self):
-        """
-        If turned on, raw PaddedData instances will be passed to
-        the model. If turned off, only .data will be used
-
-        Returns
-        -------
-        result: bool
-            whether padded data is used as is
-        """
-        return self.hparams.encode_pipeline.get("use_padded_data", False)
-
-    @property
-    def batch_outputs(self):
-        """
-        Determines whether the output pipeline
-        operates on batches or individual examples
-        (true means batched)
-
-        Returns
-        -------
-        batch_outputs: bool
-        """
-        return self.hparams.decode_pipeline.get("batch", True)
-
-    def _collate(self, data):
-        if not self.batch_inputs:
-            collate_fn = getattr(self.hparams, "collate_fn", PaddedBatch)
-            data = collate_fn(data)
-        return data
-
-    def encode_input(self, input):
-        """
-        Encodes the inputs using the pipeline
-
-        Arguments
-        ---------
-        input: dict
-            the raw inputs
-
-        Returns
-        -------
-        results: object
-
-        """
-        pipeline_input = self._get_encode_pipeline_input(input)
-        model_input = self._run_pipeline(
-            pipeline=self.encode_pipeline,
-            input=pipeline_input,
-            batch=self.batch_inputs,
-        )
-        model_input = self._collate(model_input)
-        if hasattr(model_input, "to"):
-            model_input = model_input.to(self.device)
-        return self.to_dict(model_input)
-
-    def decode_output(self, output):
-        """
-        Decodes the raw model outputs
-
-        Arguments
-        ---------
-        output: tuple
-            raw model outputs
-
-        Returns
-        -------
-        result: dict or list
-            the output of the pipeline
-        """
-        pipeline_input = self._get_decode_pipeline_input(output)
-        return self._run_pipeline(
-            pipeline=self.decode_pipeline,
-            input=pipeline_input,
-            batch=self.batch_outputs,
-        )
-
-
-class GraphemeToPhoneme(Pretrained, EncodeDecodePipelineMixin):
-    """
-    A pretrained model implementation for Grapheme-to-Phoneme (G2P) models
-    that take raw natural language text as an input and
-
-    Example
-    -------
-    >>> text = ("English is tough. It can be understood "
-    ...         "through thorough thought though")
-    >>> from speechbrain.pretrained import GraphemeToPhoneme
-    >>> tmpdir = getfixture('tmpdir')
-    >>> g2p = GraphemeToPhoneme.from_hparams('path/to/model', savedir=tmpdir) # doctest: +SKIP
-    >>> phonemes = g2p.g2p(text) # doctest: +SKIP
-    """
-
-    INPUT_STATIC_KEYS = ["txt"]
-    OUTPUT_KEYS = ["phonemes"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.create_pipelines()
-        self.load_dependencies()
-
-    @property
-    def phonemes(self):
-        """Returns the available phonemes"""
-        return self.hparams.phonemes
-
-    @property
-    def language(self):
-        """Returns the language for which this model is available"""
-        return self.hparams.language
-
-    def g2p(self, text):
-        """Performs the Grapheme-to-Phoneme conversion
-
-        Arguments
-        ---------
-        text: str or list[str]
-            a single string to be encoded to phonemes - or a
-            sequence of strings
-
-        Returns
-        -------
-        result: list
-            if a single example was provided, the return value is a
-            single list of phonemes
-        """
-        single = isinstance(text, str)
-        if single:
-            text = [text]
-
-        model_inputs = self.encode_input({"txt": text})
-        self._update_graphemes(model_inputs)
-        model_outputs = self.mods.model(**model_inputs)
-        decoded_output = self.decode_output(model_outputs)
-        phonemes = decoded_output["phonemes"]
-        if single:
-            phonemes = phonemes[0]
-        return phonemes
-
-    def _update_graphemes(self, model_inputs):
-        grapheme_sequence_mode = getattr(self.hparams, "grapheme_sequence_mode")
-        if grapheme_sequence_mode and grapheme_sequence_mode != "raw":
-            grapheme_encoded_key = f"grapheme_encoded_{grapheme_sequence_mode}"
-            if grapheme_encoded_key in model_inputs:
-                model_inputs["grapheme_encoded"] = model_inputs[
-                    grapheme_encoded_key
-                ]
-
-    def load_dependencies(self):
-        """Loads any relevant model dependencies"""
-        deps_pretrainer = getattr(self.hparams, "deps_pretrainer", None)
-        if deps_pretrainer:
-            deps_pretrainer.collect_files()
-            deps_pretrainer.load_collected(device=self.device)
-
-    def __call__(self, text):
-        """A convenience callable wrapper - same as G2P
-
-        Arguments
-        ---------
-        text: str or list[str]
-            a single string to be encoded to phonemes - or a
-            sequence of strings
-
-        Returns
-        -------
-        result: list
-            if a single example was provided, the return value is a
-            single list of phonemes
-        """
-        return self.g2p(text)
-
-    def forward(self, noisy, lengths=None):
-        """Runs enhancement on the noisy input"""
-        return self.enhance_batch(noisy, lengths)
-
-
-class WaveformEnhancement(Pretrained):
-    """A ready-to-use model for speech enhancement.
-
-    Arguments
-    ---------
-    See ``Pretrained``.
-
-    Example
-    -------
-    >>> from speechbrain.pretrained import WaveformEnhancement
-    >>> # Model is downloaded from the speechbrain HuggingFace repo
-    >>> tmpdir = getfixture("tmpdir")
-    >>> enhancer = WaveformEnhancement.from_hparams(
-    ...     source="speechbrain/mtl-mimic-voicebank",
-    ...     savedir=tmpdir,
-    ... )
-    >>> enhanced = enhancer.enhance_file(
-    ...     "speechbrain/mtl-mimic-voicebank/example.wav"
-    ... )
-    """
-
-    MODULES_NEEDED = ["enhance_model"]
-
-    def enhance_batch(self, noisy, lengths=None):
-        """Enhance a batch of noisy waveforms.
-
-        Arguments
-        ---------
-        noisy : torch.Tensor
-            A batch of waveforms to perform enhancement on.
-        lengths : torch.Tensor
-            The lengths of the waveforms if the enhancement model handles them.
-
-        Returns
-        -------
-        torch.Tensor
-            A batch of enhanced waveforms of the same shape as input.
-        """
-        noisy = noisy.to(self.device)
-        enhanced_wav, _ = self.mods.enhance_model(noisy)
-        return enhanced_wav
-
-    def enhance_file(self, filename, output_filename=None, **kwargs):
-        """Enhance a wav file.
-
-        Arguments
-        ---------
-        filename : str
-            Location on disk to load file for enhancement.
-        output_filename : str
-            If provided, writes enhanced data to this file.
-        """
-        noisy = self.load_audio(filename, **kwargs)
-
-        # Fake a batch:
-        batch = noisy.unsqueeze(0)
-        enhanced = self.enhance_batch(batch)
-
-        if output_filename is not None:
-            torchaudio.save(output_filename, enhanced, channels_first=False)
-
-        return enhanced.squeeze(0)
-
-    def forward(self, noisy, lengths=None):
-        """Runs enhancement on the noisy input"""
-        return self.enhance_batch(noisy, lengths)
-
-
-class SNREstimator(Pretrained):
-    """A "ready-to-use" SNR estimator."""
-
-    MODULES_NEEDED = ["encoder", "encoder_out"]
-    HPARAMS_NEEDED = ["stat_pooling", "snrmax", "snrmin"]
-
-    def estimate_batch(self, mix, predictions):
-        """Run SI-SNR estimation on the estimated sources, and mixture.
-
-        Arguments
-        ---------
-        mix : torch.Tensor
-            The mixture of sources of shape B X T
-        predictions : torch.Tensor
-            of size (B x T x C),
-            where B is batch size
-                  T is number of time points
-                  C is number of sources
-
-        Returns
-        -------
-        tensor
-            Estimate of SNR
-        """
-
-        predictions = predictions.permute(0, 2, 1)
-        predictions = predictions.reshape(-1, predictions.size(-1))
-
-        if hasattr(self.hparams, "separation_norm_type"):
-            if self.hparams.separation_norm_type == "max":
-                predictions = (
-                    predictions / predictions.max(dim=1, keepdim=True)[0]
-                )
-                mix = mix / mix.max(dim=1, keepdim=True)[0]
-
-            elif self.hparams.separation_norm_type == "stnorm":
-                predictions = (
-                    predictions - predictions.mean(dim=1, keepdim=True)
-                ) / predictions.std(dim=1, keepdim=True)
-                mix = (mix - mix.mean(dim=1, keepdim=True)) / mix.std(
-                    dim=1, keepdim=True
-                )
-
-        min_T = min(predictions.shape[1], mix.shape[1])
-        assert predictions.shape[1] == mix.shape[1], "lengths change"
-
-        mix_repeat = mix.repeat(2, 1)
-        inp_cat = torch.cat(
-            [
-                predictions[:, :min_T].unsqueeze(1),
-                mix_repeat[:, :min_T].unsqueeze(1),
-            ],
-            dim=1,
-        )
-
-        enc = self.mods.encoder(inp_cat)
-        enc = enc.permute(0, 2, 1)
-        enc_stats = self.hparams.stat_pooling(enc)
-
-        # this gets the SI-SNR estimate in the compressed range 0-1
-        snrhat = self.mods.encoder_out(enc_stats).squeeze()
-
-        # get the SI-SNR estimate in the true range
-        snrhat = self.gettrue_snrrange(snrhat)
-        return snrhat
-
-    def forward(self, mix, predictions):
-        """Just run the batch estimate"""
-        return self.estimate_batch(mix, predictions)
-
-    def gettrue_snrrange(self, inp):
-        """Convert from 0-1 range to true snr range"""
-        rnge = self.hparams.snrmax - self.hparams.snrmin
-        inp = inp * rnge
-        inp = inp + self.hparams.snrmin
-        return inp
-
-
-class Tacotron2(Pretrained):
-    """
-    A ready-to-use wrapper for Tacotron2 (text -> mel_spec).
-
-    Arguments
-    ---------
-    hparams
-        Hyperparameters (from HyperPyYAML)
-
-    Example
-    -------
-    >>> tmpdir_tts = getfixture('tmpdir') / "tts"
-    >>> tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir=tmpdir_tts)
-    >>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
-    >>> items = [
-    ...   "A quick brown fox jumped over the lazy dog",
-    ...   "How much wood would a woodchuck chuck?",
-    ...   "Never odd or even"
-    ... ]
-    >>> mel_outputs, mel_lengths, alignments = tacotron2.encode_batch(items)
-
-    >>> # One can combine the TTS model with a vocoder (that generates the final waveform)
-    >>> # Intialize the Vocoder (HiFIGAN)
-    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
-    >>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder)
-    >>> # Running the TTS
-    >>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
-    >>> # Running Vocoder (spectrogram-to-waveform)
-    >>> waveforms = hifi_gan.decode_batch(mel_output)
-    """
-
-    HPARAMS_NEEDED = ["model", "text_to_sequence"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.text_cleaners = getattr(
-            self.hparams, "text_cleaners", ["english_cleaners"]
-        )
-        self.infer = self.hparams.model.infer
-
-    def text_to_seq(self, txt):
-        """Encodes raw text into a tensor with a customer text-to-equence fuction"""
-        sequence = self.hparams.text_to_sequence(txt, self.text_cleaners)
-        return sequence, len(sequence)
-
-    def encode_batch(self, texts):
-        """Computes mel-spectrogram for a list of texts
-
-        Texts must be sorted in decreasing order on their lengths
-
-        Arguments
-        ---------
-        texts: List[str]
-            texts to be encoded into spectrogram
-
-        Returns
-        -------
-        tensors of output spectrograms, output lengths and alignments
-        """
-        with torch.no_grad():
-            inputs = [
-                {
-                    "text_sequences": torch.tensor(
-                        self.text_to_seq(item)[0], device=self.device
-                    )
-                }
-                for item in texts
-            ]
-            inputs = speechbrain.dataio.batch.PaddedBatch(inputs)
-
-            lens = [self.text_to_seq(item)[1] for item in texts]
-            assert lens == sorted(
-                lens, reverse=True
-            ), "input lengths must be sorted in decreasing order"
-            input_lengths = torch.tensor(lens, device=self.device)
-
-            mel_outputs_postnet, mel_lengths, alignments = self.infer(
-                inputs.text_sequences.data, input_lengths
-            )
-        return mel_outputs_postnet, mel_lengths, alignments
-
-    def encode_text(self, text):
-        """Runs inference for a single text str"""
-        return self.encode_batch([text])
-
-    def forward(self, texts):
-        "Encodes the input texts."
-        return self.encode_batch(texts)
-
-
-class FastSpeech2(Pretrained):
-    """
-    A ready-to-use wrapper for Fastspeech2 (text -> mel_spec).
-    Arguments
-    ---------
-    hparams
-        Hyperparameters (from HyperPyYAML)
-    Example
-    -------
-    >>> tmpdir_tts = getfixture('tmpdir') / "tts"
-    >>> fastspeech2 = FastSpeech2.from_hparams(source="speechbrain/tts-fastspeech2-ljspeech", savedir=tmpdir_tts)
-    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb"])
-    >>> items = [
-    ...   "A quick brown fox jumped over the lazy dog",
-    ...   "How much wood would a woodchuck chuck?",
-    ...   "Never odd or even"
-    ... ]
-    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(items)
-    >>>
-    >>> # One can combine the TTS model with a vocoder (that generates the final waveform)
-    >>> # Intialize the Vocoder (HiFIGAN)
-    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
-    >>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder)
-    >>> # Running the TTS
-    >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb"])
-    >>> # Running Vocoder (spectrogram-to-waveform)
-    >>> waveforms = hifi_gan.decode_batch(mel_outputs)
-    """
-
-    HPARAMS_NEEDED = ["spn_predictor", "model", "input_encoder"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        lexicon = self.hparams.lexicon
-        lexicon = ["@@"] + lexicon
-        self.input_encoder = self.hparams.input_encoder
-        self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
-        self.input_encoder.add_unk()
-
-        self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
-
-        self.spn_token_encoded = (
-            self.input_encoder.encode_sequence_torch(["spn"]).int().item()
-        )
-
-    def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
-        """Computes mel-spectrogram for a list of texts
-
-        Arguments
-        ---------
-        texts: List[str]
-            texts to be converted to spectrogram
-        pace: float
-            pace for the speech synthesis
-        pitch_rate : float
-            scaling factor for phoneme pitches
-        energy_rate : float
-            scaling factor for phoneme energies
-
-        Returns
-        -------
-        tensors of output spectrograms, output lengths and alignments
-        """
-
-        # Preprocessing required at the inference time for the input text
-        # "label" below contains input text
-        # "phoneme_labels" contain the phoneme sequences corresponding to input text labels
-        # "last_phonemes_combined" is used to indicate whether the index position is for a last phoneme of a word
-        # "punc_positions" is used to add back the silence for punctuations
-        phoneme_labels = list()
-        last_phonemes_combined = list()
-        punc_positions = list()
-
-        for label in texts:
-            phoneme_label = list()
-            last_phonemes = list()
-            punc_position = list()
-
-            words = label.split()
-            words = [word.strip() for word in words]
-            words_phonemes = self.g2p(words)
-
-            for i in range(len(words_phonemes)):
-                words_phonemes_seq = words_phonemes[i]
-                for phoneme in words_phonemes_seq:
-                    if not phoneme.isspace():
-                        phoneme_label.append(phoneme)
-                        last_phonemes.append(0)
-                        punc_position.append(0)
-                last_phonemes[-1] = 1
-                if words[i][-1] in ":;-,.!?":
-                    punc_position[-1] = 1
-
-            phoneme_labels.append(phoneme_label)
-            last_phonemes_combined.append(last_phonemes)
-            punc_positions.append(punc_position)
-
-        # Inserts silent phonemes in the input phoneme sequence
-        all_tokens_with_spn = list()
-        max_seq_len = -1
-        for i in range(len(phoneme_labels)):
-            phoneme_label = phoneme_labels[i]
-            token_seq = (
-                self.input_encoder.encode_sequence_torch(phoneme_label)
-                .int()
-                .to(self.device)
-            )
-            last_phonemes = torch.LongTensor(last_phonemes_combined[i]).to(
-                self.device
-            )
-
-            # Runs the silent phoneme predictor
-            spn_preds = (
-                self.hparams.modules["spn_predictor"]
-                .infer(token_seq.unsqueeze(0), last_phonemes.unsqueeze(0))
-                .int()
-            )
-
-            spn_to_add = torch.nonzero(spn_preds).reshape(-1).tolist()
-
-            for j in range(len(punc_positions[i])):
-                if punc_positions[i][j] == 1:
-                    spn_to_add.append(j)
-
-            tokens_with_spn = list()
-
-            for token_idx in range(token_seq.shape[0]):
-                tokens_with_spn.append(token_seq[token_idx].item())
-                if token_idx in spn_to_add:
-                    tokens_with_spn.append(self.spn_token_encoded)
-
-            tokens_with_spn = torch.LongTensor(tokens_with_spn).to(self.device)
-            all_tokens_with_spn.append(tokens_with_spn)
-            if max_seq_len < tokens_with_spn.shape[-1]:
-                max_seq_len = tokens_with_spn.shape[-1]
-
-        # "tokens_with_spn_tensor" holds the input phoneme sequence with silent phonemes
-        tokens_with_spn_tensor_padded = torch.LongTensor(
-            len(texts), max_seq_len
-        ).to(self.device)
-        tokens_with_spn_tensor_padded.zero_()
-
-        for seq_idx, seq in enumerate(all_tokens_with_spn):
-            tokens_with_spn_tensor_padded[seq_idx, : len(seq)] = seq
-
-        return self.encode_batch(
-            tokens_with_spn_tensor_padded,
-            pace=pace,
-            pitch_rate=pitch_rate,
-            energy_rate=energy_rate,
-        )
-
-    def encode_phoneme(
-        self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0
-    ):
-        """Computes mel-spectrogram for a list of phoneme sequences
-
-        Arguments
-        ---------
-        phonemes: List[List[str]]
-            phonemes to be converted to spectrogram
-        pace: float
-            pace for the speech synthesis
-        pitch_rate : float
-            scaling factor for phoneme pitches
-        energy_rate : float
-            scaling factor for phoneme energies
-
-        Returns
-        -------
-        tensors of output spectrograms, output lengths and alignments
-        """
-
-        all_tokens = []
-        max_seq_len = -1
-        for phoneme in phonemes:
-            token_seq = (
-                self.input_encoder.encode_sequence_torch(phoneme)
-                .int()
-                .to(self.device)
-            )
-            if max_seq_len < token_seq.shape[-1]:
-                max_seq_len = token_seq.shape[-1]
-            all_tokens.append(token_seq)
-
-        tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to(
-            self.device
-        )
-        tokens_padded.zero_()
-
-        for seq_idx, seq in enumerate(all_tokens):
-            tokens_padded[seq_idx, : len(seq)] = seq
-
-        return self.encode_batch(
-            tokens_padded,
-            pace=pace,
-            pitch_rate=pitch_rate,
-            energy_rate=energy_rate,
-        )
-
-    def encode_batch(
-        self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0
-    ):
-        """Batch inference for a tensor of phoneme sequences
-        Arguments
-        ---------
-        tokens_padded : torch.Tensor
-            A sequence of encoded phonemes to be converted to spectrogram
-        pace : float
-            pace for the speech synthesis
-        pitch_rate : float
-            scaling factor for phoneme pitches
-        energy_rate : float
-            scaling factor for phoneme energies
-        """
-        with torch.no_grad():
-
-            (
-                _,
-                post_mel_outputs,
-                durations,
-                pitch,
-                _,
-                energy,
-                _,
-                _,
-            ) = self.hparams.model(
-                tokens_padded,
-                pace=pace,
-                pitch_rate=pitch_rate,
-                energy_rate=energy_rate,
-            )
-
-            # Transposes to make in compliant with HiFI GAN expected format
-            post_mel_outputs = post_mel_outputs.transpose(-1, 1)
-
-        return post_mel_outputs, durations, pitch, energy
-
-    def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
-        """Batch inference for a tensor of phoneme sequences
-        Arguments
-        ---------
-        text : str
-            A text to be converted to spectrogram
-        pace : float
-            pace for the speech synthesis
-        pitch_rate : float
-            scaling factor for phoneme pitches
-        energy_rate : float
-            scaling factor for phoneme energies
-        """
-        return self.encode_text(
-            [text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate
-        )
-
-
-class HIFIGAN(Pretrained):
-    """
-    A ready-to-use wrapper for HiFiGAN (mel_spec -> waveform).
-    Arguments
-    ---------
-    hparams
-        Hyperparameters (from HyperPyYAML)
-    Example
-    -------
-    >>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
-    >>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder)
-    >>> mel_specs = torch.rand(2, 80,298)
-    >>> waveforms = hifi_gan.decode_batch(mel_specs)
-    >>> # You can use the vocoder coupled with a TTS system
-    >>>	# Initialize TTS (tacotron2)
-    >>> tmpdir_tts = getfixture('tmpdir') / "tts"
-    >>>	tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir=tmpdir_tts)
-    >>>	# Running the TTS
-    >>>	mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
-    >>>	# Running Vocoder (spectrogram-to-waveform)
-    >>>	waveforms = hifi_gan.decode_batch(mel_output)
-    """
-
-    HPARAMS_NEEDED = ["generator"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.infer = self.hparams.generator.inference
-        self.first_call = True
-
-    def decode_batch(self, spectrogram, mel_lens=None, hop_len=None):
-        """Computes waveforms from a batch of mel-spectrograms
-        Arguments
-        ---------
-        spectrogram: torch.Tensor
-            Batch of mel-spectrograms [batch, mels, time]
-        mel_lens: torch.tensor
-            A list of lengths of mel-spectrograms for the batch
-            Can be obtained from the output of Tacotron/FastSpeech
-        hop_len: int
-            hop length used for mel-spectrogram extraction
-            should be the same value as in the .yaml file
-        Returns
-        -------
-        waveforms: torch.Tensor
-            Batch of mel-waveforms [batch, 1, time]
-        """
-        # Prepare for inference by removing the weight norm
-        if self.first_call:
-            self.hparams.generator.remove_weight_norm()
-            self.first_call = False
-        with torch.no_grad():
-            waveform = self.infer(spectrogram.to(self.device))
-
-        # Mask the noise caused by padding during batch inference
-        if mel_lens is not None and hop_len is not None:
-            waveform = self.mask_noise(waveform, mel_lens, hop_len)
-
-        return waveform
-
-    def mask_noise(self, waveform, mel_lens, hop_len):
-        """Mask the noise caused by padding during batch inference
-        Arguments
-        ---------
-        wavform: torch.tensor
-            Batch of generated waveforms [batch, 1, time]
-        mel_lens: torch.tensor
-            A list of lengths of mel-spectrograms for the batch
-            Can be obtained from the output of Tacotron/FastSpeech
-        hop_len: int
-            hop length used for mel-spectrogram extraction
-            same value as in the .yaml file
-        Returns
-        -------
-        waveform: torch.tensor
-            Batch of waveforms without padded noise [batch, 1, time]
-        """
-        waveform = waveform.squeeze(1)
-        # the correct audio length should be hop_len * mel_len
-        mask = length_to_mask(
-            mel_lens * hop_len, waveform.shape[1], device=waveform.device
-        ).bool()
-        waveform.masked_fill_(~mask, 0.0)
-        return waveform.unsqueeze(1)
-
-    def decode_spectrogram(self, spectrogram):
-        """Computes waveforms from a single mel-spectrogram
-        Arguments
-        ---------
-        spectrogram: torch.Tensor
-            mel-spectrogram [mels, time]
-        Returns
-        -------
-        waveform: torch.Tensor
-            waveform [1, time]
-        audio can be saved by:
-        >>> waveform = torch.rand(1, 666666)
-        >>> sample_rate = 22050
-        >>> torchaudio.save(str(getfixture('tmpdir') / "test.wav"), waveform, sample_rate)
-        """
-        if self.first_call:
-            self.hparams.generator.remove_weight_norm()
-            self.first_call = False
-        with torch.no_grad():
-            waveform = self.infer(spectrogram.unsqueeze(0).to(self.device))
-        return waveform.squeeze(0)
-
-    def forward(self, spectrogram):
-        "Decodes the input spectrograms"
-        return self.decode_batch(spectrogram)
-
-
-class DiffWaveVocoder(Pretrained):
-    """
-    A ready-to-use inference wrapper for DiffWave as vocoder.
-    The wrapper allows to perform generative tasks:
-        locally-conditional generation: mel_spec -> waveform
-    Arguments
-    ---------
-    hparams
-        Hyperparameters (from HyperPyYAML)
-    """
-
-    HPARAMS_NEEDED = ["diffusion"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        if hasattr(self.hparams, "diffwave"):
-            self.infer = self.hparams.diffusion.inference
-        else:
-            raise NotImplementedError
-
-    def decode_batch(
-        self,
-        mel,
-        hop_len,
-        mel_lens=None,
-        fast_sampling=False,
-        fast_sampling_noise_schedule=None,
-    ):
-        """Generate waveforms from spectrograms
-        Arguments
-        ---------
-        mel: torch.tensor
-            spectrogram [batch, mels, time]
-        hop_len: int
-            Hop length during mel-spectrogram extraction
-            Should be the same value as in the .yaml file
-            Used to determine the output wave length
-            Also used to mask the noise for vocoding task
-        mel_lens: torch.tensor
-            Used to mask the noise caused by padding
-            A list of lengths of mel-spectrograms for the batch
-            Can be obtained from the output of Tacotron/FastSpeech
-        fast_sampling: bool
-            whether to do fast sampling
-        fast_sampling_noise_schedule: list
-            the noise schedules used for fast sampling
-        Returns
-        -------
-        waveforms: torch.tensor
-            Batch of mel-waveforms [batch, 1, time]
-
-        """
-        with torch.no_grad():
-            waveform = self.infer(
-                unconditional=False,
-                scale=hop_len,
-                condition=mel.to(self.device),
-                fast_sampling=fast_sampling,
-                fast_sampling_noise_schedule=fast_sampling_noise_schedule,
-            )
-
-        # Mask the noise caused by padding during batch inference
-        if mel_lens is not None and hop_len is not None:
-            waveform = self.mask_noise(waveform, mel_lens, hop_len)
-        return waveform
-
-    def mask_noise(self, waveform, mel_lens, hop_len):
-        """Mask the noise caused by padding during batch inference
-        Arguments
-        ---------
-        wavform: torch.tensor
-            Batch of generated waveforms [batch, 1, time]
-        mel_lens: torch.tensor
-            A list of lengths of mel-spectrograms for the batch
-            Can be obtained from the output of Tacotron/FastSpeech
-        hop_len: int
-            hop length used for mel-spectrogram extraction
-            same value as in the .yaml file
-        Returns
-        -------
-        waveform: torch.tensor
-            Batch of waveforms without padded noise [batch, 1, time]
-        """
-        waveform = waveform.squeeze(1)
-        # the correct audio length should be hop_len * mel_len
-        mask = length_to_mask(
-            mel_lens * hop_len, waveform.shape[1], device=waveform.device
-        ).bool()
-        waveform.masked_fill_(~mask, 0.0)
-        return waveform.unsqueeze(1)
-
-    def decode_spectrogram(
-        self,
-        spectrogram,
-        hop_len,
-        fast_sampling=False,
-        fast_sampling_noise_schedule=None,
-    ):
-        """Computes waveforms from a single mel-spectrogram
-        Arguments
-        ---------
-        spectrogram: torch.tensor
-            mel-spectrogram [mels, time]
-        hop_len: int
-            hop length used for mel-spectrogram extraction
-            same value as in the .yaml file
-        fast_sampling: bool
-            whether to do fast sampling
-        fast_sampling_noise_schedule: list
-            the noise schedules used for fast sampling
-        Returns
-        -------
-        waveform: torch.tensor
-            waveform [1, time]
-
-        audio can be saved by:
-        >>> waveform = torch.rand(1, 666666)
-        >>> sample_rate = 22050
-        >>> torchaudio.save(str(getfixture('tmpdir') / "test.wav"), waveform, sample_rate)
-        """
-        with torch.no_grad():
-            waveform = self.infer(
-                unconditional=False,
-                scale=hop_len,
-                condition=spectrogram.unsqueeze(0).to(self.device),
-                fast_sampling=fast_sampling,
-                fast_sampling_noise_schedule=fast_sampling_noise_schedule,
-            )
-        return waveform.squeeze(0)
-
-    def forward(self, spectrogram):
-        """Decodes the input spectrograms"""
-        return self.decode_batch(spectrogram)
-
-
-class WhisperASR(Pretrained):
-    """A ready-to-use Whisper ASR model
-
-    The class can be used  to  run the entire encoder-decoder whisper model
-    (transcribe()) to transcribe speech. The given YAML must contains the fields
-    specified in the *_NEEDED[] lists.
-
-    Example
-    -------
-    >>> from speechbrain.pretrained import WhisperASR
-    >>> tmpdir = getfixture("tmpdir")
-    >>> asr_model = WhisperASR.from_hparams(source="speechbrain/asr-whisper-large-v2-commonvoice-fr", savedir=tmpdir,) # doctest: +SKIP
-    >>> asr_model.transcribe_file("speechbrain/asr-whisper-large-v2-commonvoice-fr/example-fr.mp3") # doctest: +SKIP
-    """
-
-    HPARAMS_NEEDED = ["language"]
-    MODULES_NEEDED = ["whisper", "decoder"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.tokenizer = self.hparams.whisper.tokenizer
-        self.tokenizer.set_prefix_tokens(
-            self.hparams.language, "transcribe", False
-        )
-        self.hparams.decoder.set_decoder_input_tokens(
-            self.tokenizer.prefix_tokens
-        )
-
-    def transcribe_file(self, path):
-        """Transcribes the given audiofile into a sequence of words.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file which to transcribe.
-
-        Returns
-        -------
-        str
-            The audiofile transcription produced by this ASR system.
-        """
-        waveform = self.load_audio(path)
-        # Fake a batch:
-        batch = waveform.unsqueeze(0)
-        rel_length = torch.tensor([1.0])
-        predicted_words, predicted_tokens = self.transcribe_batch(
-            batch, rel_length
-        )
-        return predicted_words
-
-    def encode_batch(self, wavs, wav_lens):
-        """Encodes the input audio into a sequence of hidden states
-
-        The waveforms should already be in the model's desired format.
-        You can call:
-        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
-        to get a correctly converted signal in most cases.
-
-        Arguments
-        ---------
-        wavs : torch.tensor
-            Batch of waveforms [batch, time, channels].
-        wav_lens : torch.tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        torch.tensor
-            The encoded batch
-        """
-        wavs = wavs.float()
-        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-        encoder_out = self.mods.whisper.forward_encoder(wavs)
-        return encoder_out
-
-    def transcribe_batch(self, wavs, wav_lens):
-        """Transcribes the input audio into a sequence of words
-
-        The waveforms should already be in the model's desired format.
-        You can call:
-        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
-        to get a correctly converted signal in most cases.
-
-        Arguments
-        ---------
-        wavs : torch.tensor
-            Batch of waveforms [batch, time, channels].
-        wav_lens : torch.tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        list
-            Each waveform in the batch transcribed.
-        tensor
-            Each predicted token id.
-        """
-        with torch.no_grad():
-            wav_lens = wav_lens.to(self.device)
-            encoder_out = self.encode_batch(wavs, wav_lens)
-            predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens)
-            predicted_words = self.tokenizer.batch_decode(
-                predicted_tokens, skip_special_tokens=True
-            )
-            if self.hparams.normalized_transcripts:
-                predicted_words = [
-                    self.tokenizer._normalize(text).split(" ")
-                    for text in predicted_words
-                ]
-
-        return predicted_words, predicted_tokens
-
-    def forward(self, wavs, wav_lens):
-        """Runs full transcription - note: no gradients through decoding"""
-        return self.transcribe_batch(wavs, wav_lens)
-
-
-class Speech_Emotion_Diarization(Pretrained):
-    """A ready-to-use SED interface (audio -> emotions and their durations)
-
-    Arguments
-    ---------
-    hparams
-        Hyperparameters (from HyperPyYAML)
-
-    Example
-    -------
-    >>> from speechbrain.pretrained import Speech_Emotion_Diarization
-    >>> tmpdir = getfixture("tmpdir")
-    >>> sed_model = Speech_Emotion_Diarization.from_hparams(source="speechbrain/emotion-diarization-wavlm-large", savedir=tmpdir,) # doctest: +SKIP
-    >>> sed_model.diarize_file("speechbrain/emotion-diarization-wavlm-large/example.wav") # doctest: +SKIP
-    """
-
-    MODULES_NEEDED = ["input_norm", "wav2vec", "output_mlp"]
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def diarize_file(self, path):
-        """Get emotion diarization of a spoken utterance.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file which to diarize.
-
-        Returns
-        -------
-        list of dictionary: List[Dict[List]]
-            The emotions and their temporal boundaries.
-        """
-        waveform = self.load_audio(path)
-        # Fake a batch:
-        batch = waveform.unsqueeze(0)
-        rel_length = torch.tensor([1.0])
-        frame_class = self.diarize_batch(batch, rel_length, [path])
-        return frame_class
-
-    def encode_batch(self, wavs, wav_lens):
-        """Encodes audios into fine-grained emotional embeddings
-
-        Arguments
-        ---------
-        wavs : torch.tensor
-            Batch of waveforms [batch, time, channels].
-        wav_lens : torch.tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        torch.tensor
-            The encoded batch
-        """
-        if len(wavs.shape) == 1:
-            wavs = wavs.unsqueeze(0)
-
-        # Assign full length if wav_lens is not assigned
-        if wav_lens is None:
-            wav_lens = torch.ones(wavs.shape[0], device=self.device)
-
-        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
-
-        wavs = self.mods.input_norm(wavs, wav_lens)
-        outputs = self.mods.wav2vec2(wavs)
-        return outputs
-
-    def diarize_batch(self, wavs, wav_lens, batch_id):
-        """Get emotion diarization of a batch of waveforms.
-
-        The waveforms should already be in the model's desired format.
-        You can call:
-        ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
-        to get a correctly converted signal in most cases.
-
-        Arguments
-        ---------
-        wavs : torch.tensor
-            Batch of waveforms [batch, time, channels].
-        wav_lens : torch.tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-        batch_id : torch.tensor
-            id of each batch (file names etc.)
-
-        Returns
-        -------
-        list of dictionary: List[Dict[List]]
-            The emotions and their temporal boundaries.
-        """
-        outputs = self.encode_batch(wavs, wav_lens)
-        averaged_out = self.hparams.avg_pool(outputs)
-        outputs = self.mods.output_mlp(averaged_out)
-        outputs = self.hparams.log_softmax(outputs)
-        score, index = torch.max(outputs, dim=-1)
-        preds = self.hparams.label_encoder.decode_torch(index)
-        results = self.preds_to_diarization(preds, batch_id)
-        return results
-
-    def preds_to_diarization(self, prediction, batch_id):
-        """Convert frame-wise predictions into a dictionary of
-        diarization results.
-
-        Returns
-        -------
-        dictionary
-            A dictionary with the start/end of each emotion
-        """
-        results = {}
-
-        for i in range(len(prediction)):
-            pred = prediction[i]
-            lol = []
-            for j in range(len(pred)):
-                start = round(self.hparams.stride * 0.02 * j, 2)
-                end = round(start + self.hparams.window_length * 0.02, 2)
-                lol.append([batch_id[i], start, end, pred[j]])
-
-            lol = self.merge_ssegs_same_emotion_adjacent(lol)
-            results[batch_id[i]] = [
-                {"start": k[1], "end": k[2], "emotion": k[3]} for k in lol
-            ]
-            return results
-
-    def forward(self, wavs, wav_lens, batch_id):
-        """Get emotion diarization for a batch of waveforms."""
-        return self.diarize_batch(wavs, wav_lens, batch_id)
-
-    def is_overlapped(self, end1, start2):
-        """Returns True if segments are overlapping.
-
-        Arguments
-        ---------
-        end1 : float
-            End time of the first segment.
-        start2 : float
-            Start time of the second segment.
-
-        Returns
-        -------
-        overlapped : bool
-            True of segments overlapped else False.
-
-        Example
-        -------
-        >>> from speechbrain.processing import diarization as diar
-        >>> diar.is_overlapped(5.5, 3.4)
-        True
-        >>> diar.is_overlapped(5.5, 6.4)
-        False
-        """
-
-        if start2 > end1:
-            return False
-        else:
-            return True
-
-    def merge_ssegs_same_emotion_adjacent(self, lol):
-        """Merge adjacent sub-segs if they are the same emotion.
-        Arguments
-        ---------
-        lol : list of list
-            Each list contains [utt_id, sseg_start, sseg_end, emo_label].
-        Returns
-        -------
-        new_lol : list of list
-            new_lol contains adjacent segments merged from the same emotion ID.
-        Example
-        -------
-        >>> from speechbrain.utils.EDER import merge_ssegs_same_emotion_adjacent
-        >>> lol=[['u1', 0.0, 7.0, 'a'],
-        ... ['u1', 7.0, 9.0, 'a'],
-        ... ['u1', 9.0, 11.0, 'n'],
-        ... ['u1', 11.0, 13.0, 'n'],
-        ... ['u1', 13.0, 15.0, 'n'],
-        ... ['u1', 15.0, 16.0, 'a']]
-        >>> merge_ssegs_same_emotion_adjacent(lol)
-        [['u1', 0.0, 9.0, 'a'], ['u1', 9.0, 15.0, 'n'], ['u1', 15.0, 16.0, 'a']]
-        """
-        new_lol = []
-
-        # Start from the first sub-seg
-        sseg = lol[0]
-        flag = False
-        for i in range(1, len(lol)):
-            next_sseg = lol[i]
-            # IF sub-segments overlap AND has same emotion THEN merge
-            if (
-                self.is_overlapped(sseg[2], next_sseg[1])
-                and sseg[3] == next_sseg[3]
-            ):
-                sseg[2] = next_sseg[2]  # just update the end time
-                # This is important. For the last sseg, if it is the same emotion then merge
-                # Make sure we don't append the last segment once more. Hence, set FLAG=True
-                if i == len(lol) - 1:
-                    flag = True
-                    new_lol.append(sseg)
-            else:
-                new_lol.append(sseg)
-                sseg = next_sseg
-        # Add last segment only when it was skipped earlier.
-        if flag is False:
-            new_lol.append(lol[-1])
-        return new_lol
-
-
-class AudioClassifier(Pretrained):
-    """A ready-to-use class for utterance-level classification (e.g, speaker-id,
-    language-id, emotion recognition, keyword spotting, etc).
-
-    The class assumes that an encoder called "embedding_model" and a model
-    called "classifier" are defined in the yaml file. If you want to
-    convert the predicted index into a corresponding text label, please
-    provide the path of the label_encoder in a variable called 'lab_encoder_file'
-    within the yaml.
-
-    The class can be used either to run only the encoder (encode_batch()) to
-    extract embeddings or to run a classification step (classify_batch()).
-    ```
-
-    Example
-    -------
-    >>> import torchaudio
-    >>> from speechbrain.pretrained import AudioClassifier
-    >>> tmpdir = getfixture("tmpdir")
-    >>> classifier = AudioClassifier.from_hparams(
-    ...     source="speechbrain/cnn14-esc50",
-    ...     savedir=tmpdir,
-    ... )
-    >>> signal = torch.randn(1, 16000)
-    >>> prediction, _, _, text_lab = classifier.classify_batch(signal)
-    >>> print(prediction.shape)
-    torch.Size([1, 1, 50])
-    """
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def classify_batch(self, wavs, wav_lens=None):
-        """Performs classification on the top of the encoded features.
-
-        It returns the posterior probabilities, the index and, if the label
-        encoder is specified it also the text label.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model. Make sure the sample rate is fs=16000 Hz.
-        wav_lens : torch.Tensor
-            Lengths of the waveforms relative to the longest one in the
-            batch, tensor of shape [batch]. The longest one should have
-            relative length 1.0 and others len(waveform) / max_length.
-            Used for ignoring padding.
-
-        Returns
-        -------
-        out_prob
-            The log posterior probabilities of each class ([batch, N_class])
-        score:
-            It is the value of the log-posterior for the best class ([batch,])
-        index
-            The indexes of the best class ([batch,])
-        text_lab:
-            List with the text labels corresponding to the indexes.
-            (label encoder should be provided).
-        """
-        wavs = wavs.to(self.device)
-        X_stft = self.mods.compute_stft(wavs)
-        X_stft_power = speechbrain.processing.features.spectral_magnitude(
-            X_stft, power=self.hparams.spec_mag_power
-        )
-
-        if self.hparams.use_melspectra:
-            net_input = self.mods.compute_fbank(X_stft_power)
-        else:
-            net_input = torch.log1p(X_stft_power)
-
-        # Embeddings + sound classifier
-        embeddings = self.mods.embedding_model(net_input)
-        if embeddings.ndim == 4:
-            embeddings = embeddings.mean((-1, -2))
-
-        out_probs = self.mods.classifier(embeddings)
-        score, index = torch.max(out_probs, dim=-1)
-        text_lab = self.hparams.label_encoder.decode_torch(index)
-        return out_probs, score, index, text_lab
-
-    def classify_file(self, path, savedir="audio_cache"):
-        """Classifies the given audiofile into the given set of labels.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file to classify.
-
-        Returns
-        -------
-        out_prob
-            The log posterior probabilities of each class ([batch, N_class])
-        score:
-            It is the value of the log-posterior for the best class ([batch,])
-        index
-            The indexes of the best class ([batch,])
-        text_lab:
-            List with the text labels corresponding to the indexes.
-            (label encoder should be provided).
-        """
-        source, fl = split_path(path)
-        path = fetch(fl, source=source, savedir=savedir)
-
-        batch, fs_file = torchaudio.load(path)
-        batch = batch.to(self.device)
-        fs_model = self.hparams.sample_rate
-
-        # resample the data if needed
-        if fs_file != fs_model:
-            print(
-                "Resampling the audio from {} Hz to {} Hz".format(
-                    fs_file, fs_model
-                )
-            )
-            tf = torchaudio.transforms.Resample(
-                orig_freq=fs_file, new_freq=fs_model
-            ).to(self.device)
-            batch = batch.mean(dim=0, keepdim=True)
-            batch = tf(batch)
-
-        out_probs, score, index, text_lab = self.classify_batch(batch)
-        return out_probs, score, index, text_lab
-
-    def forward(self, wavs, wav_lens=None):
-        """Runs the classification"""
-        return self.classify_batch(wavs, wav_lens)
-
-
-class PIQAudioInterpreter(Pretrained):
-    """
-    This class implements the interface for the PIQ posthoc interpreter for an audio classifier.
-
-    Example
-    -------
-    >>> from speechbrain.pretrained import PIQAudioInterpreter
-    >>> tmpdir = getfixture("tmpdir")
-    >>> interpreter = PIQAudioInterpreter.from_hparams(
-    ...     source="speechbrain/PIQ-ESC50",
-    ...     savedir=tmpdir,
-    ... )
-    >>> signal = torch.randn(1, 16000)
-    >>> interpretation, _ = interpreter.interpret_batch(signal)
-    """
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def preprocess(self, wavs):
-        """Pre-process wavs to calculate STFTs"""
-        X_stft = self.mods.compute_stft(wavs)
-        X_stft_power = speechbrain.processing.features.spectral_magnitude(
-            X_stft, power=self.hparams.spec_mag_power
-        )
-        X_stft_logpower = torch.log1p(X_stft_power)
-
-        return X_stft_logpower, X_stft, X_stft_power
-
-    def classifier_forward(self, X_stft_logpower):
-        """the forward pass for the classifier"""
-        hcat = self.mods.embedding_model(X_stft_logpower)
-        embeddings = hcat.mean((-1, -2))
-        predictions = self.mods.classifier(embeddings).squeeze(1)
-        class_pred = predictions.argmax(1)
-        return hcat, embeddings, predictions, class_pred
-
-    def invert_stft_with_phase(self, X_int, X_stft_phase):
-        """Inverts STFT spectra given phase."""
-        X_stft_phase_sb = torch.cat(
-            (
-                torch.cos(X_stft_phase).unsqueeze(-1),
-                torch.sin(X_stft_phase).unsqueeze(-1),
-            ),
-            dim=-1,
-        )
-
-        X_stft_phase_sb = X_stft_phase_sb[:, : X_int.shape[1], :, :]
-        if X_int.ndim == 3:
-            X_int = X_int.unsqueeze(-1)
-        X_wpsb = X_int * X_stft_phase_sb
-        x_int_sb = self.mods.compute_istft(X_wpsb)
-        return x_int_sb
-
-    def interpret_batch(self, wavs):
-        """Classifies the given audio into the given set of labels.
-        It also provides the interpretation in the audio domain.
-
-        Arguments
-        ---------
-        wavs : torch.Tensor
-            Batch of waveforms [batch, time, channels] or [batch, time]
-            depending on the model. Make sure the sample rate is fs=16000 Hz.
-
-        Returns
-        -------
-        x_int_sound_domain
-            The interpretation in the waveform domain
-        text_lab:
-            The text label for the classification
-        fs_model:
-            The sampling frequency of the model. Useful to save the audio.
-        """
-        wavs = wavs.to(self.device)
-        X_stft_logpower, X_stft, X_stft_power = self.preprocess(wavs)
-        X_stft_phase = spectral_phase(X_stft)
-
-        # Embeddings + sound classifier
-        hcat, embeddings, predictions, class_pred = self.classifier_forward(
-            X_stft_logpower
-        )
-
-        if self.hparams.use_vq:
-            xhat, hcat, z_q_x = self.mods.psi(hcat, class_pred)
-        else:
-            xhat = self.mods.psi.decoder(hcat)
-        xhat = xhat.squeeze(1)
-        Tmax = xhat.shape[1]
-        if self.hparams.use_mask_output:
-            xhat = F.sigmoid(xhat)
-            X_int = xhat * X_stft_logpower[:, :Tmax, :]
-        else:
-            xhat = F.softplus(xhat)
-            th = xhat.max() * self.hparams.mask_th
-            X_int = (xhat > th) * X_stft_logpower[:, :Tmax, :]
-        X_int = torch.expm1(X_int)
-        x_int_sound_domain = self.invert_stft_with_phase(X_int, X_stft_phase)
-        text_lab = self.hparams.label_encoder.decode_torch(
-            class_pred.unsqueeze(0)
-        )
-
-        return x_int_sound_domain, text_lab
-
-    def interpret_file(self, path, savedir="audio_cache"):
-        """Classifies the given audiofile into the given set of labels.
-        It also provides the interpretation in the audio domain.
-
-        Arguments
-        ---------
-        path : str
-            Path to audio file to classify.
-
-        Returns
-        -------
-        x_int_sound_domain
-            The interpretation in the waveform domain
-        text_lab:
-            The text label for the classification
-        fs_model:
-            The sampling frequency of the model. Useful to save the audio.
-        """
-        source, fl = split_path(path)
-        path = fetch(fl, source=source, savedir=savedir)
-
-        batch, fs_file = torchaudio.load(path)
-        batch = batch.to(self.device)
-        fs_model = self.hparams.sample_rate
-
-        # resample the data if needed
-        if fs_file != fs_model:
-            print(
-                "Resampling the audio from {} Hz to {} Hz".format(
-                    fs_file, fs_model
-                )
-            )
-            tf = torchaudio.transforms.Resample(
-                orig_freq=fs_file, new_freq=fs_model
-            ).to(self.device)
-            batch = batch.mean(dim=0, keepdim=True)
-            batch = tf(batch)
-
-        x_int_sound_domain, text_lab = self.interpret_batch(batch)
-        return x_int_sound_domain, text_lab, fs_model
-
-    def forward(self, wavs, wav_lens=None):
-        """Runs the classification"""
-        return self.interpret_batch(wavs, wav_lens)
diff --git a/speechbrain/processing/features.py b/speechbrain/processing/features.py
index 052d99f06796c35bddd178e01710fb48bb97b969..2aba00eeb14bba9308c120ffc03a07b5d504f8b4 100644
--- a/speechbrain/processing/features.py
+++ b/speechbrain/processing/features.py
@@ -43,6 +43,7 @@ from speechbrain.utils.checkpoints import (
     register_checkpoint_hooks,
 )
 from speechbrain.dataio.dataio import length_to_mask
+from speechbrain.utils.filter_analysis import FilterProperties
 
 
 logger = logging.getLogger(__name__)
@@ -177,6 +178,17 @@ class STFT(torch.nn.Module):
 
         return stft
 
+    def get_filter_properties(self) -> FilterProperties:
+        if not self.center:
+            raise ValueError(
+                "ValueProperties cannot model a non-centered STFT, as it "
+                "assumes either centering or causality"
+            )
+
+        return FilterProperties(
+            window_size=self.win_length, stride=self.hop_length
+        )
+
 
 class ISTFT(torch.nn.Module):
     """Computes the Inverse Short-Term Fourier Transform (ISTFT)
@@ -1045,11 +1057,13 @@ class InputNormalization(torch.nn.Module):
                             self.weight = self.avg_factor
 
                         self.spk_dict_mean[spk_id] = (
-                            (1 - self.weight) * self.spk_dict_mean[spk_id]
+                            (1 - self.weight)
+                            * self.spk_dict_mean[spk_id].to(current_mean)
                             + self.weight * current_mean
                         )
                         self.spk_dict_std[spk_id] = (
-                            (1 - self.weight) * self.spk_dict_std[spk_id]
+                            (1 - self.weight)
+                            * self.spk_dict_std[spk_id].to(current_std)
                             + self.weight * current_std
                         )
 
@@ -1081,26 +1095,26 @@ class InputNormalization(torch.nn.Module):
                         self.glob_mean = current_mean
                         self.glob_std = current_std
 
-                    elif epoch < self.update_until_epoch:
+                    elif epoch is None or epoch < self.update_until_epoch:
                         if self.avg_factor is None:
                             self.weight = 1 / (self.count + 1)
                         else:
                             self.weight = self.avg_factor
 
-                        self.glob_mean = (
-                            1 - self.weight
-                        ) * self.glob_mean + self.weight * current_mean
+                        self.glob_mean = (1 - self.weight) * self.glob_mean.to(
+                            current_mean
+                        ) + self.weight * current_mean
 
-                        self.glob_std = (
-                            1 - self.weight
-                        ) * self.glob_std + self.weight * current_std
+                        self.glob_std = (1 - self.weight) * self.glob_std.to(
+                            current_std
+                        ) + self.weight * current_std
 
                     self.glob_mean.detach()
                     self.glob_std.detach()
 
                     self.count = self.count + 1
 
-                x = (x - self.glob_mean.data) / (self.glob_std.data)
+                x = (x - self.glob_mean.data.to(x)) / (self.glob_std.data.to(x))
 
         return x
 
@@ -1162,16 +1176,16 @@ class InputNormalization(torch.nn.Module):
         # Loading the spk_dict_mean in the right device
         self.spk_dict_mean = {}
         for spk in state["spk_dict_mean"]:
-            self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to(
-                self.device_inp
-            )
+            self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]  # .to(
+            #    self.device_inp
+            # )
 
         # Loading the spk_dict_std in the right device
         self.spk_dict_std = {}
         for spk in state["spk_dict_std"]:
-            self.spk_dict_std[spk] = state["spk_dict_std"][spk].to(
-                self.device_inp
-            )
+            self.spk_dict_std[spk] = state["spk_dict_std"][spk]  # .to(
+            #    self.device_inp
+            # )
 
         self.spk_dict_count = state["spk_dict_count"]
 
@@ -1201,17 +1215,16 @@ class InputNormalization(torch.nn.Module):
 
     @mark_as_transfer
     @mark_as_loader
-    def _load(self, path, end_of_epoch=False, device=None):
+    def _load(self, path, end_of_epoch=False):
         """Load statistic dictionary.
 
         Arguments
         ---------
         path : str
             The path of the statistic dictionary
-        device : str, None
-            Passed to torch.load(..., map_location=device)
         """
         del end_of_epoch  # Unused here.
+        device = "cpu"
         stats = torch.load(path, map_location=device)
         self._load_statistics_dict(stats)
 
diff --git a/speechbrain/processing/signal_processing.py b/speechbrain/processing/signal_processing.py
index 48693a86da9c6c53bc34ce287a27ed1a45676593..ea3be41ddcdc269206311ff9b4b7190ff45fe059 100644
--- a/speechbrain/processing/signal_processing.py
+++ b/speechbrain/processing/signal_processing.py
@@ -45,7 +45,7 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
     if len(waveforms.shape) == 1:
         waveforms = waveforms.unsqueeze(0)
 
-    assert amp_type in ["avg", "peak"]
+    assert amp_type in ["avg", "rms", "peak"]
     assert scale in ["linear", "dB"]
 
     if amp_type == "avg":
@@ -53,7 +53,21 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
             out = torch.mean(torch.abs(waveforms), dim=1, keepdim=True)
         else:
             wav_sum = torch.sum(input=torch.abs(waveforms), dim=1, keepdim=True)
+            # Manage multi-channel waveforms
+            if len(wav_sum.shape) == 3 and isinstance(lengths, torch.Tensor):
+                lengths = lengths.unsqueeze(2)
             out = wav_sum / lengths
+    elif amp_type == "rms":
+        if lengths is None:
+            out = torch.sqrt(torch.mean(waveforms ** 2, dim=1, keepdim=True))
+        else:
+            wav_sum = torch.sum(
+                input=torch.pow(waveforms, 2), dim=1, keepdim=True
+            )
+            if len(wav_sum.shape) == 3 and isinstance(lengths, torch.Tensor):
+                lengths = lengths.unsqueeze(2)
+            out = torch.sqrt(wav_sum / lengths)
+
     elif amp_type == "peak":
         out = torch.max(torch.abs(waveforms), dim=1, keepdim=True)[0]
     else:
@@ -104,6 +118,31 @@ def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14):
     return waveforms / den
 
 
+def mean_std_norm(waveforms, dims=1, eps=1e-06):
+    """This function normalizes the mean and std of the input
+        waveform (along the specified axis).
+
+    Arguments
+    ---------
+    waveforms : tensor
+        The waveforms to normalize.
+        Shape should be `[batch, time]` or `[batch, time, channels]`.
+    dim: int or tuple
+        The dimension(s) on which mean and std are computed
+    eps : float
+        A small number to add to the denominator to prevent NaN.
+
+    Returns
+    -------
+    waveforms : tensor
+        Normalized level waveform.
+    """
+    mean = waveforms.mean(dims, keepdim=True)
+    std = waveforms.std(dims, keepdim=True)
+    waveforms = (waveforms - mean) / (std + eps)
+    return waveforms
+
+
 def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"):
     """This functions performs signal rescaling to a target level.
 
@@ -219,7 +258,7 @@ def convolve1d(
     # Padding can be a tuple (left_pad, right_pad) or an int
     if isinstance(padding, tuple):
         waveform = torch.nn.functional.pad(
-            input=waveform, pad=padding, mode=pad_type,
+            input=waveform, pad=padding, mode=pad_type
         )
 
     # This approach uses FFT, which is more efficient if the kernel is large
diff --git a/speechbrain/tokenizers/SentencePiece.py b/speechbrain/tokenizers/SentencePiece.py
index 55559e7f1a2656371c693dbd3421fdd95973a5df..20521c157377a4879be0e011f9d011ddc02ff11c 100644
--- a/speechbrain/tokenizers/SentencePiece.py
+++ b/speechbrain/tokenizers/SentencePiece.py
@@ -9,6 +9,8 @@ import torch
 import logging
 import csv
 import json
+from dataclasses import dataclass
+from typing import List
 import sentencepiece as spm
 from speechbrain.dataio.dataio import merge_char
 from speechbrain.utils import edit_distance
@@ -463,3 +465,85 @@ class SentencePiece:
                 ).split(" ")
                 for i, utt_seq in enumerate(batch)
             ]
+
+
+def get_spm_tokens(model_path):
+    """Fetch list of tokens, can be indexed by token id
+
+    The resulting list can be used to map id to token.
+
+    Arguments
+    ---------
+    model_path : str
+        Path to SentencePiece model
+
+    Returns
+    -------
+    list
+        Tokens in order by id (can be indexed by id)
+    """
+    model = spm.SentencePieceProcessor()
+    model.load(model_path)
+    mapping = [model.sp.id_to_piece(i) for i in range(model.sp.vocab_size())]
+    return mapping
+
+
+@dataclass
+class SentencePieceDecoderStreamingContext:
+    """Mutable streaming context for a single SentencePiece streaming session.
+    """
+
+    emitted_symbol_count: int = 0
+    """The number of symbols that have been emitted for this transcription."""
+
+
+def spm_decode_preserve_leading_space(
+    tokenizer: spm.SentencePieceProcessor,
+    hyps: List[int],
+    context: SentencePieceDecoderStreamingContext,
+) -> List[str]:
+    """Assuming the tokenizer is sentencepiece, decodes the input hypothesis
+    but avoids incorrectly stripping leading spaces when streaming.
+    Operates on a single hypothesis, not a batch of hypotheses.
+
+    Normally, the tokenizer always decodes full sentences at a time, with the
+    consequence that the first space in decoding will get removed.
+    However, when streaming, we might be decoding mid-utterance where spaces
+    must not be removed mid-sentence. This function handles this case.
+
+    e.g. if within the same streaming context, you decode `["▁how", "▁are"]`
+    then `["▁you"]`, the decoder would normally return `"how areyou"` instead of
+    `"how are you"` like this function does.
+
+    Arguments
+    ---------
+    tokenizer : sentencepiece.SentencePieceProcessor
+        The SentencePiece processor to use for decoding.
+    hyps : list of output token hypotheses
+        List of tokens to decode of any length `>=0`.
+    context : SentencePieceDecoderStreamingContext
+        Mutable streaming context for the sentencepiece decoder, which should be
+        reused across calls for the same decoding stream.
+
+    Returns
+    -------
+    str
+        Decoded text. Leading spaces are preserved, except at the start of a
+        transcription."""
+
+    proto = tokenizer.decode([hyps], out_type="immutable_proto")[0]
+    text = proto.text
+
+    if len(proto.pieces) >= 1:
+        should_preserve_space = context.emitted_symbol_count > 0
+        # By default, SentencePiece tags spaces with `▁` i.e. \u2581
+        # (unicode for "Lower One Eighth Block").
+        if should_preserve_space and proto.pieces[0].piece.startswith("\u2581"):
+            # We are mid-sentence and the decoder has nuked the first space,
+            # as the decoder believes we are decoding a full sentence.
+            # Insert it back.
+            text = " " + text
+
+        context.emitted_symbol_count += len(proto.pieces)
+
+    return text
diff --git a/speechbrain/utils/__init__.py b/speechbrain/utils/__init__.py
index b7c3dfa83868e32dbafa003dec13a5885546d3b1..6f8753f6c5fe1ccaafd2430ce5811e7cfa00f8ea 100644
--- a/speechbrain/utils/__init__.py
+++ b/speechbrain/utils/__init__.py
@@ -5,7 +5,11 @@ import os
 __all__ = []
 for filename in os.listdir(os.path.dirname(__file__)):
     filename = os.path.basename(filename)
-    if filename.endswith(".py") and not filename.startswith("__"):
+    if (
+        filename.endswith(".py")
+        and not filename.startswith("__")
+        and not filename == "kmeans.py"
+    ):
         __all__.append(filename[:-3])
 
 from . import *  # noqa
diff --git a/speechbrain/utils/_workarounds.py b/speechbrain/utils/_workarounds.py
index ba82c7c7d48a1df60ac2561117e9ca690fcbc542..59cafd76ccc4c016481a8df47e1d7f1406a3ac15 100644
--- a/speechbrain/utils/_workarounds.py
+++ b/speechbrain/utils/_workarounds.py
@@ -17,8 +17,9 @@ def _cycliclrsaver(obj, path):
     torch.save(state_dict, path)
 
 
-def _cycliclrloader(obj, path, end_of_epoch, device=None):
+def _cycliclrloader(obj, path, end_of_epoch):
     del end_of_epoch  # Unused
+    device = "cpu"
     state_dict = torch.load(path, map_location=device)
     if state_dict.get("_scale_fn_ref") == WEAKREF_MARKER:
         if not isinstance(obj._scale_fn_ref, weakref.WeakMethod):
diff --git a/speechbrain/utils/autocast.py b/speechbrain/utils/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c324a966a2367257d90319fb1a2525c6e37bebb
--- /dev/null
+++ b/speechbrain/utils/autocast.py
@@ -0,0 +1,79 @@
+"""This module implements utilities and abstractions for use with
+`torch.autocast`, i.e. Automatic Mixed Precision.
+
+Authors
+ * Sylvain de Langen 2023
+"""
+import functools
+from typing import Callable, Optional
+import torch
+
+
+def fwd_default_precision(
+    fwd: Optional[Callable] = None,
+    cast_inputs: Optional[torch.dtype] = torch.float32,
+):
+    """Decorator for forward methods which, by default, *disables* autocast
+    and casts any floating-point tensor parameters into the specified dtype
+    (much like `torch.cuda.amp.custom_fwd`).
+
+    The *wrapped forward* will gain an additional `force_allow_autocast` keyword
+    parameter.
+    When set to `True`, the function will ignore `cast_inputs` and will not
+    disable autocast, as if this decorator was not specified.
+    (Thus, modules can specify a default recommended precision, and users can
+    override that behavior when desired.)
+
+    Note that as of PyTorch 2.1.1, this will **only** affect **CUDA** AMP.
+    Non-CUDA AMP will be unaffected and no input tensors will be cast!
+    This usecase may be supported by this function in the future.
+
+    When autocast is *not* active, this decorator does not change any behavior.
+
+    Arguments
+    ---------
+    fwd: Optional[Callable]
+        The function to wrap. If omitted, returns a partial application of the
+        decorator, e.g. allowing
+        `new_decorator = fwd_default_precision(cast_inputs=torch.float32)`.
+
+        Reminder: If you are decorating a function directly, this argument is
+        already specified implicitly.
+
+    cast_inputs: Optional[torch.dtype]
+        If not `None` (the default being `torch.float32`), then any
+        floating-point inputs to the wrapped function will be cast to the
+        specified type.
+
+        Note: When autocasting is enabled, output tensors of autocast-compatible
+        operations may be of the autocast data type.
+        Disabling autocast *without* casting inputs will not change this fact,
+        so lower precision operations can happen even inside of an
+        autocast-disabled region, which this argument helps avoid if desired.
+    """
+
+    if fwd is None:
+        return functools.partial(fwd_default_precision, cast_inputs=cast_inputs)
+
+    # NOTE: torch.cuda.amp.custom_fwd is written with the assumption of CUDA
+    # autocast. There does not seem to be a generic equivalent.
+    # Detecting CUDA AMP specifically also seems difficult or impossible, so we
+    # cannot even reliably warn about the issue. For now, we just document the
+    # problem.
+    wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
+
+    @functools.wraps(fwd)
+    def wrapper(*args, force_allow_autocast: bool = False, **kwargs):
+        """Wrapped forward function from fwd_default_precision.
+
+        Arguments
+        ---------
+        force_allow_autocast: bool
+            When `True`, the wrapped function will be executed directly with no
+            change to the autocast context and no input casting."""
+        if force_allow_autocast:
+            return fwd(*args, **kwargs)
+        else:
+            return wrapped_fwd(*args, **kwargs)
+
+    return wrapper
diff --git a/speechbrain/utils/bleu.py b/speechbrain/utils/bleu.py
index cfde75b84d0bbe1ec9a938cc4108e61b28519a42..0a6dcd51304033cfddad12e5ef54648e8ec1c769 100644
--- a/speechbrain/utils/bleu.py
+++ b/speechbrain/utils/bleu.py
@@ -46,12 +46,18 @@ class BLEUStats(MetricStats):
     0.0
     """
 
-    def __init__(
-        self, lang="en", merge_words=True,
-    ):
+    def __init__(self, lang="en", merge_words=True, max_ngram_order=4):
+        # Check extra-dependency for computing the bleu score
+        try:
+            from sacrebleu.metrics import BLEU
+        except ImportError:
+            print(
+                "Please install sacrebleu (https://pypi.org/project/sacrebleu/) in order to use the BLEU metric"
+            )
 
         self.clear()
         self.merge_words = merge_words
+        self.bleu = BLEU(max_ngram_order=max_ngram_order)
 
         self.predicts = []
         self.targets = None
@@ -97,15 +103,7 @@ class BLEUStats(MetricStats):
         * See MetricStats.summarize()
         """
 
-        # Check extra-dependency for computing the bleu score
-        try:
-            import sacrebleu
-        except ImportError:
-            print(
-                "Please install sacrebleu (https://pypi.org/project/sacrebleu/) in order to use the BLEU metric"
-            )
-
-        scores = sacrebleu.corpus_bleu(self.predicts, self.targets)
+        scores = self.bleu.corpus_score(self.predicts, self.targets)
         details = {}
         details["BLEU"] = scores.score
         details["BP"] = scores.bp
diff --git a/speechbrain/utils/checkpoints.py b/speechbrain/utils/checkpoints.py
index 839c035ef35fe93e83d2e880e4cc7680c580b9bc..39fe6eba2668a10331aaf51a25cf5f61070dc50c 100644
--- a/speechbrain/utils/checkpoints.py
+++ b/speechbrain/utils/checkpoints.py
@@ -60,7 +60,12 @@ import logging
 import warnings
 from packaging import version
 import speechbrain.utils._workarounds as __wa
-from speechbrain.utils.distributed import main_process_only, if_main_process
+from speechbrain.utils.distributed import (
+    main_process_only,
+    if_main_process,
+    ddp_barrier,
+    ddp_broadcast,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -69,7 +74,7 @@ METAFNAME = f"{CKPT_PREFIX}.yaml"  # Important that this is not .ckpt
 PARAMFILE_EXT = ".ckpt"  # ...because these files will be
 
 
-def torch_recovery(obj, path, end_of_epoch, device=None):
+def torch_recovery(obj, path, end_of_epoch):
     """Loads a torch.nn.Module state_dict from the given path instantly.
 
     This can be made the default for torch.nn.Modules with:
@@ -83,8 +88,6 @@ def torch_recovery(obj, path, end_of_epoch, device=None):
         Path where to load from.
     end_of_epoch : bool
         Whether the recovery comes from an end of epoch checkpoint.
-    device : str
-        Torch device, where to map the loaded parameters.
 
     Returns
     -------
@@ -92,6 +95,7 @@ def torch_recovery(obj, path, end_of_epoch, device=None):
         Given object is modified in place.
     """
     del end_of_epoch  # Unused
+    device = "cpu"
     try:
         obj.load_state_dict(torch.load(path, map_location=device), strict=True)
     except TypeError:
@@ -123,7 +127,7 @@ def torch_save(obj, path):
     torch.save(state_dict, path)
 
 
-def torch_parameter_transfer(obj, path, device):
+def torch_parameter_transfer(obj, path):
     """Non-strict Torch Module state_dict load.
 
     Loads a set of parameters from path to obj. If obj has layers for which
@@ -143,6 +147,7 @@ def torch_parameter_transfer(obj, path, device):
     None
         The object is modified in place.
     """
+    device = "cpu"
     incompatible_keys = obj.load_state_dict(
         torch.load(path, map_location=device), strict=False
     )
@@ -188,7 +193,7 @@ DEFAULT_TRANSFER_HOOKS = {
 try:
     import sentencepiece as spm
 
-    def _load_spm(obj, path, device=None):
+    def _load_spm(obj, path):
         obj.load(str(path))  # SentencePieceProcessor needs a string.
 
     DEFAULT_TRANSFER_HOOKS[spm.SentencePieceProcessor] = _load_spm
@@ -237,9 +242,9 @@ def mark_as_loader(method):
     ---------
     method : callable
         Method of the class to decorate. Must be callable with
-        signature (instance, path, end_of_epoch, device) using positional
+        signature (instance, path, end_of_epoch) using positional
         arguments. This is satisfied by for example:
-        `def loader(self, path, end_of_epoch, device):`
+        `def loader(self, path, end_of_epoch):`
 
     Note
     ----
@@ -249,9 +254,9 @@ def mark_as_loader(method):
     """
     sig = inspect.signature(method)
     try:
-        sig.bind(object(), pathlib.Path("testpath"), True, None)
+        sig.bind(object(), pathlib.Path("testpath"), True)
     except TypeError:
-        MSG = "Checkpoint loader must have signature (self, path, end_of_epoch, device)"
+        MSG = "Checkpoint loader must have signature (self, path, end_of_epoch)"
         raise TypeError(MSG)
     method._speechbrain_loader = True
     return method
@@ -264,9 +269,9 @@ def mark_as_transfer(method):
     ---------
     method : callable
         Method of the class to decorate. Must be callable with
-        signature (instance, path, device) using positional
+        signature (instance, path) using positional
         arguments. This is satisfied by for example:
-        `def loader(self, path, device):`
+        `def loader(self, path):`
 
     Note
     ----
@@ -282,9 +287,9 @@ def mark_as_transfer(method):
     """
     sig = inspect.signature(method)
     try:
-        sig.bind(object(), pathlib.Path("testpath"), device=None)
+        sig.bind(object(), pathlib.Path("testpath"))
     except TypeError:
-        MSG = "Transfer hook must have signature (self, path, device)"
+        MSG = "Transfer hook must have signature (self, path)"
         raise TypeError(MSG)
     method._speechbrain_transfer = True
     return method
@@ -318,7 +323,7 @@ def register_checkpoint_hooks(cls, save_on_main_only=True):
     ...             fo.write(str(self.param))
     ...
     ...     @mark_as_loader
-    ...     def load(self, path, end_of_epoch, device=None):
+    ...     def load(self, path, end_of_epoch):
     ...         del end_of_epoch  # Unused here
     ...         with open(path) as fi:
     ...             self.param = int(fi.read())
@@ -470,6 +475,7 @@ class Checkpointer:
         self.checkpoints_dir = pathlib.Path(checkpoints_dir)
         os.makedirs(self.checkpoints_dir, exist_ok=True)
         self.recoverables = {}
+        self.optional_recoverables = {}
         if recoverables is not None:
             self.add_recoverables(recoverables)
         self.custom_load_hooks = {}
@@ -481,7 +487,12 @@ class Checkpointer:
         self.allow_partial_load = allow_partial_load
 
     def add_recoverable(
-        self, name, obj, custom_load_hook=None, custom_save_hook=None,
+        self,
+        name,
+        obj,
+        custom_load_hook=None,
+        custom_save_hook=None,
+        optional_load=False,
     ):
         """Register a recoverable with possible custom hooks.
 
@@ -499,8 +510,20 @@ class Checkpointer:
             Called to save the object's parameters. The function/method must
             be callable with signature (instance, path) using positional
             arguments. This is satisfied by for example: def saver(self, path):
+        optional_load : bool, optional
+            If True, allows for the optional loading of an object from a checkpoint.
+            If the checkpoint lacks the specified object, no error is raised.
+            This is particularly useful during transitions between different training
+            configurations, such as changing precision from floating point 32 to 16.
+            For example, suppose you have a training checkpoint that does not includes
+            a `scaler` object. If you intend to continue pre-training in floating point 16,
+            where the `scaler` object is needed, marking it as optional prevents loading errors.
+            Without marking it as optional, attempting to load the `scaler` object from a checkpoint
+            trained in floating point 32 would fail, as the `scaler` object is not present
+            in that checkpoint.
         """
         self.recoverables[name] = obj
+        self.optional_recoverables[name] = optional_load
         if custom_load_hook is not None:
             self.custom_load_hooks[name] = custom_load_hook
         if custom_save_hook is not None:
@@ -575,16 +598,13 @@ class Checkpointer:
                 ckpt_dir = self._new_checkpoint_dirpath()
             else:
                 ckpt_dir = self._custom_checkpoint_dirpath(name)
-            os.makedirs(ckpt_dir)  # May raise FileExistsError, let it.
+            os.makedirs(ckpt_dir, exist_ok=True)
             saved_meta = self._save_checkpoint_metafile(
                 ckpt_dir / METAFNAME, meta, end_of_epoch
             )
 
         # Communicate ckpt_dir to all procs
-        communication_list = [ckpt_dir]
-        if torch.distributed.is_initialized():
-            torch.distributed.broadcast_object_list(communication_list, src=0)
-        ckpt_dir = communication_list[0]
+        ckpt_dir = ddp_broadcast(ckpt_dir, src=0)
 
         saved_paramfiles = {}
         for name, obj in self.recoverables.items():
@@ -725,6 +745,9 @@ class Checkpointer:
             See the filter builtin.
             The function is called with Checkpoint namedtuples (see above).
             By default, all checkpoints are considered.
+        max_num_checkpoints : int, None
+            The maximum number of checkpoints to return, or None to return all
+            found checkpoints.
 
         Returns
         -------
@@ -846,7 +869,6 @@ class Checkpointer:
         max_key=None,
         min_key=None,
         ckpt_predicate=None,
-        device=None,
     ):
         """Picks a checkpoint and recovers from that, if one is found.
 
@@ -874,8 +896,6 @@ class Checkpointer:
             See the filter builtin.
             The function is called with Checkpoint namedtuples (see above).
             By default, all checkpoints are considered.
-        device : torch.device
-            Device to load models to.
 
         Returns
         -------
@@ -888,12 +908,12 @@ class Checkpointer:
             importance_key, max_key, min_key, ckpt_predicate,
         )
         if chosen_ckpt is not None:
-            self.load_checkpoint(chosen_ckpt, device)
+            self.load_checkpoint(chosen_ckpt)
         else:
             logger.info("Would load a checkpoint here, but none found yet.")
         return chosen_ckpt
 
-    def load_checkpoint(self, checkpoint, device=None):
+    def load_checkpoint(self, checkpoint):
         """Loads the specified checkpoint.
 
         Arguments
@@ -901,7 +921,7 @@ class Checkpointer:
         checkpoint : Checkpoint
             Checkpoint to load.
         """
-        self._call_load_hooks(checkpoint, device)
+        self._call_load_hooks(checkpoint)
 
     def list_checkpoints(self):
         """List all checkpoints in the checkpoints directory.
@@ -990,11 +1010,21 @@ class Checkpointer:
                 )
             )
 
+        # Sync before deleting to avoid another process saving at the same time.
+        # This has led to errors as documented here:
+        # https://github.com/speechbrain/speechbrain/issues/2250
+        ddp_barrier()
+
         # Delete unprotected checkpoints
         for ckpt in potential_deletions:
             if ckpt not in protected_checkpoints:
                 Checkpointer._delete_checkpoint(ckpt, verbosity=verbosity)
 
+        # Sync after deleting to avoid another process saving at the same time.
+        # This has led to errors as documented here:
+        # https://github.com/speechbrain/speechbrain/issues/2250
+        ddp_barrier()
+
     @staticmethod
     @main_process_only
     def _delete_checkpoint(checkpoint, verbosity=logging.INFO):
@@ -1003,7 +1033,7 @@ class Checkpointer:
         shutil.rmtree(checkpoint.path)
         logger.log(verbosity, f"Deleted checkpoint in {checkpoint.path}")
 
-    def _call_load_hooks(self, checkpoint, device=None):
+    def _call_load_hooks(self, checkpoint):
         # This internal function finds the correct hook to call for every
         # recoverable, and calls it.
         logger.info(f"Loading a checkpoint from {checkpoint.path}")
@@ -1023,6 +1053,12 @@ class Checkpointer:
                     warnings.warn(MSG, UserWarning)
                     continue
                 else:
+                    if self.optional_recoverables[name]:
+                        MSG = f"Trying to load checkpoint from {checkpoint.path}, \
+                                but missing a load path for {name}. Skipping as this \
+                                recoverable is marked as optional."
+                        warnings.warn(MSG, UserWarning)
+                        continue
                     MSG = f"Loading checkpoint from {checkpoint.path}, \
                             but missing a load path for {name}"
                     raise RuntimeError(MSG)
@@ -1030,13 +1066,13 @@ class Checkpointer:
             # First see if object has custom load hook:
             if name in self.custom_load_hooks:
                 self.custom_load_hooks[name](
-                    obj, loadpath, end_of_epoch, device
+                    obj, loadpath, end_of_epoch,
                 )
                 continue
             # Otherwise find the default saver for that type:
             default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)
             if default_hook is not None:
-                default_hook(obj, loadpath, end_of_epoch, device)
+                default_hook(obj, loadpath, end_of_epoch)
                 continue
             # If we got here, no custom hook or registered default hook exists
             MSG = f"Don't know how to load {type(obj)}. Register default hook \
@@ -1147,7 +1183,6 @@ def average_checkpoints(
     recoverable_name,
     parameter_loader=torch.load,
     averager=average_state_dicts,
-    device=None,
 ):
     """Average parameters from multiple checkpoints.
 
@@ -1208,18 +1243,9 @@ def average_checkpoints(
     >>> model.param.data
     tensor([8.])
     """
-
-    try:
-        # try to map the ckps to the correct device
-        parameter_iterator = (
-            parameter_loader(
-                ckpt.paramfiles[recoverable_name], map_location=device
-            )
-            for ckpt in checkpoint_list
-        )
-    except TypeError:
-        parameter_iterator = (
-            parameter_loader(ckpt.paramfiles[recoverable_name])
-            for ckpt in checkpoint_list
-        )
+    device = "cpu"
+    parameter_iterator = (
+        parameter_loader(ckpt.paramfiles[recoverable_name], map_location=device)
+        for ckpt in checkpoint_list
+    )
     return averager(parameter_iterator)
diff --git a/speechbrain/utils/data_utils.py b/speechbrain/utils/data_utils.py
index 34be37654ae7f43864d2544d0d10b3b16cbaece8..da6d988cfa077487b74cfd5a74146f3014bec687 100644
--- a/speechbrain/utils/data_utils.py
+++ b/speechbrain/utils/data_utils.py
@@ -4,6 +4,8 @@ Authors
  * Mirco Ravanelli 2020
  * Aku Rouhe 2020
  * Samuele Cornell 2020
+ * Adel Moumen 2024
+ * Pierre Champion 2023
 """
 
 import math
@@ -18,6 +20,7 @@ import tqdm
 import pathlib
 import speechbrain as sb
 from numbers import Number
+import gzip
 
 
 def undo_padding(batch, lengths):
@@ -351,7 +354,18 @@ def download_file(
                 if dest_unpack is None:
                     dest_unpack = os.path.dirname(dest)
                 print(f"Extracting {dest} to {dest_unpack}")
-                shutil.unpack_archive(dest, dest_unpack)
+                # shutil unpack_archive does not work with tar.gz files
+                if (
+                    source.endswith(".tar.gz")
+                    or source.endswith(".tgz")
+                    or source.endswith(".gz")
+                ):
+                    out = dest.replace(".gz", "")
+                    with gzip.open(dest, "rb") as f_in:
+                        with open(out, "wb") as f_out:
+                            shutil.copyfileobj(f_in, f_out)
+                else:
+                    shutil.unpack_archive(dest, dest_unpack)
                 if write_permissions:
                     set_writing_permissions(dest_unpack)
 
@@ -376,7 +390,7 @@ def set_writing_permissions(folder_path):
 
 
 def pad_right_to(
-    tensor: torch.Tensor, target_shape: (list, tuple), mode="constant", value=0,
+    tensor: torch.Tensor, target_shape, mode="constant", value=0,
 ):
     """
     This function takes a torch tensor of arbitrary shape and pads it to target
@@ -596,10 +610,10 @@ def split_path(path):
             # Interpret as path to file in current directory.
             return "./", src
 
-    if isinstance(path, sb.pretrained.fetching.FetchSource):
+    if isinstance(path, sb.utils.fetching.FetchSource):
         fetch_from, fetch_path = path
         source, filename = split(fetch_path)
-        return sb.pretrained.fetching.FetchSource(fetch_from, source), filename
+        return sb.utils.fetching.FetchSource(fetch_from, source), filename
     else:
         return split(path)
 
diff --git a/speechbrain/utils/distributed.py b/speechbrain/utils/distributed.py
index 043b2efac03b924fc57821e303cfcaf46a4f5894..6afd548a90d8040afe2366151f0c02892846362a 100644
--- a/speechbrain/utils/distributed.py
+++ b/speechbrain/utils/distributed.py
@@ -3,13 +3,14 @@
 Authors:
  * Abdel Heba 2020
  * Aku Rouhe 2020
+ * Peter Plantinga 2023
 """
+import datetime
 import os
 import torch
-import logging
 from functools import wraps
 
-logger = logging.getLogger(__name__)
+MAIN_PROC_ONLY = 0
 
 
 def run_on_main(
@@ -56,27 +57,17 @@ def run_on_main(
     if post_kwargs is None:
         post_kwargs = {}
 
-    if if_main_process():
-        # Main comes here
-        try:
-            func(*args, **kwargs)
-        finally:
-            ddp_barrier()
-    else:
-        # Others go here
-        ddp_barrier()
+    main_process_only(func)(*args, **kwargs)
+    ddp_barrier()
+
     if post_func is not None:
         if run_post_on_main:
             # Just run on every process without any barrier.
             post_func(*post_args, **post_kwargs)
-        elif not if_main_process():
-            # Others go here
-            try:
-                post_func(*post_args, **post_kwargs)
-            finally:
-                ddp_barrier()
         else:
-            # But main comes here
+            # Do the opposite of `run_on_main`
+            if not if_main_process():
+                post_func(*post_args, **post_kwargs)
             ddp_barrier()
 
 
@@ -105,8 +96,14 @@ def main_process_only(function):
     @wraps(function)
     def main_proc_wrapped_func(*args, **kwargs):
         """This decorated function runs only if this is the main process."""
+        global MAIN_PROC_ONLY
+        MAIN_PROC_ONLY += 1
         if if_main_process():
-            return function(*args, **kwargs)
+            result = function(*args, **kwargs)
+        else:
+            result = None
+        MAIN_PROC_ONLY -= 1
+        return result
 
     return main_proc_wrapped_func
 
@@ -116,10 +113,39 @@ def ddp_barrier():
     torch.distributed.barrier() will block processes until the whole
     group enters this function.
     """
-    if torch.distributed.is_initialized():
+    # Check if we're in a single-threaded section, skip barrier
+    if MAIN_PROC_ONLY >= 1:
+        return
+    elif torch.distributed.is_initialized():
         torch.distributed.barrier()
 
 
+def ddp_broadcast(communication_object, src=0):
+    """In DDP mode, this function will broadcast an object to all
+    processes.
+
+    Arguments
+    ---------
+    communication_object: Any
+        The object to be communicated to all processes. Must be picklable.
+        See docs for ``torch.distributed.broadcast_object_list()``
+    src: int
+        The rank which holds the object to be communicated.
+
+    Returns
+    -------
+    The communication_object passed on rank src.
+    """
+    if MAIN_PROC_ONLY >= 1 or not torch.distributed.is_initialized():
+        return communication_object
+
+    # Wrapping object in a list is required for preventing
+    # a copy of the object, maintaining a pointer instead
+    communication_list = [communication_object]
+    torch.distributed.broadcast_object_list(communication_list, src=src)
+    return communication_list[0]
+
+
 def ddp_init_group(run_opts):
     """This function will initialize the ddp group if
     distributed_launch bool is given in the python command line.
@@ -133,69 +159,44 @@ def ddp_init_group(run_opts):
     run_opts: list
         A list of arguments to parse, most often from `sys.argv[1:]`.
     """
-    if run_opts["distributed_launch"]:
-        if "local_rank" not in run_opts:
-            raise ValueError(
-                "To use DDP backend, start your script with:\n\t"
-                "python -m torch.distributed.launch [args]\n\t"
-                "experiment.py hyperparams.yaml --distributed_launch "
-                "--distributed_backend=nccl"
-            )
-        else:
-            if not run_opts["distributed_backend"] == "gloo":
-                if run_opts["local_rank"] + 1 > torch.cuda.device_count():
-                    raise ValueError(
-                        "Killing process " + str() + "\n"
-                        "Not enough GPUs available!"
-                    )
-        if "RANK" in os.environ is None or os.environ["RANK"] == "":
-            raise ValueError(
-                "To use DDP backend, start your script with:\n\t"
-                "python -m torch.distributed.launch [args]\n\t"
-                "experiment.py hyperparams.yaml --distributed_launch "
-                "--distributed_backend=nccl"
-            )
-        rank = int(os.environ["RANK"])
-
-        if run_opts["distributed_backend"] == "nccl":
-            if not torch.distributed.is_nccl_available():
-                raise ValueError("NCCL is not supported in your machine.")
-        elif run_opts["distributed_backend"] == "gloo":
-            if not torch.distributed.is_gloo_available():
-                raise ValueError("GLOO is not supported in your machine.")
-        elif run_opts["distributed_backend"] == "mpi":
-            if not torch.distributed.is_mpi_available():
-                raise ValueError("MPI is not supported in your machine.")
-        else:
-            logger.info(
-                run_opts["distributed_backend"]
-                + " communcation protocol doesn't exist."
-            )
+    rank = os.environ.get("RANK")
+    local_rank = os.environ.get("LOCAL_RANK")
+    if local_rank is None or rank is None:
+        return
+
+    local_rank = int(local_rank)
+    if not run_opts["distributed_backend"] == "gloo":
+        if local_rank + 1 > torch.cuda.device_count():
             raise ValueError(
-                run_opts["distributed_backend"]
-                + " communcation protocol doesn't exist."
+                "Killing process " + str() + "\n" "Not enough GPUs available!"
             )
-        # rank arg is used to set the right rank of the current process for ddp.
-        # if you have 2 servers with 2 gpu:
-        # server1:
-        #   GPU0: local_rank=device=0, rank=0
-        #   GPU1: local_rank=device=1, rank=1
-        # server2:
-        #   GPU0: local_rank=device=0, rank=2
-        #   GPU1: local_rank=device=1, rank=3
-        torch.distributed.init_process_group(
-            backend=run_opts["distributed_backend"], rank=rank
-        )
+    rank = int(rank)
+
+    if run_opts["distributed_backend"] == "nccl":
+        if not torch.distributed.is_nccl_available():
+            raise ValueError("NCCL is not supported in your machine.")
+    elif run_opts["distributed_backend"] == "gloo":
+        if not torch.distributed.is_gloo_available():
+            raise ValueError("GLOO is not supported in your machine.")
+    elif run_opts["distributed_backend"] == "mpi":
+        if not torch.distributed.is_mpi_available():
+            raise ValueError("MPI is not supported in your machine.")
     else:
-        logger.info(
-            "distributed_launch flag is disabled, "
-            "this experiment will be executed without DDP."
+        raise ValueError(
+            run_opts["distributed_backend"]
+            + " communcation protocol doesn't exist."
         )
-        if "local_rank" in run_opts and run_opts["local_rank"] > 0:
-            raise ValueError(
-                "DDP is disabled, local_rank must not be set.\n"
-                "For DDP training, please use --distributed_launch. "
-                "For example:\n\tpython -m torch.distributed.launch "
-                "experiment.py hyperparams.yaml "
-                "--distributed_launch --distributed_backend=nccl"
-            )
+
+    # rank arg is used to set the right rank of the current process for ddp.
+    # if you have 2 servers with 2 gpu:
+    # server1:
+    #   GPU0: local_rank=device=0, rank=0
+    #   GPU1: local_rank=device=1, rank=1
+    # server2:
+    #   GPU0: local_rank=device=0, rank=2
+    #   GPU1: local_rank=device=1, rank=3
+    torch.distributed.init_process_group(
+        backend=run_opts["distributed_backend"],
+        rank=rank,
+        timeout=datetime.timedelta(seconds=7200),
+    )
diff --git a/speechbrain/utils/dynamic_chunk_training.py b/speechbrain/utils/dynamic_chunk_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ea6f931538f82beb2540aad595f070fffb9fc8c
--- /dev/null
+++ b/speechbrain/utils/dynamic_chunk_training.py
@@ -0,0 +1,170 @@
+"""Configuration and utility classes for classes for Dynamic Chunk Training, as
+often used for the training of streaming-capable models in speech recognition.
+
+The definition of Dynamic Chunk Training is based on that of the following
+paper, though a lot of the literature refers to the same definition:
+https://arxiv.org/abs/2012.05481
+
+Authors
+* Sylvain de Langen 2023
+"""
+
+import speechbrain as sb
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+# NOTE: this configuration object is intended to be relatively specific to
+# Dynamic Chunk Training; if you want to implement a different similar type of
+# chunking different from that, you should consider using a different object.
+@dataclass
+class DynChunkTrainConfig:
+    """Dynamic Chunk Training configuration object for use with transformers,
+    often in ASR for streaming.
+
+    This object may be used both to configure masking at training time and for
+    run-time configuration of DynChunkTrain-ready models."""
+
+    chunk_size: int
+    """Size in frames of a single chunk, always `>0`.
+    If chunkwise streaming should be disabled at some point, pass an optional
+    streaming config parameter."""
+
+    left_context_size: Optional[int] = None
+    """Number of *chunks* (not frames) visible to the left, always `>=0`.
+    If zero, then chunks can never attend to any past chunk.
+    If `None`, the left context is infinite (but use
+    `.is_fininite_left_context` for such a check)."""
+
+    def is_infinite_left_context(self) -> bool:
+        """Returns true if the left context is infinite (i.e. any chunk can
+        attend to any past frame)."""
+        return self.left_context_size is None
+
+    def left_context_size_frames(self) -> Optional[int]:
+        """Returns the number of left context *frames* (not chunks).
+        If ``None``, the left context is infinite.
+        See also the ``left_context_size`` field."""
+
+        if self.left_context_size is None:
+            return None
+
+        return self.chunk_size * self.left_context_size
+
+
+@dataclass
+class DynChunkTrainConfigRandomSampler:
+    """Helper class to generate a DynChunkTrainConfig at runtime depending on the current
+    stage.
+
+    Example
+    -------
+    >>> from speechbrain.core import Stage
+    >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+    >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfigRandomSampler
+    >>> # for the purpose of this example, we test a scenario with a 100%
+    >>> # chance of the (24, None) scenario to occur
+    >>> sampler = DynChunkTrainConfigRandomSampler(
+    ...     chunkwise_prob=1.0,
+    ...     chunk_size_min=24,
+    ...     chunk_size_max=24,
+    ...     limited_left_context_prob=0.0,
+    ...     left_context_chunks_min=16,
+    ...     left_context_chunks_max=16,
+    ...     test_config=DynChunkTrainConfig(32, 16),
+    ...     valid_config=None
+    ... )
+    >>> one_train_config = sampler(Stage.TRAIN)
+    >>> one_train_config
+    DynChunkTrainConfig(chunk_size=24, left_context_size=None)
+    >>> one_train_config.is_infinite_left_context()
+    True
+    >>> sampler(Stage.TEST)
+    DynChunkTrainConfig(chunk_size=32, left_context_size=16)"""
+
+    chunkwise_prob: float
+    """When sampling (during `Stage.TRAIN`), the probability that a finite chunk
+    size will be used.
+    In the other case, any chunk can attend to the full past and future
+    context."""
+
+    chunk_size_min: int
+    """When sampling a random chunk size, the minimum chunk size that can be
+    picked."""
+
+    chunk_size_max: int
+    """When sampling a random chunk size, the maximum chunk size that can be
+    picked."""
+
+    limited_left_context_prob: float
+    """When sampling a random chunk size, the probability that the left context
+    will be limited.
+    In the other case, any chunk can attend to the full past context."""
+
+    left_context_chunks_min: int
+    """When sampling a random left context size, the minimum number of left
+    context chunks that can be picked."""
+
+    left_context_chunks_max: int
+    """When sampling a random left context size, the maximum number of left
+    context chunks that can be picked."""
+
+    test_config: Optional[DynChunkTrainConfig] = None
+    """The configuration that should be used for `Stage.TEST`.
+    When `None`, evaluation is done with full context (i.e. non-streaming)."""
+
+    valid_config: Optional[DynChunkTrainConfig] = None
+    """The configuration that should be used for `Stage.VALID`.
+    When `None`, evaluation is done with full context (i.e. non-streaming)."""
+
+    def _sample_bool(self, prob: float) -> bool:
+        """Samples a random boolean with a probability, in a way that depends on
+        PyTorch's RNG seed.
+
+        Arguments
+        ---------
+        prob : float
+            Probability (0..1) to return True (False otherwise)."""
+        return torch.rand((1,)).item() < prob
+
+    def __call__(self, stage: "sb.core.Stage") -> DynChunkTrainConfig:
+        """In training stage, samples a random DynChunkTrain configuration.
+        During validation or testing, returns the relevant configuration.
+
+        Arguments
+        ---------
+        stage : speechbrain.core.Stage
+            Current stage of training or evaluation.
+            In training mode, a random DynChunkTrainConfig will be sampled
+            according to the specified probabilities and ranges.
+            During evaluation, the relevant DynChunkTrainConfig attribute will
+            be picked.
+        """
+        if stage == sb.core.Stage.TRAIN:
+            # When training for streaming, for each batch, we have a
+            # `dynamic_chunk_prob` probability of sampling a chunk size
+            # between `dynamic_chunk_min` and `_max`, otherwise output
+            # frames can see anywhere in the future.
+            if self._sample_bool(self.chunkwise_prob):
+                chunk_size = torch.randint(
+                    self.chunk_size_min, self.chunk_size_max + 1, (1,),
+                ).item()
+
+                if self._sample_bool(self.limited_left_context_prob):
+                    left_context_chunks = torch.randint(
+                        self.left_context_chunks_min,
+                        self.left_context_chunks_max + 1,
+                        (1,),
+                    ).item()
+                else:
+                    left_context_chunks = None
+
+                return DynChunkTrainConfig(chunk_size, left_context_chunks)
+            return None
+        elif stage == sb.core.Stage.TEST:
+            return self.test_config
+        elif stage == sb.core.Stage.VALID:
+            return self.valid_config
+        else:
+            raise AttributeError(f"Unsupported stage found {stage}")
diff --git a/speechbrain/utils/epoch_loop.py b/speechbrain/utils/epoch_loop.py
index 1d6daac698c0114a5d34a1a974d26e315ad8860c..52727fc58a5acbd02fcfddd775ed963d9a0ceb2a 100644
--- a/speechbrain/utils/epoch_loop.py
+++ b/speechbrain/utils/epoch_loop.py
@@ -55,12 +55,11 @@ class EpochCounter:
             fo.write(str(self.current))
 
     @mark_as_loader
-    def _recover(self, path, end_of_epoch=True, device=None):
+    def _recover(self, path, end_of_epoch=True):
         # NOTE: end_of_epoch = True by default so that when
         #  loaded in parameter transfer, this starts a new epoch.
         #  However, parameter transfer to EpochCounter should
         #  probably never be used really.
-        del device  # Not used.
         with open(path) as fi:
             saved_value = int(fi.read())
             if end_of_epoch:
diff --git a/speechbrain/pretrained/fetching.py b/speechbrain/utils/fetching.py
similarity index 89%
rename from speechbrain/pretrained/fetching.py
rename to speechbrain/utils/fetching.py
index aac28cdb7be1d7e687a496b2898eff0e3e4cfc75..e2567ee9bbd8ae91d4f28d8de09cf61a8837099e 100644
--- a/speechbrain/pretrained/fetching.py
+++ b/speechbrain/utils/fetching.py
@@ -11,7 +11,6 @@ import pathlib
 import logging
 from enum import Enum
 import huggingface_hub
-from typing import Union
 from collections import namedtuple
 from requests.exceptions import HTTPError
 
@@ -58,8 +57,7 @@ def fetch(
     save_filename=None,
     use_auth_token=False,
     revision=None,
-    cache_dir: Union[str, pathlib.Path, None] = None,
-    silent_local_fetch: bool = False,
+    huggingface_cache_dir=None,
 ):
     """Ensures you have a local copy of the file, returns its path
 
@@ -99,10 +97,8 @@ def fetch(
         The model revision corresponding to the HuggingFace Hub model revision.
         This is particularly useful if you wish to pin your code to a particular
         version of a model hosted at HuggingFace.
-    cache_dir: str or Path (default: None)
-        Location of HuggingFace cache for storing pre-trained models, to which symlinks are created.
-    silent_local_fetch: bool (default: False)
-        Surpress logging messages (quiet mode).
+    huggingface_cache_dir: str
+        Path to HuggingFace cache; if None -> "~/.cache/huggingface" (default: None)
 
     Returns
     -------
@@ -122,27 +118,28 @@ def fetch(
     if isinstance(source, FetchSource):
         fetch_from, source = source
     sourcefile = f"{source}/{filename}"
+    destination = savedir / save_filename
+    if destination.exists() and not overwrite:
+        MSG = f"Fetch {filename}: Using existing file/symlink in {str(destination)}."
+        logger.info(MSG)
+        return destination
+
     if pathlib.Path(source).is_dir() and fetch_from not in [
         FetchFrom.HUGGING_FACE,
         FetchFrom.URI,
     ]:
-        # Interpret source as local directory path & return it as destination
+        # Interpret source as local directory path & create a link and return it as destination
         sourcepath = pathlib.Path(sourcefile).absolute()
+        _missing_ok_unlink(destination)
+        destination.symlink_to(sourcepath)
         MSG = f"Destination {filename}: local file in {str(sourcepath)}."
-        if not silent_local_fetch:
-            logger.info(MSG)
-        return sourcepath
-    destination = savedir / save_filename
-    if destination.exists() and not overwrite:
-        MSG = f"Fetch {filename}: Using existing file/symlink in {str(destination)}."
         logger.info(MSG)
         return destination
     if (
         str(source).startswith("http:") or str(source).startswith("https:")
-    ) or fetch_from is FetchFrom.URI:
-        # Interpret source as web address.
+    ) or fetch_from is FetchFrom.URI:  # Interpret source as web address.
         MSG = (
-            f"Fetch {filename}: Downloading from normal URL {str(sourcefile)}."
+            f"Fetch {filename}: Downloading from normal URL {str(sourcefile)} ."
         )
         logger.info(MSG)
         # Download
@@ -163,7 +160,7 @@ def fetch(
                 filename=filename,
                 use_auth_token=use_auth_token,
                 revision=revision,
-                cache_dir=cache_dir,
+                cache_dir=huggingface_cache_dir,
             )
             logger.info(f"HF fetch: {fetched_file}")
         except HTTPError as e:
@@ -171,7 +168,6 @@ def fetch(
                 raise ValueError("File not found on HF hub")
             else:
                 raise
-
         # Huggingface hub downloads to etag filename, symlink to the expected one:
         sourcepath = pathlib.Path(fetched_file).absolute()
         _missing_ok_unlink(destination)
diff --git a/speechbrain/utils/filter_analysis.py b/speechbrain/utils/filter_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f69dea5500e0a34ded1336373e83c7e6d780f6b
--- /dev/null
+++ b/speechbrain/utils/filter_analysis.py
@@ -0,0 +1,221 @@
+"""Implements utils to model and combine filter properties, i.e. compute how
+window size, stride, etc. behave, which may be useful for certain usecases such
+as streaming.
+
+Authors:
+ * Sylvain de Langen 2024
+"""
+
+from dataclasses import dataclass
+from typing import Any, Iterable
+
+
+@dataclass
+class FilterProperties:
+    """Models the properties of something that behaves like a filter (e.g.
+    convolutions, fbanks, etc.) over time."""
+
+    window_size: int
+    """Size of the filter, i.e. the number of input frames on which a single
+    output depends. Other than dilation, it is assumed that the window operates
+    over a contiguous chunk of frames.
+
+    Example:
+    --------
+    .. code-block:: text
+
+        size = 3, stride = 3
+
+        out  <-a-> <-b-> <-c->
+        in   1 2 3 4 5 6 7 8 9
+    """
+
+    stride: int = 1
+    """Stride of the filter, i.e. how many input frames get skipped over from an
+    output frame to the next (regardless of window size or dilation).
+
+    Example:
+    --------
+    .. code-block:: text
+
+        size = 3, stride = 2
+
+             <-a->
+                 <-b->   <-d->
+        out          <-c->
+        in   1 2 3 4 5 6 7 8 9
+    """
+
+    dilation: int = 1
+    """Dilation rate of the filter. A window will consider every n-th
+    (n=dilation) input frame. With dilation, the filter will still observe
+    `size` input frames, but the window will span more frames.
+
+    Dilation is mostly relevant to "a trous" convolutions.
+    A dilation rate of 1, the default, effectively performs no dilation.
+
+    Example:
+    --------
+    .. code-block:: text
+
+        size = 3, stride = 1, dilation = 3
+
+            <-------> dilation - 1 == 2 skips
+            a        a        a
+            |  b     |  b     |  b
+            |  |  c  |  |  c  |  |  c
+            |  |  |  d  |  |  d  |  |  d
+            |  |  |  |  e  |  |  e  |  |  ..
+        in  1  2  3  4  5  6  7  8  9  10 ..
+            <-> stride == 1
+    """
+
+    causal: bool = False
+    """Whether the filter is causal, i.e. whether an output frame only depends
+    on past input frames (of a lower or equal index).
+
+    In certain cases, such as 1D convolutions, this can simply be achieved by
+    inserting padding to the left of the filter prior to applying the filter to
+    the input tensor.
+
+    Example:
+    --------
+    .. code-block:: text
+
+        size = 3, stride = 1, causal = true
+                 <-e->
+               <-d->
+             <-c->
+             b->
+             a
+        in   1 2 3 4 5
+    """
+
+    def __post_init__(self):
+        assert self.window_size > 0
+        assert self.stride > 0
+        assert (
+            self.dilation > 0
+        ), "Dilation must be >0. NOTE: a dilation of 1 means no dilation."
+
+    @staticmethod
+    def pointwise_filter() -> "FilterProperties":
+        """Returns filter properties for a trivial filter whose output frames
+        only ever depend on their respective input frame."""
+
+        return FilterProperties(window_size=1, stride=1)
+
+    def get_effective_size(self):
+        """The number of input frames that span the window, including those
+        ignored by dilation."""
+        return 1 + ((self.window_size - 1) * self.dilation)
+
+    def get_convolution_padding(self):
+        """The number of frames that need to be inserted on each end for a
+        typical convolution."""
+
+        if self.window_size % 2 == 0:
+            raise ValueError("Cannot determine padding with even window size")
+
+        if self.causal:
+            return self.get_effective_size() - 1
+
+        return (self.get_effective_size() - 1) // 2
+
+    def get_noncausal_equivalent(self):
+        """From a causal filter definition, gets a compatible non-causal filter
+        definition for which each output frame depends on the same input frames,
+        plus some false dependencies."""
+
+        if not self.causal:
+            return self
+
+        return FilterProperties(
+            # NOTE: valid even on even window sizes e.g. (2-1)*2+1 == 3
+            window_size=(self.window_size - 1) * 2 + 1,
+            stride=self.stride,
+            dilation=self.dilation,
+            causal=False,
+        )
+
+    def with_on_top(
+        self, other: "FilterProperties", allow_approximate: bool = True
+    ) -> "FilterProperties":
+        """Considering the chain of filters `other_filter(self(x))`, returns
+        recalculated properties of the resulting filter.
+
+        Arguments
+        ---------
+        other_filter: FilterProperties
+            The filter to combine `self` with.
+
+        allow_approximate: bool, optional
+            If `True` (the default), the resulting properties may be
+            "pessimistic" and express false dependencies instead of erroring
+            out when exact properties cannot be determined.
+            This might be the case when stacking non-causal and causal filters.
+            Depending on the usecase, this might be fine, but functions like
+            `has_overlap` may erroneously start returning `True`.
+        """
+
+        self_size = self.window_size
+
+        if other.window_size % 2 == 0:
+            if allow_approximate:
+                other_size = other.window_size + 1
+            else:
+                raise ValueError(
+                    "The filter to append cannot have an uneven window size. "
+                    "Specify `allow_approximate=True` if you do not need to "
+                    "analyze exact dependencies."
+                )
+        else:
+            other_size = other.window_size
+
+        if (self.causal or other.causal) and not (self.causal and other.causal):
+            if allow_approximate:
+                return self.get_noncausal_equivalent().with_on_top(
+                    other.get_noncausal_equivalent()
+                )
+            else:
+                raise ValueError(
+                    "Cannot express exact properties of causal and non-causal "
+                    "filters. "
+                    "Specify `allow_approximate=True` if you do not need to "
+                    "analyze exact dependencies."
+                )
+
+        out_size = self_size + (self.stride * (other_size - 1))
+        stride = self.stride * other.stride
+        dilation = self.dilation * other.dilation
+        causal = self.causal
+
+        return FilterProperties(out_size, stride, dilation, causal)
+
+
+def stack_filter_properties(
+    filters: Iterable[Any], allow_approximate: bool = True
+) -> FilterProperties:
+    """Returns the filter properties of a sequence of stacked filters.
+    If the sequence is empty, then a no-op filter is returned (with a size and
+    stride of 1).
+
+    Arguments
+    ---------
+    filters: FilterProperties | any
+        The filters to combine, e.g. `[a, b, c]` modelling `c(b(a(x)))`.
+        If an item is not an instance of :class:`FilterProperties`, then this
+        attempts to call `.get_filter_properties()` over it.
+
+    allow_approximate: bool, optional
+        See `FilterProperties.with_on_top`."""
+
+    ret = FilterProperties.pointwise_filter()
+
+    for prop in filters:
+        if not isinstance(prop, FilterProperties):
+            prop = prop.get_filter_properties()
+
+        ret = ret.with_on_top(prop, allow_approximate)
+
+    return ret
diff --git a/speechbrain/utils/kmeans.py b/speechbrain/utils/kmeans.py
new file mode 100644
index 0000000000000000000000000000000000000000..176add37c6219c8f2dcb1189bf3f44b06b70d580
--- /dev/null
+++ b/speechbrain/utils/kmeans.py
@@ -0,0 +1,172 @@
+"""
+Utilities for training kmeans model.
+
+Author
+ * Pooneh Mousavi 2023
+"""
+
+import os
+import logging
+from tqdm.contrib import tqdm
+
+try:
+    from sklearn.cluster import MiniBatchKMeans
+except ImportError:
+    err_msg = "The optional dependency sklearn is needed to use this module\n"
+    err_msg += "Cannot import sklearn.cluster.MiniBatchKMeans to use KMeans/\n"
+    err_msg += "Please follow the instructions below\n"
+    err_msg += "=============================\n"
+    err_msg += "pip install -U scikit-learn\n"
+    raise ImportError(err_msg)
+import joblib
+
+logger = logging.getLogger(__name__)
+
+
+def accumulate_and_extract_features(
+    batch, features_list, ssl_model, ssl_layer_num, device
+):
+    """ Extract features (output of SSL model) and acculamte them on cpu to be used for clustering.
+
+    Arguments
+    ---------
+        batch: tensor
+            Single batch of data.
+        features_list : list
+            accumulate features list.
+        ssl_model
+            SSL-model used to  extract features used for clustering.
+        ssl_layer_num: int
+            specify output of which layer of the ssl_model should be used.
+        device
+            CPU or  GPU.
+    """
+    batch = batch.to(device)
+    wavs, wav_lens = batch.sig
+    wavs, wav_lens = (
+        wavs.to(device),
+        wav_lens.to(device),
+    )
+    feats = ssl_model(wavs, wav_lens)[ssl_layer_num].flatten(end_dim=-2)
+    features_list.extend(feats.to("cpu").detach().numpy())
+
+
+def fetch_kmeans_model(
+    n_clusters,
+    init,
+    max_iter,
+    batch_size,
+    tol,
+    max_no_improvement,
+    n_init,
+    reassignment_ratio,
+    random_state,
+    checkpoint_path,
+):
+    """Return a k-means clustering model with specified parameters.
+
+    Arguments
+    ---------
+        n_clusters : MiniBatchKMeans
+            The number of clusters to form as well as the number of centroids to generate.
+        init : int
+            Method for initialization: {'k-means++'', ''random''}
+        max_iter : int
+            Maximum number of iterations over the complete dataset before stopping independently of any early stopping criterion heuristics.
+        batch_size : int
+            Size of the mini batches.
+        tol : float
+            Control early stopping based on the relative center changes as measured by a smoothed, variance-normalized of the mean center squared position changes.
+        max_no_improvement :int
+            Control early stopping based on the consecutive number of mini batches that does not yield an improvement on the smoothed inertia.
+        n_init : int
+            Number of random initializations that are tried
+        reassignment_ratio : float
+            Control the fraction of the maximum number of counts for a center to be reassigned.
+        random_state :int
+            Determines random number generation for centroid initialization and random reassignment.
+        compute_labels : bool
+            Compute label assignment and inertia for the complete dataset once the minibatch optimization has converged in fit.
+        init_size : int
+            Number of samples to randomly sample for speeding up the initialization.
+        checkpoint_path : str
+            Path to saved model.
+
+    Returns
+    ---------
+        MiniBatchKMeans
+            a k-means clustering model with specified parameters.
+    """
+    if os.path.exists(checkpoint_path):
+        logger.info(f"The checkpoint is loaded from {checkpoint_path}.")
+        return joblib.load(checkpoint_path)
+
+    logger.info(
+        f"No checkpoint is found at {checkpoint_path}. New model is initialized for training."
+    )
+    return MiniBatchKMeans(
+        n_clusters=n_clusters,
+        init=init,
+        max_iter=max_iter,
+        batch_size=batch_size,
+        tol=tol,
+        max_no_improvement=max_no_improvement,
+        n_init=n_init,
+        reassignment_ratio=reassignment_ratio,
+        random_state=random_state,
+        verbose=1,
+        compute_labels=True,
+        init_size=None,
+    )
+
+
+def train(
+    model,
+    train_set,
+    ssl_model,
+    ssl_layer_num,
+    kmeans_batch_size=1000,
+    device="cpu",
+):
+    """Train a  Kmeans model .
+
+    Arguments
+    ---------
+        model : MiniBatchKMeans
+            The initial kmeans model for training.
+        train_set : Dataloader
+            Batches of tarining data.
+        ssl_model
+            SSL-model used to  extract features used for clustering.
+        ssl_layer_num : int
+            Specify output of which layer of the ssl_model should be used.
+        device
+            CPU or  GPU.
+        kmeans_batch_size : int
+            Size of the mini batches.
+    """
+    logger.info("Start training kmeans model.")
+    features_list = []
+    with tqdm(train_set, dynamic_ncols=True,) as t:
+        for batch in t:
+            # train a kmeans model on a single batch if  features_list reaches the kmeans_batch_size.
+            if len(features_list) >= kmeans_batch_size:
+                model = model.fit(features_list)
+                features_list = []
+            # extract features from the SSL model
+            accumulate_and_extract_features(
+                batch, features_list, ssl_model, ssl_layer_num, device
+            )
+
+
+def save_model(model, checkpoint_path):
+    """Save a  Kmeans model .
+
+    Arguments
+    ---------
+        model : MiniBatchKMeans
+            The  kmeans model to be saved.
+        checkpoint_path : str)
+            Path to save the model..
+    """
+    joblib.dump(model, open(checkpoint_path, "wb"))
diff --git a/speechbrain/utils/parameter_transfer.py b/speechbrain/utils/parameter_transfer.py
index 6028f159280b280864127791611903db7c40f879..703664e6ea95d8794e5ef079653786f30b87b756 100644
--- a/speechbrain/utils/parameter_transfer.py
+++ b/speechbrain/utils/parameter_transfer.py
@@ -11,7 +11,7 @@ Authors
 import logging
 import pathlib
 from speechbrain.utils.distributed import run_on_main
-from speechbrain.pretrained.fetching import fetch, FetchFrom, FetchSource
+from speechbrain.utils.fetching import fetch, FetchFrom, FetchSource
 from speechbrain.utils.checkpoints import (
     DEFAULT_LOAD_HOOKS,
     DEFAULT_TRANSFER_HOOKS,
@@ -154,8 +154,7 @@ class Pretrainer:
         """
 
         def split(src):
-            """Core function to split path.
-            """
+            """Core function to split path."""
             if "/" in src:
                 return src.rsplit("/", maxsplit=1)
             else:
@@ -288,15 +287,8 @@ class Pretrainer:
         else:
             return bool(condition)
 
-    def load_collected(self, device=None):
-        """Loads the files that have been collected.
-
-        Arguments
-        ---------
-        device : str
-            Device on which to load, if you want to load to a specific device
-            directly ( otherwise just leave it to None ).
-        """
+    def load_collected(self):
+        """Loads the files that have been collected."""
         logger.info(
             f"Loading pretrained files for: {', '.join(self.loadables)}"
         )
@@ -312,9 +304,9 @@ class Pretrainer:
                     f"Redirecting (loading from local path): {paramfiles[name]} -> {self.paths[name]}"
                 )
                 paramfiles[name] = self.paths[name]
-        self._call_load_hooks(paramfiles, device)
+        self._call_load_hooks(paramfiles)
 
-    def _call_load_hooks(self, paramfiles, device=None):
+    def _call_load_hooks(self, paramfiles):
         # This internal function finds the correct hook to call for every
         # recoverable, and calls it.
         for name, obj in self.loadables.items():
@@ -324,19 +316,19 @@ class Pretrainer:
 
             # First see if object has custom load hook:
             if name in self.custom_hooks:
-                self.custom_hooks[name](obj, loadpath, device=device)
+                self.custom_hooks[name](obj, loadpath)
                 continue
             # Try the default transfer hook:
             default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS)
             if default_hook is not None:
-                default_hook(obj, loadpath, device=device)
+                default_hook(obj, loadpath)
                 continue
             # Otherwise find the default loader for that type:
             default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)
             if default_hook is not None:
                 # Need to fake end-of-epoch:
                 end_of_epoch = False
-                default_hook(obj, loadpath, end_of_epoch, device)
+                default_hook(obj, loadpath, end_of_epoch)
                 continue
             # If we got here, no custom hook or registered default hook exists
             MSG = f"Don't know how to load {type(obj)}. Register default hook \
diff --git a/speechbrain/pretrained/training.py b/speechbrain/utils/pretrained.py
similarity index 100%
rename from speechbrain/pretrained/training.py
rename to speechbrain/utils/pretrained.py
diff --git a/speechbrain/utils/profiling.py b/speechbrain/utils/profiling.py
index 1fc33898789998d534b937d51b0d83d55cb7301c..a08a33d5d658b0ccf4c9ce1aec4433e5fbe90496 100644
--- a/speechbrain/utils/profiling.py
+++ b/speechbrain/utils/profiling.py
@@ -1,678 +1,39 @@
-"""Polymorphic decorators to handle PyTorch profiling and benchmarking.
+"""Wrapper to handle PyTorch profiling and benchmarking.
 
 Author:
-    * Andreas Nautsch 2022
+    * Titouan Parcollet 2024
 """
-import numpy as np
-from copy import deepcopy
-from torch import profiler
-from functools import wraps
-from typing import Any, Callable, Iterable, Optional
-
-# from typing import List
-# from itertools import chain
-
-"""
-from torch.autograd.profiler_util import (  # pytorch v1.10.1
-    EventList,
-    FunctionEvent,
-    _format_time,
-    _format_memory,
-)
-"""
-
-
-def set_profiler_attr(func: object, set_attr: str, handler: Callable):
-    """Sets handler for profiler: scheduler or trace export.
-    """
-    assert set_attr in [
-        "on_trace_ready",
-        "schedule",
-    ], "Needs to be a callable profiler attribute."
-
-    if (
-        func is None
-    ):  # Polymorph: not used as decorator; func is used as e.g.: trace_export()
-        return handler
-    elif callable(
-        func
-    ):  # Polymorph: decorates a decorator of function/class constructor
-
-        @wraps(func)
-        def wrapper(*args, **kwargs):
-            """Wrapper implementation."""
-            if "__call__" not in dir(
-                func
-            ):  # Decorator for class constructor (directly)
-                result = func(*args, **kwargs)
-                setattr(result.profiler, set_attr, handler)
-                return result  # not tested
-            else:  # Return as additional argument.
-                kwargs[set_attr] = handler
-                return func(*args, **kwargs)
-
-        return wrapper
-    else:  # Polymorph: func is assumed to be an instance of speechbrain.core.Brain
-        # No return: in-place edit
-        if hasattr(func, "profiler"):
-            if func.profiler is profiler.profile:
-                setattr(func.profiler, set_attr, handler)
-
-
-def schedule(
-    func: Optional[object] = None,
-    wait: int = 2,
-    warmup: int = 2,
-    active: int = 2,
-    repeat: int = 1,
-    skip_first: int = 0,
-):
-    """Wrapper to create a ```torch.profiler.schedule``` (sets default parameters for warm-up).
-    """
-    torch_scheduler = profiler.schedule(
-        wait=wait,
-        warmup=warmup,
-        active=active,
-        repeat=repeat,
-        skip_first=skip_first,
-    )
-    """
-    Curious which action a default scheduler suggests at which profiler.step() ?
-        [torch_scheduler(x) for x in range(10)]
-
-        00 = {ProfilerAction} ProfilerAction.NONE
-        01 = {ProfilerAction} ProfilerAction.NONE
-        02 = {ProfilerAction} ProfilerAction.WARMUP
-        03 = {ProfilerAction} ProfilerAction.WARMUP
-        04 = {ProfilerAction} ProfilerAction.RECORD
-        05 = {ProfilerAction} ProfilerAction.RECORD_AND_SAVE
-        06 = {ProfilerAction} ProfilerAction.NONE
-        07 = {ProfilerAction} ProfilerAction.NONE
-        08 = {ProfilerAction} ProfilerAction.NONE
-        09 = {ProfilerAction} ProfilerAction.NONE
-    """
-
-    return set_profiler_attr(
-        func=func, set_attr="schedule", handler=torch_scheduler
-    )
-
-
-def export(
-    func: Optional[object] = None,
-    dir_name: str = "./log/",
-    worker_name: Optional[str] = None,
-    use_gzip: bool = False,
-):
-    """Exports current and aggregated traces for:
-    - Chrome tensorboard
-    - FlameGraph
-    (and sets default parameters for log file folder/filenames).
-    """
-    import os
-    import socket
-    import time
-
-    # Chrome export (default handler); inspired the log_file() function below.
-    tensorboard_handler = profiler.tensorboard_trace_handler(
-        dir_name=dir_name, worker_name=worker_name, use_gzip=use_gzip
-    )
-
-    def trace_handler(prof: profiler.profile):
-        """trace_handler implementation."""
-
-        def log_file(export_chrome: bool = False, info: str = ""):
-            """Implementation of logging file."""
-            nonlocal worker_name
-            if not worker_name:
-                worker_name = "{}_{}".format(
-                    socket.gethostname(), str(os.getpid())
-                )
-            if export_chrome:
-                ext = "pt.trace.json"
-            else:
-                ext = "txt"
-            if info:
-                pattern = "{{}}.{{}}_{}.{{}}".format(info)
-            else:
-                pattern = "{}.{}.{}"
-            file_name = pattern.format(
-                worker_name, int(time.time() * 1000), ext
-            )
-            if use_gzip:
-                file_name = file_name + ".gz"
-            return os.path.join(dir_name, file_name)
-
-        def export_stacks(log_path: str, metric: str):
-            """Implementation of export_stacks."""
-            prof.export_stacks(log_file(), metric)
 
-        def export_traces(aggregated_traces: bool = False):
-            """Implementation of export_traces."""
-            if not aggregated_traces:
-                # Chrome export (also checks for dir_name existing).
-                tensorboard_handler(prof)
-
-            # FlameGraph exports.
-            if prof.with_stack or aggregated_traces:
-                log_path = (
-                    log_file(info="aggregated")
-                    if aggregated_traces
-                    else log_file()
-                )
-                export_stacks(log_path=log_path, metric="self_cpu_time_total")
-                if prof.profiler is not None:
-                    if prof.profiler.use_cuda:
-                        export_stacks(
-                            log_path=log_path, metric="self_cuda_time_total"
-                        )
-
-        # export last logged trace - skip if events are empty (e.g., profiler created w/o any torch.nn call)
-        if prof.events():
-            export_traces()
-
-    return set_profiler_attr(
-        func=func, set_attr="on_trace_ready", handler=trace_handler
-    )
-
-
-def prepare_profiler_for_brain(prof: profiler.profile):
-    """Sets up a ``torch.profiler.profile`` to also (a) aggregate traces issued from various interactions
-    with ``speechbrain.core.Brain``:s and (b) hooks a method to ``merge_traces``.
-    """
-    # Brain functions will be called independently -> traces will be segregated, so we aggregate them.
-    prof.speechbrain_event_traces = list()
-
-    # Preparing the profiler to be re-used during Brain:s' lifecycles.
-    def hook_profiler_stop(stop: Callable):
-        """Implementation of hook_profiler_stop."""
-
-        @wraps(stop)
-        def stop_wrapper():
-            """Implementation of stop_wrapper."""
-            kineto_profiler = prof.profiler
-            if kineto_profiler is not None:
-                stop_result = stop()
-                if (
-                    prof.events()
-                ):  # kineto events are not aggregatable (sticking with parsed kineto events)
-                    # see: torch.autograd.profiler.__exit__
-                    kineto_events = kineto_profiler._parse_kineto_results(
-                        kineto_profiler.kineto_results
-                    )
-                    # add to trace record
-                    prof.speechbrain_event_traces.append(
-                        deepcopy(kineto_events)
-                    )
-                    # set flag to disable the profiler
-                    kineto_profiler.enabled = False
-                return stop_result
-            else:
-                return stop()  # will be: None
-
-        return stop_wrapper
-
-    # Preparing the profiler to be re-started during Brain:s' lifecycles.
-    def hook_profiler_start(start: Callable):
-        """Implementation of hook_profiler_start."""
-
-        @wraps(start)
-        def start_wrapper():
-            """Implementation of start_wrapper."""
-            prof.step_num = 0
-            prof.current_action = prof.schedule(prof.step_num)
-            kineto_profiler = prof.profiler
-            if kineto_profiler is not None:
-                # check flag if profiler is disabled (i.e. as of stop_wrapper); prevents entering its __init__ twice
-                if not kineto_profiler.enabled:
-                    # reset kineto profiler (otherwise, one obtains the same traces over & over again)
-                    kineto_profiler.enabled = True
-            return start()
-
-        return start_wrapper
-
-    """
-    # It's currently designed as hiding an Easter Egg.
-    def merge_traces():
-        " ""Implementation of merge_traces." ""
-        # Alternative re-design quirks: make trace aggregator a GLOBAL -or- create another profiler class.
-        trace_aggregator = "speechbrain_event_traces"
-        if prof.profiler is not None:
-            if trace_aggregator in dir(prof) and prof.events():
-                # clear all assigned parents/children (from previous mergers & trees)
-                for trace in getattr(prof, trace_aggregator):
-                    for event in trace:
-                        event.cpu_parent = None
-                        event.cpu_children: List[FunctionEvent] = []
-                # assemble new list
-                merged_events = EventList(
-                    list(chain.from_iterable(getattr(prof, trace_aggregator))),
-                    use_cuda=prof.profiler.use_cuda,
-                    profile_memory=prof.profiler.profile_memory,
-                    with_flops=prof.profiler.with_flops,
-                )
-                merged_events._build_tree()
-                return merged_events
-            else:  # not tested
-                return prof.events()
-        else:
-            return []
-    """
-
-    # Augment torch's profiler.
-    setattr(prof, "start", hook_profiler_start(getattr(prof, "start")))
-    setattr(prof, "stop", hook_profiler_stop(getattr(prof, "stop")))
-    # setattr(prof, "merge_traces", merge_traces)
-
-    # Return so it can be readily assigned elsewhere :)
-    return prof
-
-
-def hook_brain_methods(
-    func: object,
-    prof: profiler.profile,
-    class_hooks: Optional[Iterable[str]] = None,
-):
-    """For instances of ``speechbrain.core.Brain``, critical functions are hooked to profiler start/stop methods.
-    """
-    # Prepare additional hook decorators for methods of Brain:s.
-    def hook_brain(f: Callable):
-        """Implementation of hook_brain."""
-
-        @wraps(f)
-        def hook(*f_args, **f_kwargs):
-            """Implementation of hook."""
-            # The profiler stopped after __init__ so we need to get it up again and stop it manually also.
-            prof.start()
-            r = f(*f_args, **f_kwargs)
-            prof.stop()
-            return r
-
-        return hook
-
-    # Hook the crucial Brain methods.
-    if class_hooks is None:
-        class_hooks = ["fit", "evaluate"]
-    for method in class_hooks:
-        if method in dir(func):  # func is an instance of Brain
-            setattr(func, method, hook_brain(getattr(func, method)))
+from torch import profiler
+from typing import Optional
+import os
 
 
-def profile(
-    func: Optional[object] = None,
-    class_hooks: Optional[Iterable[str]] = None,
-    activities: Optional[Iterable[profiler.ProfilerActivity]] = None,
-    schedule: Optional[Callable[[int], profiler.ProfilerAction]] = None,
-    on_trace_ready: Optional[Callable[..., Any]] = None,
-    record_shapes: bool = False,
-    profile_memory: bool = False,
-    with_stack: bool = False,
-    with_flops: bool = False,
-    with_modules: bool = False,
+def prepare_profiler(
+    profile_warmup: Optional[int] = 5,
+    profile_steps: Optional[int] = 5,
+    logdir: Optional[str] = "tensorboard_logs",
 ) -> object:
-    """Wrapper to create a PyTorch profiler to benchmark training/inference of speechbrain.core.Brain instances.
+    """Wrapper to create a PyTorch profiler to benchmark training of speechbrain.core.Brain instances.
     See ``torch.profiler.profile`` documentation for details (brief summary below).
 
     Arguments
     ---------
-    func : object
-        ``speechbrain.core.Brain``:s or a (train/eval) function to be profiled.
-    class_hooks : iterable
-        List of method/function names of ``speechbrain.core.Brain``:s that should be profiled also.
-        Otherwise, only the __init__ constructor will be profiled when decorating a Brain class.
-        Default: ``['fit', 'evaluate']`` for classes, and ``None`` for functions.
-    activities : iterable
-        List of activity groups.
-        Default: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.
-        (Default value should be ok for most cases.)
-    schedule : callable
-        Waits a specified amount of steps for PyTorch to warm-up; see the above ``schedule`` decorator.
-        Default: ``ProfilerAction.RECORD`` (immediately starts recording).
-    on_trace_ready : callable
-        Specifies what benchmark record should be saved (after each scheduled step);
-        see above ``trace_handler`` decorator.
-        Default: ``None`` (pick up collected reporting once profiling ended, but not details per step).
-    record_shapes : bool
-        Save input shapes of operations (enables to group benchmark data by after profiling).
-        Default: ``False``.
-    profile_memory : bool
-        Track tensor memory allocation/deallocation.
-        Default: ``False``.
-    with_stack : bool
-        Record source information (file and line number).
-        Default: ``False``.
-    with_flops: bool
-        Estimate the number of FLOPs.
-        Default: ``False``.
-    with_modules: bool
-        Record module hierarchy (including function names)
-        Default: ``False``
-
-    Example
-    -------
-    >>> import torch
-    >>> @profile
-    ... def run(x : torch.Tensor):
-    ...     y = x ** 2
-    ...     z = y ** 3
-    ...     return y.backward()  # y.backward() returns None --> return value is substituted with profiler
-    >>> data = torch.randn((1, 1), requires_grad=True)
-    >>> prof = run(data)
-    >>> out = [len(prof.events()), len(prof.key_averages()), prof.profiler.total_average().count]
-    """
-    if func is None:  # return a profiler; not tested
-        return prepare_profiler_for_brain(
-            profiler.profile(
-                activities=activities,
-                schedule=schedule,
-                on_trace_ready=on_trace_ready,
-                record_shapes=record_shapes,
-                profile_memory=profile_memory,
-                with_stack=with_stack,
-                with_flops=with_flops,
-                with_modules=with_modules,
-            )
-        )
-    # Polymorph: func is pretrained or an instance of Brain (assumed case)
-    if hasattr(func, "HPARAMS_NEEDED") or not callable(func):
-        with profiler.profile(
-            activities=activities,
-            schedule=schedule,  # scheduler needs to be set directly (fetching is here not possible as for wrappers)
-            on_trace_ready=on_trace_ready,
-            record_shapes=record_shapes,
-            profile_memory=profile_memory,
-            with_stack=with_stack,
-            with_flops=with_flops,
-            with_modules=with_modules,
-        ) as prof:
-            func.profiler = prepare_profiler_for_brain(prof)
-            hook_brain_methods(func=func, class_hooks=class_hooks, prof=prof)
-            return func  # no need to return anything; all done in-place; but if needs to be readily assigned elsewhere
-    else:
-        # callable(func) - polymorph: __init__ Brain constructor -or- function to be wrapped
-        @wraps(func)
-        def wrapper(*args, **kwargs):
-            """Implementation of the wrapper."""
-            # Binding variables.
-            nonlocal class_hooks
-            nonlocal schedule
-            nonlocal on_trace_ready
-            # Check if there's a nested decorators.
-            if schedule is None:
-                if "schedule" in kwargs:
-                    schedule = kwargs.pop("schedule")
-            if on_trace_ready is None:
-                if "on_trace_ready" in kwargs:
-                    on_trace_ready = kwargs.pop("on_trace_ready")
-            with profiler.profile(
-                activities=activities,
-                schedule=schedule,
-                on_trace_ready=on_trace_ready,
-                record_shapes=record_shapes,
-                profile_memory=profile_memory,
-                with_stack=with_stack,
-                with_flops=with_flops,
-                with_modules=with_modules,
-            ) as prof:
-                # Preserves profiler as class attribute if func is not a function (implies: speechbrain.core.Brain).
-                if "__call__" not in dir(func):
-                    # Passing the profiler to Bain:s' __init__ constructor as an additional argument.
-                    kwargs["profiler"] = prepare_profiler_for_brain(prof)
-                    hook_brain_methods(
-                        func=func, class_hooks=class_hooks, prof=prof
-                    )
-
-                # Run & trace to benchmark.
-                result = func(*args, **kwargs)
-
-                # Prof is about to be lost at return.
-                if "__call__" in dir(func):
-                    if result is None:
-                        return prof  # for void function, simply return profiling data
-                    else:  # not tested - returns both
-                        return result, prof
-
-                return result
-
-        return wrapper
-
-
-def profile_analyst(
-    func: Optional[object] = None, class_hooks: Optional[Iterable[str]] = None,
-):  # to diverge, define parameters from scratch: @schedule; @export & @profile
-    """Pre-configured profiling for a fully detailed benchmark - analyst perspective.
-
-    Creating this analyst view will create overheads (disabling some PyTorch optimisations);
-    use @profile_optimiser to take benefits of optimisations and further optimise your modules, accordingly.
-    """
-    profiler_kwargs = {
-        "schedule": schedule(),
-        "on_trace_ready": None,
-        "record_shapes": True,
-        "profile_memory": True,
-        "with_stack": True,
-        "with_flops": True,  # only for: matrix multiplication & 2D conv; see: torch.autograd.profiler.profile
-        "with_modules": True,
-        "class_hooks": class_hooks,
-    }
-    wrapped_func = profile(func, **profiler_kwargs)
-    # Polymorph: func is pretrained or an instance of Brain (assumed case)
-    if hasattr(func, "HPARAMS_NEEDED") or not callable(func):
-        return wrapped_func
-    else:  # callable(func) - polymorph: __init__ Brain constructor -or- function to be wrapped
-
-        @wraps(func)
-        def wrapper(*args, **kwargs):
-            """Implementation of the wrapper."""
-            return wrapped_func(*args, **kwargs)
-
-        return wrapper
-
-
-def profile_optimiser(
-    func: Optional[object] = None, class_hooks: Optional[Iterable[str]] = None,
-):  # to diverge, define parameters from scratch: @schedule; @export & @profile
-    """Pre-configured profiling for a detailed benchmark (better suitable for speed-optimisation than @profile_analyst).
-    """
-    profiler_kwargs = {
-        "schedule": schedule(),
-        "on_trace_ready": None,
-        "record_shapes": False,  # avoid: overheads
-        "profile_memory": True,
-        "with_stack": False,  # avoid: overheads
-        "with_flops": False,  # only for: matrix multiplication & 2D conv; see: torch.autograd.profiler.profile
-        "with_modules": True,
-        "class_hooks": class_hooks,
-    }
-    wrapped_func = profile(func, **profiler_kwargs)
-    # Polymorph: func is pretrained or an instance of Brain (assumed case)
-    if hasattr(func, "HPARAMS_NEEDED") or not callable(func):
-        return wrapped_func
-    else:  # callable(func) - polymorph: __init__ Brain constructor -or- function to be wrapped
-
-        @wraps(func)
-        def wrapper(*args, **kwargs):
-            """Implementation of the wrapper."""
-            return wrapped_func(*args, **kwargs)
-
-        return wrapper
-
-
-def profile_report(  # not part of unittests
-    func: Optional[object] = None, class_hooks: Optional[Iterable[str]] = None,
-):
-    """Pre-configured profiling for a reporting benchmark (changed scheduler to @profile_optimiser).
-    """
-    profiler_kwargs = {
-        "schedule": schedule(
-            wait=1, warmup=2, active=7, repeat=1, skip_first=0,
-        ),  # gives #active, avg:ed of #repeat
-        "on_trace_ready": None,
-        "record_shapes": False,  # avoid: overheads
-        "profile_memory": True,
-        "with_stack": False,  # avoid: overheads
-        "with_flops": False,  # only for: matrix multiplication & 2D conv; see: torch.autograd.profiler.profile
-        "with_modules": True,
-        "class_hooks": class_hooks,
-    }
-    wrapped_func = profile(func, **profiler_kwargs)
-    # Polymorph: func is pretrained or an instance of Brain (assumed case)
-    if hasattr(func, "HPARAMS_NEEDED") or not callable(func):
-        return wrapped_func
-    else:  # callable(func) - polymorph: __init__ Brain constructor -or- function to be wrapped
-
-        @wraps(func)
-        def wrapper(*args, **kwargs):
-            """Implementation of the wrapper."""
-            return wrapped_func(*args, **kwargs)
-
-        return wrapper
-
-
-"""
-def events_diff(
-    a: EventList, b: EventList, filter_by: str = "count",
-):
-    " ""Takes two ``EventList``:s in, filters events of equal value (default: by the count of events).
-
-    The purpose of the results of this diff are for visualisation only (to see the difference between implementations).
-    " ""
-    # Making copies from the originals instead of simply adding the diff directly might be slower (preserves structure).
-    aa = deepcopy(a)
-    bb = deepcopy(b)
-
-    # Maps: function name -> (call count, position) // the position helps to remove alike call numbers later on.
-    a_filter = dict(
-        [(i.key, (getattr(i, filter_by), p)) for p, i in enumerate(aa)]
+    profile_warmup: int
+        Number of warmup step before starting to log.
+    profile_steps: int
+        Number of steps to log after warmup.
+    logdir: str
+        Path to the output folder of the logs.
+    """
+
+    logdir = os.path.join(logdir, "profiler_logs")
+
+    return profiler.profile(
+        schedule=profiler.schedule(
+            wait=0, warmup=profile_warmup, active=profile_steps, repeat=1
+        ),
+        on_trace_ready=profiler.tensorboard_trace_handler(logdir),
+        record_shapes=True,
+        with_stack=True,
     )
-    b_filter = dict(
-        [(i.key, (getattr(i, filter_by), p)) for p, i in enumerate(bb)]
-    )
-
-    # Figuring our which ones to delete.
-    a_to_remove = list([])
-    b_to_remove = list([])
-    for key in a_filter.keys():
-        if key in b_filter.keys():
-            # Equal values are filtered.
-            if a_filter[key][0] == b_filter[key][0]:
-                # Enlist position to be removed.
-                a_to_remove.append(a_filter[key][1])
-                b_to_remove.append(b_filter[key][1])
-
-    # Since EventLists are lists: removing items from the back.
-    if a_to_remove:
-        a_to_remove.sort(reverse=True)
-        for k in a_to_remove:
-            aa.remove(aa[k])
-
-    if b_to_remove:
-        b_to_remove.sort(reverse=True)
-        for k in b_to_remove:
-            bb.remove(bb[k])
-
-    return aa, bb
-"""
-
-
-def report_time(events: object, verbose=False, upper_control_limit=False):
-    """Summary reporting of total time - see: torch.autograd.profiler_util
-    """
-    # Aggregate CPU & CUDA time.
-    """
-    if isinstance(events, FunctionEvent):
-        function_events = events
-    elif
-    """
-    if isinstance(events, profiler.profile):
-        function_events = events.events()
-    elif hasattr(events, "profiler"):  # assumes speechbrain.core.Brain
-        function_events = events.profiler.events()
-    else:
-        raise TypeError(
-            "Expected a FunctionEvent; profiler.profile, or a SpeechBrain."
-        )
-
-    if upper_control_limit:
-        # discerns top-level event (among others) aten:zeros which is in the avg range of 10-20ms on laptop CPU
-        cpu_data = np.array(
-            [e.cpu_time for e in function_events if e.key == "ProfilerStep*"]
-        )
-        cuda_data = np.array(
-            [e.cuda_time for e in function_events if e.key == "ProfilerStep*"]
-        )
-        cpu_time = cpu_data.mean() + 3 * cpu_data.std()
-        cuda_time = cuda_data.mean() + 3 * cuda_data.std()
-    else:
-        total = function_events.total_average()
-        cpu_time = total.self_cpu_time_total
-        cuda_time = total.self_cuda_time_total
-
-    """
-    if verbose:
-        print("CPU time: {}".format(_format_time(cpu_time)))
-        if cuda_time > 0:
-            print("CUDA time: {}".format(_format_time(cuda_time)))
-    """
-
-    return cpu_time, cuda_time
-
-
-def report_memory(handler: object, verbose=False):
-    """Summary reporting of total time - see: torch.autograd.profiler_util
-    """
-    # Aggregate CPU & CUDA time.
-    """
-    if isinstance(handler, FunctionEvent):
-        events = handler
-    elif
-    """
-    if isinstance(handler, profiler.profile):
-        events = handler.events()
-    elif hasattr(handler, "profiler"):  # assumes speechbrain.core.Brain
-        events = handler.profiler.events()
-    else:
-        raise TypeError(
-            "Expected a FunctionEvent; profiler.profile, or a SpeechBrain."
-        )
-
-    """memory allocation during each time step is of relevance, e.g. for visualisation - time intensive for lots events
-    mem_times = np.unique(
-        [[x.time_range.start, x.time_range.end] for x in events]
-    )
-    cpu_memory = np.zeros_like(mem_times)
-    cuda_memory = np.zeros_like(mem_times)
-    for x in events:
-        idx = (x.time_range.start <= mem_times) & (
-            x.time_range.end >= mem_times
-        )
-        cpu_memory[idx] += x.cpu_memory_usage
-        cuda_memory[idx] += x.cuda_memory_usage
-
-    # variable names instead of labeling pandas' columns
-    cpu_mem = np.max(cpu_memory)
-    cuda_mem = np.max(cuda_memory)
-    """
-
-    cpu_mem = cuda_mem = 0
-    for e in events:
-        if len(e.cpu_children) == 0:
-            leaf_cpu_mem = e.cpu_memory_usage
-            leaf_cuda_mem = e.cuda_memory_usage
-            parent = e.cpu_parent
-            while parent is not None:
-                leaf_cpu_mem += parent.cpu_memory_usage
-                leaf_cuda_mem += parent.cuda_memory_usage
-                parent = parent.cpu_parent
-            if leaf_cpu_mem > cpu_mem:
-                cpu_mem = leaf_cpu_mem
-            if leaf_cuda_mem > cuda_mem:
-                cuda_mem = leaf_cuda_mem
-
-    """
-    if verbose:
-        print("Peak CPU Mem: {}".format(_format_memory(cpu_mem)))
-        if cuda_mem > 0:
-            print("Peak CUDA Mem: {}".format(_format_memory(cuda_mem)))
-    """
-
-    return cpu_mem, cuda_mem
diff --git a/speechbrain/utils/streaming.py b/speechbrain/utils/streaming.py
new file mode 100644
index 0000000000000000000000000000000000000000..336f6e197a1d9db9a14e976bf485e3198cb0c1ab
--- /dev/null
+++ b/speechbrain/utils/streaming.py
@@ -0,0 +1,233 @@
+"""Utilities to assist with designing and training streaming models.
+
+Authors
+* Sylvain de Langen 2023
+"""
+
+import math
+import torch
+from typing import Callable, List
+
+
+def split_fixed_chunks(
+    x: torch.Tensor, chunk_size: int, dim: int = -1
+) -> List[torch.Tensor]:
+    """Split an input tensor `x` into a list of chunk tensors of size
+    `chunk_size` alongside dimension `dim`.
+    Useful for splitting up sequences with chunks of fixed sizes.
+
+    If dimension `dim` cannot be evenly split by `chunk_size`, then the last
+    chunk will be smaller than `chunk_size`.
+
+    Arguments
+    ---------
+    x : torch.Tensor
+        The tensor to split into chunks, typically a sequence or audio signal.
+
+    chunk_size : int
+        The size of each chunk, i.e. the max size of each chunk on dimension
+        `dim`.
+
+    dim : int
+        Dimension to split alongside of, typically the time dimension.
+
+    Returns
+    -------
+    List[torch.Tensor]
+        A chunk list of tensors, see description and example.
+        Guarantees `.size(dim) <= chunk_size`.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.utils.streaming import split_fixed_chunks
+    >>> x = torch.zeros((16, 10000, 80))
+    >>> chunks = split_fixed_chunks(x, 128, dim=1)
+    >>> len(chunks)
+    79
+    >>> chunks[0].shape
+    torch.Size([16, 128, 80])
+    >>> chunks[-1].shape
+    torch.Size([16, 16, 80])
+    """
+
+    num_chunks = math.ceil(x.size(dim) / chunk_size)
+    split_at_indices = [(i + 1) * chunk_size for i in range(num_chunks - 1)]
+    return torch.tensor_split(x, split_at_indices, dim=1)
+
+
+def split_wav_lens(
+    chunk_lens: List[int], wav_lens: torch.Tensor
+) -> List[torch.Tensor]:
+    """Converts a single `wav_lens` tensor into a list of `chunk_count` tensors,
+    typically useful when chunking signals with `split_fixed_chunks`.
+
+    `wav_lens` represents the relative length of each audio within a batch,
+    which is typically used for masking. This function computes the relative
+    length at chunk level.
+
+    Arguments
+    ---------
+    chunk_lens : List[int]
+        Length of the sequence of every chunk. For example, if `chunks` was
+        returned from `split_fixed_chunks(x, chunk_size, dim=1)`, then this
+        should be `[chk.size(1) for chk in chunks]`.
+
+    wav_lens : torch.Tensor
+        Relative lengths of audio within a batch. For example, for an input
+        signal of 100 frames and a batch of 3 elements, `(1.0, 0.5, 0.25)`
+        would mean the batch holds audio of 100 frames, 50 frames and 25 frames
+        respectively.
+
+    Returns
+    -------
+    List[torch.Tensor]
+        A list of chunked wav_lens, see description and example.
+
+    Example
+    -------
+    >>> import torch
+    >>> from speechbrain.utils.streaming import split_wav_lens, split_fixed_chunks
+    >>> x = torch.zeros((3, 20, 80))
+    >>> chunks = split_fixed_chunks(x, 8, dim=1)
+    >>> len(chunks)
+    3
+    >>> # 20 frames, 13 frames, 17 frames
+    >>> wav_lens = torch.tensor([1.0, 0.65, 0.85])
+    >>> chunked_wav_lens = split_wav_lens([c.size(1) for c in chunks], wav_lens)
+    >>> chunked_wav_lens
+    [tensor([1., 1., 1.]), tensor([1.0000, 0.6250, 1.0000]), tensor([1.0000, 0.0000, 0.2500])]
+    >>> # wav 1 covers 62.5% (5/8) of the second chunk's frames
+    """
+
+    chunk_wav_lens = []
+
+    seq_size = sum(chunk_lens)
+    wav_lens_frames = wav_lens * seq_size
+
+    chunk_start_frame = 0
+    for chunk_len in chunk_lens:
+        chunk_raw_len = (wav_lens_frames - chunk_start_frame) / chunk_len
+        chunk_raw_len = torch.clamp(chunk_raw_len, 0.0, 1.0)
+        chunk_wav_lens.append(chunk_raw_len)
+
+        chunk_start_frame += chunk_len
+
+    return chunk_wav_lens
+
+
+def infer_dependency_matrix(
+    model: Callable, seq_shape: tuple, in_stride: int = 1
+):
+    """
+    Randomizes parts of the input sequence several times in order to detect
+    dependencies between input frames and output frames, aka whether a given
+    output frame depends on a given input frame.
+
+    This can prove useful to check whether a model behaves correctly in a
+    streaming context and does not contain accidental dependencies to future
+    frames that couldn't be known in a streaming scenario.
+
+    Note that this can get very computationally expensive for very long
+    sequences.
+
+    Furthermore, this expects inference to be fully deterministic, else false
+    dependencies may be found. This also means that the model must be in eval
+    mode, to inhibit things like dropout layers.
+
+    Arguments
+    ---------
+    model : Callable
+        Can be a model or a function (potentially emulating streaming
+        functionality). Does not require to be a trained model, random weights
+        should usually suffice.
+    seq_shape : tuple
+        The function tries inferring by randomizing parts of the input sequence
+        in order to detect unwanted dependencies.
+        The shape is expected to look like `[batch_size, seq_len, num_feats]`,
+        where `batch_size` may be `1`.
+    in_stride : int
+        Consider only N-th input, for when the input sequences are very long
+        (e.g. raw audio) and the output is shorter (subsampled, filters, etc.)
+
+    Returns
+    -------
+    dependencies : torch.BoolTensor
+        Matrix representing whether an output is dependent on an input; index
+        using `[in_frame_idx, out_frame_idx]`. `True` indicates a detected
+        dependency.
+    """
+    # TODO: document arguments
+
+    bs, seq_len, feat_len = seq_shape
+
+    base_seq = torch.rand(seq_shape)
+    with torch.no_grad():
+        base_out = model(base_seq)
+
+        if not model(base_seq).equal(base_out):
+            raise ValueError(
+                "Expected deterministic model, but inferring twice on the same "
+                "data yielded different results. Make sure that you use "
+                "`eval()` mode so that it does not include randomness."
+            )
+    out_len, _out_feat_len = base_out.shape[1:]
+
+    deps = torch.zeros(
+        ((seq_len + (in_stride - 1)) // in_stride, out_len), dtype=torch.bool
+    )
+
+    for in_frame_idx in range(0, seq_len, in_stride):
+        test_seq = base_seq.clone()
+        test_seq[:, in_frame_idx, :] = torch.rand(bs, feat_len)
+
+        with torch.no_grad():
+            test_out = model(test_seq)
+
+        for out_frame_idx in range(out_len):
+            if not torch.allclose(
+                test_out[:, out_frame_idx, :], base_out[:, out_frame_idx, :]
+            ):
+                deps[in_frame_idx // in_stride][out_frame_idx] = True
+
+    return deps
+
+
+def plot_dependency_matrix(deps):
+    """
+    Returns a matplotlib figure of a dependency matrix generated by
+    `infer_dependency_matrix`.
+
+    At a given point, a red square indicates that a given output frame (y-axis)
+    was to depend on a given input frame (x-axis).
+
+    For example, a fully red image means that all output frames were dependent
+    on all the history. This could be the case of a bidirectional RNN, or a
+    transformer model, for example.
+
+    Arguments
+    ---------
+    deps : torch.BoolTensor
+        Matrix returned by `infer_dependency_matrix` or one in a compatible
+        format.
+    """
+    import matplotlib.pyplot as plt
+    from matplotlib.colors import ListedColormap
+
+    cmap = ListedColormap(["white", "red"])
+
+    fig, ax = plt.subplots()
+
+    ax.pcolormesh(
+        torch.permute(deps, (1, 0)),
+        cmap=cmap,
+        vmin=False,
+        vmax=True,
+        edgecolors="gray",
+        linewidth=0.5,
+    )
+    ax.set_title("Dependency plot")
+    ax.set_xlabel("in")
+    ax.set_ylabel("out")
+    ax.set_aspect("equal")
+    return fig
diff --git a/speechbrain/utils/text_to_sequence.py b/speechbrain/utils/text_to_sequence.py
index 5d73db1782d244707b9e8a3f0e2a749bcb5f2531..08819e9411325af5b2fdb79fb5e6d9a7d7057ee4 100644
--- a/speechbrain/utils/text_to_sequence.py
+++ b/speechbrain/utils/text_to_sequence.py
@@ -26,7 +26,9 @@
 #
 # *****************************************************************************
 import re
+import logging
 
+logger = logging.getLogger(__name__)
 
 valid_symbols = [
     "AA",
@@ -316,3 +318,69 @@ def _should_keep_symbol(s):
     """whether to keep a certain symbol
     """
     return s in _symbol_to_id and s != "_" and s != "~"
+
+
+def _g2p_keep_punctuations(g2p_model, text):
+    """do grapheme to phoneme and keep the punctuations between the words
+    Arguments
+    ---------
+    g2p_model: speechbrain.inference.text g2p model
+    text: string
+        the input text
+
+    Example
+    -------
+    >>> from speechbrain.inference.text import GraphemeToPhoneme
+    >>> g2p_model = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") # doctest: +SKIP
+    >>> from speechbrain.utils.text_to_sequence import _g2p_keep_punctuations # doctest: +SKIP
+    >>> text = "Hi, how are you?" # doctest: +SKIP
+    >>> _g2p_keep_punctuations(g2p_model, text) # doctest: +SKIP
+    ['HH', 'AY', ',', ' ', 'HH', 'AW', ' ', 'AA', 'R', ' ', 'Y', 'UW', '?']
+    """
+    # find the words where a "-" or "'" or "." or ":" appears in the middle
+    special_words = re.findall(r"\w+[-':\.][-':\.\w]*\w+", text)
+
+    # remove intra-word punctuations ("-':."), this does not change the output of speechbrain g2p
+    for special_word in special_words:
+        rmp = special_word.replace("-", "")
+        rmp = rmp.replace("'", "")
+        rmp = rmp.replace(":", "")
+        rmp = rmp.replace(".", "")
+        text = text.replace(special_word, rmp)
+
+    # keep inter-word punctuations
+    all_ = re.findall(r"[\w]+|[-!'(),.:;? ]", text)
+    try:
+        phonemes = g2p_model(text)
+    except RuntimeError:
+        logger.info(f"error with text: {text}")
+        quit()
+    word_phonemes = "-".join(phonemes).split(" ")
+
+    phonemes_with_punc = []
+    count = 0
+    try:
+        # if the g2p model splits the words correctly
+        for i in all_:
+            if i not in "-!'(),.:;? ":
+                phonemes_with_punc.extend(word_phonemes[count].split("-"))
+                count += 1
+            else:
+                phonemes_with_punc.append(i)
+    except IndexError:
+        # sometimes the g2p model cannot split the words correctly
+        logger.warning(
+            f"Do g2p word by word because of unexpected ouputs from g2p for text: {text}"
+        )
+
+        for i in all_:
+            if i not in "-!'(),.:;? ":
+                p = g2p_model.g2p(i)
+                p_without_space = [i for i in p if i != " "]
+                phonemes_with_punc.extend(p_without_space)
+            else:
+                phonemes_with_punc.append(i)
+
+    while "" in phonemes_with_punc:
+        phonemes_with_punc.remove("")
+    return phonemes_with_punc
diff --git a/speechbrain/utils/torch_audio_backend.py b/speechbrain/utils/torch_audio_backend.py
index 49e6e2f692389dd566f7375b03fd038266e09cef..de694c6a46ded6f080206db87b459ab7e09b2013 100644
--- a/speechbrain/utils/torch_audio_backend.py
+++ b/speechbrain/utils/torch_audio_backend.py
@@ -6,17 +6,62 @@ Authors
 import platform
 import logging
 import torchaudio
+from typing import Optional
 
 logger = logging.getLogger(__name__)
 
 
+def try_parse_torchaudio_major_version() -> Optional[int]:
+    """Tries parsing the torchaudio major version.
+
+    Returns
+    -------
+    The parsed major version, otherwise ``None``."""
+
+    if not hasattr(torchaudio, "__version__"):
+        return None
+
+    version_split = torchaudio.__version__.split(".")
+
+    # expect in format x.y.zwhatever; we care only about x
+
+    if len(version_split) <= 2:
+        # not sure how to parse this
+        return None
+
+    try:
+        version = int(version_split[0])
+    except Exception:
+        return None
+
+    return version
+
+
 def check_torchaudio_backend():
     """Checks the torchaudio backend and sets it to soundfile if
     windows is detected.
     """
-    current_system = platform.system()
-    if current_system == "Windows":
+
+    torchaudio_major = try_parse_torchaudio_major_version()
+
+    if torchaudio_major is None:
+        logger.warning(
+            "Failed to detect torchaudio major version; unsure how to check your setup. We recommend that you keep torchaudio up-to-date."
+        )
+    elif torchaudio_major >= 2:
+        available_backends = torchaudio.list_audio_backends()
+
+        if len(available_backends) == 0:
+            logger.warning(
+                "SpeechBrain could not find any working torchaudio backend. Audio files may fail to load. Follow this link for instructions and troubleshooting: https://pytorch.org/audio/stable/index.html"
+            )
+    else:
         logger.warning(
-            "The torchaudio backend is switched to 'soundfile'. Note that 'sox_io' is not supported on Windows."
+            "This version of torchaudio is old. SpeechBrain no longer tries using the torchaudio global backend mechanism in recipes, so if you encounter issues, update torchaudio."
         )
-        torchaudio.set_audio_backend("soundfile")
+        current_system = platform.system()
+        if current_system == "Windows":
+            logger.warning(
+                'Switched audio backend to "soundfile" because you are running Windows and you are running an old torchaudio version.'
+            )
+            torchaudio.set_audio_backend("soundfile")
diff --git a/speechbrain/utils/train_logger.py b/speechbrain/utils/train_logger.py
index d3bb864f17046a7232b4aa72db345e081f6e159a..56b9b467fad828a40fcaf613e9718d6979468274 100644
--- a/speechbrain/utils/train_logger.py
+++ b/speechbrain/utils/train_logger.py
@@ -2,9 +2,9 @@
 
 Authors
  * Peter Plantinga 2020
+ * Jarod Duret 2023
 """
 import logging
-import ruamel.yaml
 import torch
 import os
 from speechbrain.utils.distributed import main_process_only, if_main_process
@@ -182,22 +182,46 @@ class TensorboardLogger(TrainLogger):
 
 
 class WandBLogger(TrainLogger):
-    """Logger for wandb. To be used the same way as TrainLogger. Handles nested dicts as well.
-    An example on how to use this can be found in recipes/Voicebank/MTL/CoopNet/"""
+    """
+    Logger for WandB (Weights & Biases). This logger is designed to be used in the same way as TrainLogger
+    and supports handling nested dictionaries as well.
 
-    def __init__(self, *args, **kwargs):
-        try:
-            yaml_file = kwargs.pop("yaml_config")
-            with open(yaml_file, "r") as yaml_stream:
-                # Read yaml with ruamel to ignore bangs
-                config_dict = ruamel.yaml.YAML().load(yaml_stream)
+    Arguments
+    ---------
+    initializer: callable
+        A callable function that initializes the WandB run.
+        For more information on the parameters that can be passed to the initializer, refer to
+        the documentation: https://docs.wandb.ai/ref/python/init
+    *args: tuple
+        Positional arguments to be passed to the initializer function.
+    **kwargs: dict
+        Keyword arguments to be passed to the initializer function.
+
+    Example
+    -------
+    To initialize the logger, use the following pattern in hparams.yaml:
+
+    ```
+    train_logger: !new:speechbrain.utils.train_logger.WandBLogger
+        initializer: !name:wandb.init
+            entity: speechbrain
+            project: sb_project
+            name: sb_run
+            reinit: True
+            resume: False
+            dir: !ref <output_folder>/wandb
+    ```
+
+    NOTE
+    ----
+    If there is an issue with the WandB Logger initialization, it raises an exception.
+    """
 
-            # Run initializer only on main
+    def __init__(self, initializer, *args, **kwargs):
+        try:
             self.run = None
             if if_main_process():
-                self.run = kwargs.pop("initializer", None)(
-                    *args, **kwargs, config=config_dict
-                )
+                self.run = initializer(*args, **kwargs)
         except Exception as e:
             raise e("There was an issue with the WandB Logger initialization")
 
diff --git a/speechbrain/version.txt b/speechbrain/version.txt
index 83ac1cc02fe62e730fd785d6322247fecf700a83..3eefcb9dd5b38e2c1dc061052455dd97bcd51e6c 100644
--- a/speechbrain/version.txt
+++ b/speechbrain/version.txt
@@ -1 +1 @@
-0.5.14
+1.0.0
diff --git a/templates/enhancement/train.py b/templates/enhancement/train.py
index c944b1b7b6f2babd9902c844a3e721dfc3d06fcb..dc20338310e7df618b26c855c96c10f69deac13d 100644
--- a/templates/enhancement/train.py
+++ b/templates/enhancement/train.py
@@ -48,7 +48,12 @@ class SEBrain(sb.Brain):
         # We first move the batch to the appropriate device, and
         # compute the features necessary for masking.
         batch = batch.to(self.device)
-        noisy_wavs, lens = batch.noisy_sig
+        self.clean_wavs, self.lens = batch.clean_sig
+
+        noisy_wavs, self.lens = self.hparams.wav_augment(
+            self.clean_wavs, self.lens
+        )
+
         noisy_feats = self.compute_feats(noisy_wavs)
 
         # Masking is done here with the "signal approximation (SA)" algorithm.
@@ -102,15 +107,20 @@ class SEBrain(sb.Brain):
         """
 
         # Prepare clean targets for comparison
-        clean_wavs, lens = batch.clean_sig
-        clean_spec = self.compute_feats(clean_wavs)
+        clean_spec = self.compute_feats(self.clean_wavs)
 
         # Directly compare the masked spectrograms with the clean targets
-        loss = sb.nnet.losses.mse_loss(predictions["spec"], clean_spec, lens)
+        loss = sb.nnet.losses.mse_loss(
+            predictions["spec"], clean_spec, self.lens
+        )
 
         # Append this batch of losses to the loss metric for easy
         self.loss_metric.append(
-            batch.id, predictions["spec"], clean_spec, lens, reduction="batch"
+            batch.id,
+            predictions["spec"],
+            clean_spec,
+            self.lens,
+            reduction="batch",
         )
 
         # Some evaluations are slower, and we only want to perform them
@@ -121,8 +131,8 @@ class SEBrain(sb.Brain):
             self.stoi_metric.append(
                 batch.id,
                 predictions["wav"],
-                clean_wavs,
-                lens,
+                self.clean_wavs,
+                self.lens,
                 reduction="batch",
             )
 
@@ -220,15 +230,12 @@ def dataio_prep(hparams):
     # Define audio pipeline. Adds noise, reverb, and babble on-the-fly.
     # Of course for a real enhancement dataset, you'd want a fixed valid set.
     @sb.utils.data_pipeline.takes("wav")
-    @sb.utils.data_pipeline.provides("noisy_sig", "clean_sig")
+    @sb.utils.data_pipeline.provides("clean_sig")
     def audio_pipeline(wav):
         """Load the signal, and pass it and its length to the corruption class.
         This is done on the CPU in the `collate_fn`."""
         clean_sig = sb.dataio.dataio.read_audio(wav)
-        noisy_sig = hparams["env_corruption"](
-            clean_sig.unsqueeze(0), torch.ones(1)
-        ).squeeze(0)
-        return noisy_sig, clean_sig
+        return clean_sig
 
     # Define datasets sorted by ascending lengths for efficiency
     datasets = {}
@@ -243,7 +250,7 @@ def dataio_prep(hparams):
             json_path=data_info[dataset],
             replacements={"data_root": hparams["data_folder"]},
             dynamic_items=[audio_pipeline],
-            output_keys=["id", "noisy_sig", "clean_sig"],
+            output_keys=["id", "clean_sig"],
         ).filtered_sorted(sort_key="length")
     return datasets
 
@@ -279,6 +286,7 @@ if __name__ == "__main__":
                 "save_json_test": hparams["test_annotation"],
             },
         )
+    sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
 
     # Create dataset objects "train" and "valid"
     datasets = dataio_prep(hparams)
diff --git a/templates/enhancement/train.yaml b/templates/enhancement/train.yaml
index c3cd1a21951e4b459f8ef7917532f608a87b78c8..a6e4fc26fd45752a664a905ee0783d26596183bc 100755
--- a/templates/enhancement/train.yaml
+++ b/templates/enhancement/train.yaml
@@ -25,7 +25,6 @@ data_folder: ./data
 output_folder: !ref ./results/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
-rir_folder: !ref <data_folder>
 
 # Path where data manifest files will be stored
 # The data manifest files are created by the data preparation script.
@@ -34,6 +33,12 @@ valid_annotation: valid.json
 test_annotation: test.json
 skip_prep: False
 
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: noise.csv #The data manifest files are created by the data preparation script
+
+
 # The train logger writes training statistics to a file, as well as stdout.
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
@@ -49,8 +54,10 @@ window_fn: !name:torch.hamming_window
 number_of_epochs: 20
 batch_size: 8
 learning_rate: 0.0001
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 # The mask operates on log-spectral features, computed using these
 # STFT parameters, as well as computing magnitude and log1p.
@@ -71,17 +78,38 @@ resynth: !name:speechbrain.processing.signal_processing.resynthesize
     stft: !ref <compute_STFT>
     istft: !ref <compute_ISTFT>
 
-# Added noise and reverb come from OpenRIR dataset, automatically
-# downloaded and prepared with this Environmental Corruption class.
-# The babble is generated from other utterances in each batch.
-env_corruption: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    openrir_max_noise_len: 10
-    noise_snr_low: 0
-    noise_snr_high: 15
-    babble_speaker_count: !ref <batch_size> - 1
-    babble_snr_low: 0
-    babble_snr_high: 15
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: False
+    concat_original: False
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 1
+    max_augmentations: 1
+    augment_prob: 1.0
+    augmentations: [!ref <add_noise>]
 
 # To design a custom model, either just edit the simple CustomModel
 # class that's listed here, or replace this `!new` call with a line
diff --git a/templates/hyperparameter_optimization_speaker_id/train.py b/templates/hyperparameter_optimization_speaker_id/train.py
index a4218c8aa409de881892091762c41b46bdfd026c..283d27aefa87fcdb3ccba12509efd2bf9b8c0924 100644
--- a/templates/hyperparameter_optimization_speaker_id/train.py
+++ b/templates/hyperparameter_optimization_speaker_id/train.py
@@ -33,7 +33,6 @@ Authors
 """
 import os
 import sys
-import torch
 import speechbrain as sb
 from hyperpyyaml import load_hyperpyyaml
 from mini_librispeech_prepare import prepare_mini_librispeech
@@ -83,18 +82,9 @@ class SpkIdBrain(sb.Brain):
         """
         wavs, lens = wavs
 
-        # Add augmentation if specified. In this version of augmentation, we
-        # concatenate the original and the augment batches in a single bigger
-        # batch. This is more memory-demanding, but helps to improve the
-        # performance. Change it if you run OOM.
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                lens = torch.cat([lens, lens])
-
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, lens = self.hparams.wav_augment(wavs, lens)
 
         # Feature extraction and normalization
         feats = self.modules.compute_features(wavs)
@@ -123,10 +113,9 @@ class SpkIdBrain(sb.Brain):
         _, lens = batch.sig
         spkid, _ = batch.spk_id_encoded
 
-        # Concatenate labels (due to data augmentation)
-        if stage == sb.Stage.TRAIN and hasattr(self.modules, "env_corrupt"):
-            spkid = torch.cat([spkid, spkid], dim=0)
-            lens = torch.cat([lens, lens])
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            spkid = self.hparams.wav_augment.replicate_labels(spkid)
+            lens = self.hparams.wav_augment.replicate_labels(lens)
 
         # Compute the cost function
         loss = sb.nnet.losses.nll_loss(predictions, spkid, lens)
@@ -324,6 +313,7 @@ if __name__ == "__main__":
                     "split_ratio": hparams["split_ratio"],
                 },
             )
+        sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
 
         # Create dataset objects "train", "valid", and "test".
         datasets = dataio_prep(hparams)
diff --git a/templates/hyperparameter_optimization_speaker_id/train.yaml b/templates/hyperparameter_optimization_speaker_id/train.yaml
index a2784a88bc79f2e1ee690dfc53ab01a48f8cbf03..ac04c89156342d8a554051e72dd389c524976ab2 100644
--- a/templates/hyperparameter_optimization_speaker_id/train.yaml
+++ b/templates/hyperparameter_optimization_speaker_id/train.yaml
@@ -24,7 +24,6 @@ data_folder: ./data
 output_folder: !ref ./results/speaker_id/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
-rir_folder: !ref <data_folder>
 
 # Path where data manifest files will be stored
 # The data manifest files are created by the data preparation script.
@@ -34,6 +33,12 @@ test_annotation: test.json
 split_ratio: [80, 10, 10]
 skip_prep: False
 
+
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
 # The train logger writes training statistics to a file, as well as stdout.
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
@@ -58,25 +63,78 @@ tdnn_channels: 512
 tdnn_channels_out: 1500
 n_classes: 28 # In this case, we have 28 speakers
 emb_dim: 512 # dimensionality of the embeddings
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
-
-
-# Added noise and reverb come from OpenRIR dataset, automatically
-# downloaded and prepared with this Environmental Corruption class.
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-
-# Adds speech change + time and frequency dropouts (time-domain implementation)
-# # A small speed change help to improve the performance of speaker-id as well.
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
+    num_workers: !ref <num_workers>
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: False
+    concat_original: True
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 # Feature extraction
 compute_features: !new:speechbrain.lobes.features.Fbank
@@ -121,8 +179,6 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
 # device, as well as having train()/eval() called on them by the Brain class.
 modules:
     compute_features: !ref <compute_features>
-    env_corrupt: !ref <env_corrupt>
-    augmentation: !ref <augmentation>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     mean_var_norm: !ref <mean_var_norm>
diff --git a/templates/speaker_id/train.py b/templates/speaker_id/train.py
index 5a2d180135c27ba518c81b931a3712e3710d8b44..2549c70888d9710358ab1b1136595a9724203839 100644
--- a/templates/speaker_id/train.py
+++ b/templates/speaker_id/train.py
@@ -26,7 +26,6 @@ Authors
 """
 import os
 import sys
-import torch
 import speechbrain as sb
 from hyperpyyaml import load_hyperpyyaml
 from mini_librispeech_prepare import prepare_mini_librispeech
@@ -75,18 +74,9 @@ class SpkIdBrain(sb.Brain):
         """
         wavs, lens = wavs
 
-        # Add augmentation if specified. In this version of augmentation, we
-        # concatenate the original and the augment batches in a single bigger
-        # batch. This is more memory-demanding, but helps to improve the
-        # performance. Change it if you run OOM.
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                lens = torch.cat([lens, lens])
-
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, lens = self.hparams.wav_augment(wavs, lens)
 
         # Feature extraction and normalization
         feats = self.modules.compute_features(wavs)
@@ -116,9 +106,9 @@ class SpkIdBrain(sb.Brain):
         spkid, _ = batch.spk_id_encoded
 
         # Concatenate labels (due to data augmentation)
-        if stage == sb.Stage.TRAIN and hasattr(self.modules, "env_corrupt"):
-            spkid = torch.cat([spkid, spkid], dim=0)
-            lens = torch.cat([lens, lens])
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            spkid = self.hparams.wav_augment.replicate_labels(spkid)
+            lens = self.hparams.wav_augment.replicate_labels(lens)
 
         # Compute the cost function
         loss = sb.nnet.losses.nll_loss(predictions, spkid, lens)
@@ -308,6 +298,7 @@ if __name__ == "__main__":
                 "split_ratio": hparams["split_ratio"],
             },
         )
+    sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
 
     # Create dataset objects "train", "valid", and "test".
     datasets = dataio_prep(hparams)
diff --git a/templates/speaker_id/train.yaml b/templates/speaker_id/train.yaml
index 26a8e3ac318c30daa6c5e38b3fc0fff908e3ced7..09c08cbbe7c4c715b60d2aeb32d95617edbde83f 100644
--- a/templates/speaker_id/train.yaml
+++ b/templates/speaker_id/train.yaml
@@ -22,7 +22,6 @@ data_folder: ./data
 output_folder: !ref ./results/speaker_id/<seed>
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
-rir_folder: !ref <data_folder>
 
 # Path where data manifest files will be stored
 # The data manifest files are created by the data preparation script.
@@ -32,6 +31,11 @@ test_annotation: test.json
 split_ratio: [80, 10, 10]
 skip_prep: False
 
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
 # The train logger writes training statistics to a file, as well as stdout.
 train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
     save_file: !ref <train_log>
@@ -53,25 +57,79 @@ lr_start: 0.001
 lr_final: 0.0001
 n_classes: 28 # In this case, we have 28 speakers
 emb_dim: 512 # dimensionality of the embeddings
+num_workers: 4
 dataloader_options:
     batch_size: !ref <batch_size>
-
-
-# Added noise and reverb come from OpenRIR dataset, automatically
-# downloaded and prepared with this Environmental Corruption class.
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <rir_folder>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-
-# Adds speech change + time and frequency dropouts (time-domain implementation)
-# # A small speed change help to improve the performance of speaker-id as well.
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
+    num_workers: !ref <num_workers>
+
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: False
+    concat_original: True
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 # Feature extraction
 compute_features: !new:speechbrain.lobes.features.Fbank
@@ -111,8 +169,6 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
 # device, as well as having train()/eval() called on them by the Brain class.
 modules:
     compute_features: !ref <compute_features>
-    env_corrupt: !ref <env_corrupt>
-    augmentation: !ref <augmentation>
     embedding_model: !ref <embedding_model>
     classifier: !ref <classifier>
     mean_var_norm: !ref <mean_var_norm>
diff --git a/templates/speech_recognition/ASR/train.py b/templates/speech_recognition/ASR/train.py
index 4dd9711479c6c0eb485af6e70b0ac9e7c72b79a2..8c0017333481729c4dba35ff59b7d3e17e92b538 100644
--- a/templates/speech_recognition/ASR/train.py
+++ b/templates/speech_recognition/ASR/train.py
@@ -78,6 +78,7 @@ class ASR(sb.Brain):
         """
         # We first move the batch to the appropriate device.
         batch = batch.to(self.device)
+
         feats, self.feat_lens = self.prepare_features(stage, batch.sig)
         tokens_bos, _ = self.prepare_tokens(stage, batch.tokens_bos)
 
@@ -85,7 +86,7 @@ class ASR(sb.Brain):
         encoded_signal = self.modules.encoder(feats.detach())
 
         # Embed tokens and pass tokens & encoded signal to decoder
-        embedded_tokens = self.modules.embedding(tokens_bos)
+        embedded_tokens = self.modules.embedding(tokens_bos.detach())
         decoder_outputs, _ = self.modules.decoder(
             embedded_tokens, encoded_signal, self.feat_lens
         )
@@ -98,14 +99,18 @@ class ASR(sb.Brain):
             # Output layer for ctc log-probabilities
             ctc_logits = self.modules.ctc_lin(encoded_signal)
             predictions["ctc_logprobs"] = self.hparams.log_softmax(ctc_logits)
-        elif stage == sb.Stage.VALID:
-            predictions["tokens"], _ = self.hparams.valid_search(
-                encoded_signal, self.feat_lens
-            )
-        elif stage == sb.Stage.TEST:
-            predictions["tokens"], _ = self.hparams.test_search(
-                encoded_signal, self.feat_lens
-            )
+
+        elif stage != sb.Stage.TRAIN:
+            if stage == sb.Stage.VALID:
+                hyps, _, _, _ = self.hparams.valid_search(
+                    encoded_signal, self.feat_lens
+                )
+            elif stage == sb.Stage.TEST:
+                hyps, _, _, _ = self.hparams.test_search(
+                    encoded_signal, self.feat_lens
+                )
+
+            predictions["tokens"] = hyps
 
         return predictions
 
@@ -134,27 +139,24 @@ class ASR(sb.Brain):
         """
         wavs, wav_lens = wavs
 
-        # Add augmentation if specified. In this version of augmentation, we
-        # concatenate the original and the augment batches in a single bigger
-        # batch. This is more memory-demanding, but helps to improve the
-        # performance. Change it if you run OOM.
-        if stage == sb.Stage.TRAIN:
-            if hasattr(self.modules, "env_corrupt"):
-                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
-                wavs = torch.cat([wavs, wavs_noise], dim=0)
-                wav_lens = torch.cat([wav_lens, wav_lens])
-
-            if hasattr(self.hparams, "augmentation"):
-                wavs = self.hparams.augmentation(wavs, wav_lens)
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
+            wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
 
         # Feature computation and normalization
+        fea_lens = wav_lens  # Relative lenghs are preserved
+
+        # Add feature augmentation if specified.
         feats = self.hparams.compute_features(wavs)
-        feats = self.modules.normalize(feats, wav_lens)
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, fea_lens)
+        feats = self.modules.normalize(feats, fea_lens)
 
-        return feats, wav_lens
+        return feats, fea_lens
 
     def prepare_tokens(self, stage, tokens):
-        """Double the tokens batch if features are doubled.
+        """
+        Augments the tokens batch if needed.
 
         Arguments
         ---------
@@ -162,11 +164,24 @@ class ASR(sb.Brain):
             Currently executing stage.
         tokens : tuple
             The tokens (tensor) and their lengths (tensor).
+
+        Returns
+        -------
+        tuple
+            Augmented tokens and their lengths.
         """
         tokens, token_lens = tokens
-        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
-            tokens = torch.cat([tokens, tokens], dim=0)
-            token_lens = torch.cat([token_lens, token_lens], dim=0)
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "wav_augment"):
+                tokens = self.hparams.wav_augment.replicate_labels(tokens)
+                token_lens = self.hparams.wav_augment.replicate_labels(
+                    token_lens
+                )
+            if hasattr(self.hparams, "fea_augment"):
+                tokens = self.hparams.fea_augment.replicate_labels(tokens)
+                token_lens = self.hparams.fea_augment.replicate_labels(
+                    token_lens
+                )
         return tokens, token_lens
 
     def compute_objectives(self, predictions, batch, stage):
@@ -188,6 +203,7 @@ class ASR(sb.Brain):
         loss : torch.Tensor
             A one-element tensor used for backpropagating the gradient.
         """
+
         # Compute sequence loss against targets with EOS
         tokens_eos, tokens_eos_lens = self.prepare_tokens(
             stage, batch.tokens_eos
@@ -373,12 +389,20 @@ def dataio_prepare(hparams):
     # does not harm the performance.
     if hparams["sorting"] == "ascending":
         datasets["train"] = datasets["train"].filtered_sorted(sort_key="length")
+        datasets["valid"] = datasets["valid"].filtered_sorted(sort_key="length")
+        datasets["test"] = datasets["test"].filtered_sorted(sort_key="length")
         hparams["train_dataloader_opts"]["shuffle"] = False
 
     elif hparams["sorting"] == "descending":
         datasets["train"] = datasets["train"].filtered_sorted(
             sort_key="length", reverse=True
         )
+        datasets["valid"] = datasets["valid"].filtered_sorted(
+            sort_key="length", reverse=True
+        )
+        datasets["test"] = datasets["test"].filtered_sorted(
+            sort_key="length", reverse=True
+        )
         hparams["train_dataloader_opts"]["shuffle"] = False
 
     elif hparams["sorting"] == "random":
@@ -422,6 +446,8 @@ if __name__ == "__main__":
                 "save_json_test": hparams["test_annotation"],
             },
         )
+    sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])
+    sb.utils.distributed.run_on_main(hparams["prepare_rir_data"])
 
     # We can now directly create the datasets for training, valid, and test
     datasets = dataio_prepare(hparams)
@@ -432,7 +458,7 @@ if __name__ == "__main__":
     # We download the pretrained LM from HuggingFace (or elsewhere depending on
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Trainer initialization
     asr_brain = ASR(
diff --git a/templates/speech_recognition/ASR/train.yaml b/templates/speech_recognition/ASR/train.yaml
index f911774071cc714fa6ac328d10ff3d03094853a3..0acc053869354e6c61c10d7a70e364510945de8f 100644
--- a/templates/speech_recognition/ASR/train.yaml
+++ b/templates/speech_recognition/ASR/train.yaml
@@ -23,7 +23,13 @@ __set_seed: !apply:torch.manual_seed [!ref <seed>]
 # It allows you to read the data much faster without slowing down the shared filesystem.
 
 data_folder: ../data # In this case, data will be automatically downloaded here.
-data_folder_rirs: !ref <data_folder> # noise/ris dataset will automatically be downloaded here
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+data_folder_rir: !ref <data_folder>/rir # The impulse responses used for data augmentation will automatically be downloaded here.
+
+# Data for augmentation
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
+
 output_folder: !ref results/CRDNN_BPE_960h_LM/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
@@ -43,6 +49,9 @@ pretrained_path: speechbrain/asr-crdnn-rnnlm-librispeech
 train_annotation: ../train.json
 valid_annotation: ../valid.json
 test_annotation: ../test.json
+noise_annotation: ../noise.csv
+rir_annotation: ../rir.csv
+
 skip_prep: False
 
 # The train logger writes training statistics to a file, as well as stdout.
@@ -59,22 +68,112 @@ sorting: ascending
 ckpt_interval_minutes: 15 # save checkpoint every N min
 label_smoothing: 0.1
 
+#Optimal Number of Workers for Data Reading
+#The ideal value depends on your machine's hardware, such as the number of available CPUs.
+num_workers: 4
+
 # Dataloader options
 train_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 valid_dataloader_opts:
     batch_size: !ref <batch_size>
+    num_workers: !ref <num_workers>
 
 test_dataloader_opts:
     batch_size: !ref <batch_size>
-
+    num_workers: !ref <num_workers>
 
 # Feature parameters
 sample_rate: 16000
 n_fft: 400
 n_mels: 40
 
+# NOTE ON DATA AUGMENTATION
+# This template demonstrates the use of all available data augmentation strategies
+# to illustrate how they work and how you can combine them with the augmenter.
+# In practical applications (e.g., refer to other recipes), it is usually advisable
+# to select a subset of these strategies for better performance.
+
+# Waveform Augmentation Functions
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+speed_changes: [85, 90, 95, 105, 110, 115]  # List of speed changes for time-stretching
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 3  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+clip_low: 0.1  # Min amplitude to clip
+clip_high: 0.5  # Max amplitude to clip
+amp_low: 0.05  # Min waveform amplitude
+amp_high: 1.0  # Max waveform amplitude
+babble_snr_low: 5  # Min SNR for babble (batch sum noise)
+babble_snr_high: 15  # Max SNR for babble (batch sum noise)
+
+# Feature Augmentation Functions
+min_time_shift: 0  # Min random shift of spectrogram in time
+max_time_shift: 15  # Max random shift of spectrogram in time
+min_freq_shift: 0  # Min random shift of spectrogram in frequency
+max_freq_shift: 5  # Max random shift of spectrogram in frequency
+time_drop_length_low: 5  # Min length for temporal chunk to drop in spectrogram
+time_drop_length_high: 15  # Max length for temporal chunk to drop in spectrogram
+time_drop_count_low: 1  # Min number of chunks to drop in time in the spectrogram
+time_drop_count_high: 3  # Max number of chunks to drop in time in the spectrogram
+time_drop_replace: "zeros"  # Method of dropping chunks
+freq_drop_length_low: 1  # Min length for chunks to drop in frequency in the spectrogram
+freq_drop_length_high: 5  # Max length for chunks to drop in frequency in the spectrogram
+freq_drop_count_low: 1  # Min number of chunks to drop in frequency in the spectrogram
+freq_drop_count_high: 3  # Max number of chunks to drop in frequency in the spectrogram
+freq_drop_replace: "zeros"  # Method of dropping chunks
+time_warp_window: 20  # Length of time warping window
+time_warp_mode: "bicubic"  # Time warping method
+freq_warp_window: 4  # Length of frequency warping window
+freq_warp_mode: "bicubic"  # Frequency warping method
+
+# Enable Waveform Augmentation Flags (useful for hyperparameter tuning)
+enable_codec_augment: False
+enable_add_reverb: True
+enable_add_noise: True
+enable_speed_perturb: True
+enable_drop_freq: True
+enable_drop_chunk: True
+enable_clipping: True
+enable_rand_amp: True
+enable_babble_noise: True
+enable_drop_resolution: True
+
+# Enable Feature Augmentations Flags (useful for hyperparameter tuning)
+enable_time_shift: True
+enable_freq_shift: True
+enable_time_drop: True
+enable_freq_drop: True
+enable_time_warp: True
+enable_freq_warp: True
+
+# Waveform Augmenter (combining augmentations)
+time_parallel_augment: False  # Apply augmentations in parallel if True, or sequentially if False
+time_concat_original: True  # Concatenate original signals to the training batch if True
+time_repeat_augment: 1  # Number of times to apply augmentation
+time_shuffle_augmentations: True  # Shuffle order of augmentations if True, else use specified order
+time_min_augmentations: 1  # Min number of augmentations to apply
+time_max_augmentations: 10  # Max number of augmentations to apply
+time_augment_prob: 1.0     # Probability to apply time augmentation
+
+# Feature Augmenter (combining augmentations)
+fea_parallel_augment: False  # Apply feature augmentations in parallel if True, or sequentially if False
+fea_concat_original: True  # Concatenate original signals to the training batch if True
+fea_repeat_augment: 1  # Number of times to apply feature augmentation
+fea_shuffle_augmentations: True  # Shuffle order of feature augmentations if True, else use specified order
+fea_min_augmentations: 1  # Min number of feature augmentations to apply
+fea_max_augmentations: 6  # Max number of feature augmentations to app
+fea_augment_prob: 1.0     # Probability to apply feature augmentation
+
 # Model parameters
 activation: !name:torch.nn.LeakyReLU
 dropout: 0.15
@@ -126,20 +225,189 @@ compute_features: !new:speechbrain.lobes.features.Fbank
 normalize: !new:speechbrain.processing.features.InputNormalization
     norm_type: global
 
-# Added noise and reverb come from OpenRIR dataset, automatically
-# downloaded and prepared with this Environmental Corruption class.
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
-
-# Adds speech change + time and frequency dropouts (time-domain implementation).
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+# Download and prepare the dataset of room impulse responses for augmentation
+prepare_rir_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <RIR_DATASET_URL>
+    dest_folder: !ref <data_folder_rir>
+    ext: wav
+    csv_file: !ref <rir_annotation>
+
+
+# ----- WAVEFORM AUGMENTATION ----- #
+
+# Codec augmentation
+codec_augment: !new:speechbrain.augment.codec.CodecAugment
     sample_rate: !ref <sample_rate>
-    speeds: [95, 100, 105]
+
+# Add reverberation to input signal
+add_reverb: !new:speechbrain.augment.time_domain.AddReverb
+    csv_file: !ref <rir_annotation>
+    reverb_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Add noise to input signal
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Clipping
+clipping: !new:speechbrain.augment.time_domain.DoClip
+    clip_low: !ref <clip_low>
+    clip_high: !ref <clip_high>
+
+# Random Amplitude
+rand_amp: !new:speechbrain.augment.time_domain.RandAmp
+    amp_low: !ref <amp_low>
+    amp_high: !ref <amp_high>
+
+# Noise sequence derived by summing up all the signals in the batch
+# It is similar to babble noise
+sum_batch: !name:torch.sum
+    dim: 0
+    keepdim: True
+
+babble_noise: !new:speechbrain.augment.time_domain.AddNoise
+    snr_low: !ref <babble_snr_low>
+    snr_high: !ref <babble_snr_high>
+    noise_funct: !ref <sum_batch>
+
+drop_resolution: !new:speechbrain.augment.time_domain.DropBitResolution
+    target_dtype: 'random'
+
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: !ref <time_parallel_augment>
+    concat_original: !ref <time_concat_original>
+    repeat_augment: !ref <time_repeat_augment>
+    shuffle_augmentations: !ref <time_shuffle_augmentations>
+    min_augmentations: !ref <time_min_augmentations>
+    max_augmentations: !ref <time_max_augmentations>
+    augment_prob: !ref <time_augment_prob>
+    augmentations: [
+        !ref <codec_augment>,
+        !ref <add_reverb>,
+        !ref <add_noise>,
+        !ref <babble_noise>,
+        !ref <speed_perturb>,
+        !ref <clipping>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>,
+        !ref <rand_amp>,
+        !ref <drop_resolution>]
+    enable_augmentations: [
+        !ref <enable_codec_augment>,
+        !ref <enable_add_reverb>,
+        !ref <enable_add_noise>,
+        !ref <enable_babble_noise>,
+        !ref <enable_speed_perturb>,
+        !ref <enable_clipping>,
+        !ref <enable_drop_freq>,
+        !ref <enable_drop_chunk>,
+        !ref <enable_rand_amp>,
+        !ref <enable_drop_resolution>]
+
+
+# ----- FEATURE AUGMENTATION ----- #
+
+# Time shift
+time_shift: !new:speechbrain.augment.freq_domain.RandomShift
+    min_shift: !ref <min_time_shift>
+    max_shift: !ref <max_time_shift>
+    dim: 1
+
+# Frequency shift
+freq_shift: !new:speechbrain.augment.freq_domain.RandomShift
+    min_shift: !ref <min_freq_shift>
+    max_shift: !ref <max_freq_shift>
+    dim: 2
+
+# Time Drop
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: !ref <time_drop_length_low>
+    drop_length_high: !ref <time_drop_length_high>
+    drop_count_low: !ref <time_drop_count_low>
+    drop_count_high: !ref <time_drop_count_high>
+    replace: !ref <time_drop_replace>
+    dim: 1
+
+# Frequency Drop
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+    drop_length_low: !ref <freq_drop_length_low>
+    drop_length_high: !ref <freq_drop_length_high>
+    drop_count_low: !ref <freq_drop_count_low>
+    drop_count_high: !ref <freq_drop_count_high>
+    replace: !ref <freq_drop_replace>
+    dim: 2
+
+# Time warp
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+    warp_window: !ref <time_warp_window>
+    warp_mode: !ref <time_warp_mode>
+    dim: 1
+
+freq_warp: !new:speechbrain.augment.freq_domain.Warping
+    warp_window: !ref <freq_warp_window>
+    warp_mode: !ref <freq_warp_mode>
+    dim: 2
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: !ref <fea_parallel_augment>
+    concat_original: !ref <fea_concat_original>
+    repeat_augment: !ref <fea_repeat_augment>
+    shuffle_augmentations: !ref <fea_shuffle_augmentations>
+    min_augmentations: !ref <fea_min_augmentations>
+    max_augmentations: !ref <fea_max_augmentations>
+    augment_start_index: !ref <batch_size> # This leaves unchanges original inputs
+    concat_end_index: !ref <batch_size> # This leaves unchanges original inputs
+    augment_prob: !ref <fea_augment_prob>
+    augmentations: [
+        !ref <time_shift>,
+        !ref <freq_shift>,
+        !ref <time_drop>,
+        !ref <freq_drop>,
+        !ref <time_warp>,
+        !ref <freq_warp>]
+    enable_augmentations: [
+        !ref <enable_time_shift>,
+        !ref <enable_freq_shift>,
+        !ref <enable_time_drop>,
+        !ref <enable_freq_drop>,
+        !ref <enable_time_warp>,
+        !ref <enable_freq_warp>]
 
 # The CRDNN model is an encoder that combines CNNs, RNNs, and DNNs.
 encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN
@@ -213,7 +481,6 @@ modules:
     ctc_lin: !ref <ctc_lin>
     seq_lin: !ref <seq_lin>
     normalize: !ref <normalize>
-    env_corrupt: !ref <env_corrupt>
     lm_model: !ref <lm_model>
 
 # Gathering all the submodels in a single model object.
@@ -237,54 +504,88 @@ lm_model: !new:speechbrain.lobes.models.RNNLM.RNNLM
     dnn_neurons: 512
     return_hidden: True  # For inference
 
-# Beamsearch is applied on the top of the decoder. If the language model is
-# given, a language model is applied (with a weight specified in lm_weight).
-# If ctc_weight is set, the decoder uses CTC + attention beamsearch. This
-# improves the performance, but slows down decoding. For a description of
-# the other parameters, please see the speechbrain.decoders.S2SRNNBeamSearchLM.
+# Define scorers for beam search
+
+# If ctc_scorer is set, the decoder uses CTC + attention beamsearch. This
+# improves the performance, but slows down decoding.
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+# If coverage_scorer is set, coverage penalty is applied based on accumulated
+# attention weights during beamsearch.
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+    vocab_size: !ref <output_neurons>
+
+# If the lm_scorer is set, a language model
+# is applied (with a weight specified in scorer).
+rnnlm_scorer: !new:speechbrain.decoders.scorer.RNNLMScorer
+    language_model: !ref <lm_model>
+    temperature: !ref <temperature_lm>
+
+# Gathering all scorers in a scorer instance for beamsearch:
+# - full_scorers are scorers which score on full vocab set, while partial_scorers
+# are scorers which score on pruned tokens.
+# - The number of pruned tokens is decided by scorer_beam_scale * beam_size.
+# - For some scorers like ctc_scorer, ngramlm_scorer, putting them
+# into full_scorers list would be too heavy. partial_scorers are more
+# efficient because they score on pruned tokens at little cost of
+# performance drop. For other scorers, please see the speechbrain.decoders.scorer.
+test_scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    scorer_beam_scale: 1.5
+    full_scorers: [
+        !ref <rnnlm_scorer>,
+        !ref <coverage_scorer>]
+    partial_scorers: [!ref <ctc_scorer>]
+    weights:
+        rnnlm: !ref <lm_weight>
+        coverage: !ref <coverage_penalty>
+        ctc: !ref <ctc_weight_decode>
+
+valid_scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    full_scorers: [!ref <coverage_scorer>]
+    weights:
+        coverage: !ref <coverage_penalty>
+
+# Beamsearch is applied on the top of the decoder. For a description of
+# the other parameters, please see the speechbrain.decoders.S2SRNNBeamSearcher.
 
 # It makes sense to have a lighter search during validation. In this case,
-# we don't use the LM and CTC probabilities during decoding.
+# we don't use scorers during decoding.
 valid_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
     embedding: !ref <embedding>
     decoder: !ref <decoder>
     linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <valid_beam_size>
     eos_threshold: !ref <eos_threshold>
     using_max_attn_shift: !ref <using_max_attn_shift>
     max_attn_shift: !ref <max_attn_shift>
-    coverage_penalty: !ref <coverage_penalty>
     temperature: !ref <temperature>
+    scorer: !ref <valid_scorer>
 
 # The final decoding on the test set can be more computationally demanding.
-# In this case, we use the LM + CTC probabilities during decoding as well.
-# Please, remove this part if you need a faster decoder.
-test_search: !new:speechbrain.decoders.S2SRNNBeamSearchLM
+# In this case, we use the LM + CTC probabilities during decoding as well,
+# which are defined in scorer.
+# Please, remove scorer if you need a faster decoder.
+test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
     embedding: !ref <embedding>
     decoder: !ref <decoder>
     linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    language_model: !ref <lm_model>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
     eos_threshold: !ref <eos_threshold>
     using_max_attn_shift: !ref <using_max_attn_shift>
     max_attn_shift: !ref <max_attn_shift>
-    coverage_penalty: !ref <coverage_penalty>
-    lm_weight: !ref <lm_weight>
-    ctc_weight: !ref <ctc_weight_decode>
     temperature: !ref <temperature>
-    temperature_lm: !ref <temperature_lm>
+    scorer: !ref <test_scorer>
 
 # This function manages learning rate annealing over the epochs.
 # We here use the NewBoB algorithm, that anneals the learning rate if
diff --git a/templates/speech_recognition/LM/RNNLM.yaml b/templates/speech_recognition/LM/RNNLM.yaml
index 582b279d781dd279bc3475177dae1a00880d30f7..0495ace2e5c25f623b06557839c43f68d4a6326c 100644
--- a/templates/speech_recognition/LM/RNNLM.yaml
+++ b/templates/speech_recognition/LM/RNNLM.yaml
@@ -7,7 +7,7 @@
 # Seed needs to be set at top of yaml, before objects with parameters are made
 seed: 2602
 __set_seed: !apply:torch.manual_seed [!ref <seed>]
-data_folder: ./data/
+data_folder: data/
 output_folder: !ref results/RNNLM/
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -35,7 +35,7 @@ tokenizer_file: ../Tokenizer/save/1000_unigram.model
 number_of_epochs: 20
 batch_size: 80
 lr: 0.001
-accu_steps: 1 # Gradient accumulation to simulate large batch training
+grad_accumulation_factor: 1 # Gradient accumulation to simulate large batch training
 ckpt_interval_minutes: 15 # save checkpoint every N min
 
 # Dataloader options
diff --git a/templates/speech_recognition/LM/train.py b/templates/speech_recognition/LM/train.py
index add47eeca86c2c9ab47fb6c2ed8aaf4bad21a400..f6deaca473e448dfca5ab832f65e9a1f2a6b14d1 100755
--- a/templates/speech_recognition/LM/train.py
+++ b/templates/speech_recognition/LM/train.py
@@ -72,37 +72,9 @@ class LM(sb.core.Brain):
         )
         return loss
 
-    def fit_batch(self, batch):
-        """Runs all the steps needed to train the model on a single batch.
-
-        Arguments
-        ---------
-        batch : PaddedBatch
-            This batch object contains all the relevant tensors for computation.
-
-        Returns
-        -------
-        Loss : torch.Tensor
-            A tensor containing the loss (single real number).
-        """
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-
-        # Loss backpropagation (gradient computation)
-        (loss / self.hparams.accu_steps).backward()
-
-        # Manage gradient accumulation
-        if self.step % self.hparams.accu_steps == 0:
-
-            # Gradient clipping & early stop if loss is not fini
-            self.check_gradients(loss)
-
-            # Update the parameters
-            self.optimizer.step()
-
-            # Reset the gradient
-            self.optimizer.zero_grad()
-
+    def on_fit_batch_end(self, batch, outputs, loss, should_step):
+        """At the end of the optimizer step, apply noam annealing."""
+        if should_step:
             if isinstance(
                 self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
             ) or isinstance(
@@ -111,8 +83,6 @@ class LM(sb.core.Brain):
             ):
                 self.hparams.lr_annealing(self.optimizer)
 
-        return loss
-
     def on_stage_end(self, stage, stage_loss, epoch):
         """Gets called at the end of an epoch.
 
@@ -255,7 +225,7 @@ if __name__ == "__main__":
     # We download the tokenizer from HuggingFace (or elsewhere depending on
     # the path given in the YAML file).
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Create dataset objects "train", "valid", and "test"
     train_data, valid_data, test_data = dataio_prepare(hparams)
diff --git a/tests/.run-doctests.sh b/tests/.run-doctests.sh
index fcc3ed79f54f2fd73f87bb4e3b6b3d789eeaecad..bb463ee2acc2841a2b26dbe42d5b7d811e1c9cfc 100755
--- a/tests/.run-doctests.sh
+++ b/tests/.run-doctests.sh
@@ -5,5 +5,5 @@ set -e -u -o pipefail
 # > pytest --doctest-modules speechbrain/
 # However, we take this more complex approach to avoid testing files not
 # tracked by git. We filter out tests that require optional dependencies.
-avoid="transducer_loss.py\|fairseq_wav2vec.py\|huggingface_wav2vec.py\|bleu.py\|ctc_segmentation.py\|check_url.py\|huggingface_whisper.py"
+avoid="transducer_loss.py\|fairseq_wav2vec.py\|huggingface_wav2vec.py\|bleu.py\|ctc_segmentation.py\|check_url.py\|huggingface_whisper.py\|language_model.py\|vocos.py|discrete_wav2vec2.py\|discrete_wavlm.py\|discrete_hubert.py"
 git ls-files speechbrain | grep -e "\.py$" | grep -v $avoid | xargs pytest --doctest-modules
diff --git a/tests/.run-unittests.sh b/tests/.run-unittests.sh
index 2d2f2c5cedfa7027738d3f081fc210a0893f8871..85ede08c9434160fac971612b6d97ef83f25f357 100755
--- a/tests/.run-unittests.sh
+++ b/tests/.run-unittests.sh
@@ -1,4 +1,4 @@
 #!/bin/bash
 set -e -u -o pipefail
 
-git ls-files tests/unittests | grep -e "\.py$" | xargs pytest
+git ls-files tests/unittests | grep -e "\.py$" | xargs pytest
\ No newline at end of file
diff --git a/tests/PRE-RELEASE-TESTS.md b/tests/PRE-RELEASE-TESTS.md
index 430b6080dac409bdb06eec22dfc2a194c7b9a744..11d44d2a0c381a3ec87cd7bf12152f2673e6b776 100644
--- a/tests/PRE-RELEASE-TESTS.md
+++ b/tests/PRE-RELEASE-TESTS.md
@@ -2,7 +2,7 @@
 
 1. Create a new environment. For instance, using conda:
 ```
-conda create --name fresh_env python=3.9
+conda create --name fresh_env python=3.11
 ```
 2. Activate the new environment
 ```
@@ -26,24 +26,30 @@ find recipes | grep extra | xargs cat | sort -u | grep -v \# | xargs -I {} pip i
 pip install fairseq
 conda install 'ffmpeg<4.4'
 ```
-7. Run the basic tests by typing:
+7. Update the PERFORMANCE.md file:
+```
+python tools/readme_builder.py --recipe_info_dir tests/recipes/ --output_file PERFORMANCE.md
+```
+Remember to push it.
+
+8. Run the basic tests by typing:
 ```
 pytest
 ```
-8. Run load yaml test:
+9. Run load yaml test:
 ```
 tests/.run-load-yaml-tests.sh
 ```
-9. Run recipe tests
+10. Run recipe tests
 ```
 tests/.run-recipe-tests.sh
 ```
-10. Make sure all HuggingFace repos are working
+11. Make sure all HuggingFace repos are working
 ```
 tests/.run-HF-checks.sh
 ```
-10. Make sure all HuggingFace API Interfaces are up to date and working (see [here](#huggingface-api-testing)])
-11. Check URLs
+12. Make sure all HuggingFace API Interfaces are up to date and working (see [here](#huggingface-api-testing)])
+13. Check URLs
 ```
 tests/.run-url-checks.sh
 ```
diff --git a/tests/consistency/test_recipe.py b/tests/consistency/test_recipe.py
index d2d1cfcf0acec80834982d0b68d0895637505d3d..55d17ef26ab3449661aac1714dbcfb1d77a7e654 100644
--- a/tests/consistency/test_recipe.py
+++ b/tests/consistency/test_recipe.py
@@ -23,6 +23,8 @@ def test_recipe_list(
         "recipes/Voicebank/MTL/CoopNet/hparams/logger.yaml",
         "recipes/LibriParty/generate_dataset/dataset.yaml",
         "hpopt.yaml",
+        "recipes/LJSpeech/TTS/quantization/hparams/kmeans.yaml",
+        "recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer.yaml",
     ],
 ):
     """This test checks if all the all hparam file of all the recipes are listed
diff --git a/tests/consistency/test_yaml.py b/tests/consistency/test_yaml.py
index 2ee2cba7f950c096767ea830142dc4913f85bc16..0ee5c8c57950506e1c6e42eff77b5bac7494ef3a 100644
--- a/tests/consistency/test_yaml.py
+++ b/tests/consistency/test_yaml.py
@@ -22,6 +22,7 @@ def test_yaml_script_consistency(recipe_folder="tests/recipes"):
 
     # Use this list to itemize special yaml for which we do not have to test
     avoid_check = []
+    check = True
 
     # Loop over all recipe CSVs
     for recipe_csvfile in os.listdir(recipe_folder):
@@ -30,7 +31,6 @@ def test_yaml_script_consistency(recipe_folder="tests/recipes"):
         with open(
             os.path.join(recipe_folder, recipe_csvfile), newline=""
         ) as csvfile:
-            check = True
             reader = csv.DictReader(
                 csvfile, delimiter=",", skipinitialspace=True
             )
diff --git a/tests/integration/ASR_CTC/example_asr_ctc_experiment.py b/tests/integration/ASR_CTC/example_asr_ctc_experiment.py
old mode 100755
new mode 100644
diff --git a/tests/integration/ASR_ConformerTransducer_streaming/example_asr_conformertransducer_streaming_experiment.py b/tests/integration/ASR_ConformerTransducer_streaming/example_asr_conformertransducer_streaming_experiment.py
new file mode 100644
index 0000000000000000000000000000000000000000..893c511397303dcd969af614e81f4ec0acf6d408
--- /dev/null
+++ b/tests/integration/ASR_ConformerTransducer_streaming/example_asr_conformertransducer_streaming_experiment.py
@@ -0,0 +1,281 @@
+#!/usr/bin/env/python3
+"""This minimal example trains a RNNT-based speech recognizer on a tiny dataset.
+The encoder is based on a Conformer model with the use of Dynamic Chunk Training
+ (with a Dynamic Chunk Convolution within the convolution modules) that predict
+phonemes. A greedy search is used on top of the output probabilities.
+Given the tiny dataset, the expected behavior is to overfit the training dataset
+(with a validation performance that stays high).
+"""
+import pathlib
+import speechbrain as sb
+from hyperpyyaml import load_hyperpyyaml
+import torch
+
+
+class ConformerTransducerBrain(sb.Brain):
+    def compute_forward(self, batch, stage):
+        """Forward computations from the waveform batches to the output probabilities."""
+        batch = batch.to(self.device)
+        wavs, wav_lens = batch.sig
+        phn_with_bos, phn_with_bos_lens = batch.phn_encoded_bos
+
+        # Add waveform augmentation if specified.
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "wav_augment"):
+                wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
+                phn_with_bos = self.hparams.wav_augment.replicate_labels(
+                    phn_with_bos
+                )
+
+        feats = self.hparams.compute_features(wavs)
+
+        # Add feature augmentation if specified.
+        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
+            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
+            phn_with_bos = self.hparams.fea_augment.replicate_labels(
+                phn_with_bos
+            )
+
+        current_epoch = self.hparams.epoch_counter.current
+
+        # Old models may not have the streaming hparam, we don't break them in
+        # any other way so just check for its presence
+        if hasattr(self.hparams, "streaming") and self.hparams.streaming:
+            dynchunktrain_config = self.hparams.dynchunktrain_config_sampler(
+                stage
+            )
+        else:
+            dynchunktrain_config = None
+
+        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
+
+        src = self.modules.CNN(feats)
+        x = self.modules.enc(
+            src,
+            wav_lens,
+            pad_idx=self.hparams.pad_index,
+            dynchunktrain_config=dynchunktrain_config,
+        )
+        x = self.modules.proj_enc(x)
+
+        e_in = self.modules.emb(phn_with_bos)
+        e_in = torch.nn.functional.dropout(
+            e_in,
+            self.hparams.dec_emb_dropout,
+            training=(stage == sb.Stage.TRAIN),
+        )
+        h, _ = self.modules.dec(e_in)
+        h = torch.nn.functional.dropout(
+            h, self.hparams.dec_dropout, training=(stage == sb.Stage.TRAIN)
+        )
+        h = self.modules.proj_dec(h)
+
+        # Joint network
+        # add labelseq_dim to the encoder tensor: [B,T,H_enc] => [B,T,1,H_enc]
+        # add timeseq_dim to the decoder tensor: [B,U,H_dec] => [B,1,U,H_dec]
+        joint = self.modules.Tjoint(x.unsqueeze(2), h.unsqueeze(1))
+
+        # Output layer for transducer log-probabilities
+        logits_transducer = self.modules.transducer_lin(joint)
+
+        # Compute outputs
+        if stage == sb.Stage.TRAIN:
+            p_ctc = None
+            p_ce = None
+
+            if self.hparams.ctc_weight > 0.0:
+                # Output layer for ctc log-probabilities
+                out_ctc = self.modules.proj_ctc(x)
+                p_ctc = self.hparams.log_softmax(out_ctc)
+
+            if self.hparams.ce_weight > 0.0:
+                # Output layer for ctc log-probabilities
+                p_ce = self.modules.dec_lin(h)
+                p_ce = self.hparams.log_softmax(p_ce)
+
+            return p_ctc, p_ce, logits_transducer, wav_lens
+
+        best_hyps, scores, _, _ = self.hparams.Greedysearcher(x)
+        return logits_transducer, wav_lens, best_hyps
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the loss (Transducer+(CTC+NLL)) given predictions and targets."""
+
+        ids = batch.id
+        phn, phn_lens = batch.phn_encoded
+        phn_with_eos, phn_with_eos_lens = batch.phn_encoded_eos
+
+        # Train returns 4 elements vs 3 for val and test
+        if len(predictions) == 4:
+            p_ctc, p_ce, logits_transducer, wav_lens = predictions
+        else:
+            logits_transducer, wav_lens, predicted_phn = predictions
+
+        if stage == sb.Stage.TRAIN:
+            if hasattr(self.hparams, "wav_augment"):
+                phn = self.hparams.wav_augment.replicate_labels(phn)
+                phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens)
+                phn_with_eos = self.hparams.wav_augment.replicate_labels(
+                    phn_with_eos
+                )
+                phn_with_eos_lens = self.hparams.wav_augment.replicate_labels(
+                    phn_with_eos_lens
+                )
+            if hasattr(self.hparams, "fea_augment"):
+                phn = self.hparams.fea_augment.replicate_labels(phn)
+                phn_lens = self.hparams.fea_augment.replicate_labels(phn_lens)
+                phn_with_eos = self.hparams.fea_augment.replicate_labels(
+                    phn_with_eos
+                )
+                phn_with_eos_lens = self.hparams.fea_augment.replicate_labels(
+                    phn_with_eos_lens
+                )
+
+        if stage == sb.Stage.TRAIN:
+            CTC_loss = 0.0
+            CE_loss = 0.0
+            if p_ctc is not None:
+                CTC_loss = self.hparams.ctc_cost(p_ctc, phn, wav_lens, phn_lens)
+            if p_ce is not None:
+                CE_loss = self.hparams.ce_cost(
+                    p_ce, phn_with_eos, length=phn_with_eos_lens
+                )
+            loss_transducer = self.hparams.transducer_cost(
+                logits_transducer, phn, wav_lens, phn_lens
+            )
+            loss = (
+                self.hparams.ctc_weight * CTC_loss
+                + self.hparams.ce_weight * CE_loss
+                + (1 - (self.hparams.ctc_weight + self.hparams.ce_weight))
+                * loss_transducer
+            )
+        else:
+            loss = self.hparams.transducer_cost(
+                logits_transducer, phn, wav_lens, phn_lens
+            )
+
+        if stage != sb.Stage.TRAIN:
+            self.per_metrics.append(
+                ids, predicted_phn, phn, target_len=phn_lens
+            )
+
+        return loss
+
+    def on_stage_start(self, stage, epoch=None):
+        "Gets called when a stage (either training, validation, test) starts."
+        if stage != sb.Stage.TRAIN:
+            self.per_metrics = self.hparams.per_stats()
+
+    def on_stage_end(self, stage, stage_loss, epoch=None):
+        """Gets called at the end of a stage."""
+        if stage == sb.Stage.TRAIN:
+            self.train_loss = stage_loss
+        if stage == sb.Stage.VALID and epoch is not None:
+            print("Epoch %d complete" % epoch)
+            print("Train loss: %.2f" % self.train_loss)
+        if stage != sb.Stage.TRAIN:
+            print(stage, "loss: %.2f" % stage_loss)
+            print(stage, "PER: %.2f" % self.per_metrics.summarize("error_rate"))
+
+
+def data_prep(data_folder, hparams):
+    "Creates the datasets and their data processing pipelines."
+
+    # 1. Declarations:
+    train_data = sb.dataio.dataset.DynamicItemDataset.from_json(
+        json_path=data_folder / "../annotation/ASR_train.json",
+        replacements={"data_root": data_folder},
+    )
+    valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(
+        json_path=data_folder / "../annotation/ASR_dev.json",
+        replacements={"data_root": data_folder},
+    )
+    datasets = [train_data, valid_data]
+    label_encoder = sb.dataio.encoder.CTCTextEncoder()
+    label_encoder.expect_len(hparams["num_labels"])
+
+    # 2. Define audio pipeline:
+    @sb.utils.data_pipeline.takes("wav")
+    @sb.utils.data_pipeline.provides("sig")
+    def audio_pipeline(wav):
+        sig = sb.dataio.dataio.read_audio(wav)
+        return sig
+
+    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
+
+    # 3. Define text pipeline:
+    @sb.utils.data_pipeline.takes("phn")
+    @sb.utils.data_pipeline.provides(
+        "phn_list", "phn_encoded", "phn_encoded_bos", "phn_encoded_eos"
+    )
+    def text_pipeline(phn):
+        phn_list = phn.strip().split()
+        yield phn_list
+        phn_encoded = label_encoder.encode_sequence_torch(phn_list)
+        yield phn_encoded
+        phn_encoded_bos = label_encoder.prepend_bos_index(phn_encoded).long()
+        yield phn_encoded_bos
+        phn_encoded_eos = label_encoder.append_eos_index(phn_encoded).long()
+        yield phn_encoded_eos
+
+    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
+
+    # 3. Fit encoder:
+    # NOTE: In this minimal example, also update from valid data
+    label_encoder.insert_blank(index=hparams["blank_index"])
+    label_encoder.insert_bos_eos(
+        bos_index=hparams["bos_index"], eos_label="<bos>"
+    )
+    label_encoder.update_from_didataset(train_data, output_key="phn_list")
+    label_encoder.update_from_didataset(valid_data, output_key="phn_list")
+
+    # 4. Set output:
+    sb.dataio.dataset.set_output_keys(
+        datasets,
+        ["id", "sig", "phn_encoded", "phn_encoded_bos", "phn_encoded_eos"],
+    )
+    return train_data, valid_data, label_encoder
+
+
+def main(device="cpu"):
+    experiment_dir = pathlib.Path(__file__).resolve().parent
+    hparams_file = experiment_dir / "hyperparams.yaml"
+    data_folder = "../../samples/ASR"
+    data_folder = (experiment_dir / data_folder).resolve()
+
+    # Load model hyper parameters:
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin)
+
+    # Dataset creation
+    train_data, valid_data, label_encoder = data_prep(data_folder, hparams)
+
+    # Trainer initialization
+    transducer_brain = ConformerTransducerBrain(
+        hparams["modules"],
+        hparams["opt_class"],
+        hparams,
+        run_opts={"device": device},
+    )
+
+    # Training/validation loop
+    transducer_brain.fit(
+        range(hparams["number_of_epochs"]),
+        train_data,
+        valid_data,
+        train_loader_kwargs=hparams["dataloader_options"],
+        valid_loader_kwargs=hparams["dataloader_options"],
+    )
+    # Evaluation is run separately (now just evaluating on valid data)
+    transducer_brain.evaluate(valid_data)
+
+    # Check that model overfits for integration test
+    assert transducer_brain.train_loss < 90.0
+
+
+if __name__ == "__main__":
+    main()
+
+
+def test_error(device):
+    main(device)
diff --git a/tests/integration/ASR_ConformerTransducer_streaming/hyperparams.yaml b/tests/integration/ASR_ConformerTransducer_streaming/hyperparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a988432a34a0d5c34b155ba04b173a5bfaa4744f
--- /dev/null
+++ b/tests/integration/ASR_ConformerTransducer_streaming/hyperparams.yaml
@@ -0,0 +1,268 @@
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 3407
+__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
+
+# Training parameters
+# To make Transformers converge, the global bath size should be large enough.
+# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor.
+# Empirically, we found that this value should be >= 128.
+# Please, set your parameters accordingly.
+number_of_epochs: 30
+lr: 1.0
+ctc_weight: 0.3 # Multitask with CTC for the encoder (0.0 = disabled)
+ce_weight: 0.0 # Multitask with CE for the decoder (0.0 = disabled)
+max_grad_norm: 5.0
+loss_reduction: 'batchmean'
+precision: fp32 # bf16, fp16 or fp32
+
+# Feature parameters
+sample_rate: 16000
+n_fft: 512
+n_mels: 80
+win_length: 32
+
+# Streaming & dynamic chunk training options
+# At least for the current architecture on LibriSpeech, we found out that
+# non-streaming accuracy is very similar between `streaming: True` and
+# `streaming: False`.
+streaming: True  # controls all Dynamic Chunk Training & chunk size & left context mechanisms
+
+# Configuration for Dynamic Chunk Training.
+# In this model, a chunk is roughly equivalent to 40ms of audio.
+dynchunktrain_config_sampler: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfigRandomSampler # yamllint disable-line rule:line-length
+   chunkwise_prob: 0.6 # Probability during a batch to limit attention and sample a random chunk size in the following range
+   chunk_size_min: 2 # Minimum chunk size (if in a DynChunkTrain batch)
+   chunk_size_max: 8 # Maximum chunk size (if in a DynChunkTrain batch)
+   limited_left_context_prob: 0.75 # If in a DynChunkTrain batch, the probability during a batch to restrict left context to a random number of chunks
+   left_context_chunks_min: 1 # Minimum left context size (in # of chunks)
+   left_context_chunks_max: 8 # Maximum left context size (in # of chunks)
+   # If you specify a valid/test config, you can optionally have evaluation be
+   # done with a specific DynChunkTrain configuration.
+   # valid_config: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfig
+   #    chunk_size: 24
+   #    left_context_size: 16
+   # test_config: ...
+
+dataloader_options:
+   batch_size: 1
+
+# Model parameters
+# Transformer
+d_model: 64
+joint_dim: 128
+nhead: 2
+num_encoder_layers: 1
+num_decoder_layers: 0
+d_ffn: 128
+transformer_dropout: 0.1
+activation: !name:torch.nn.GELU
+output_neurons: !ref <num_labels>
+dec_dim: 128
+dec_emb_dropout: 0.2
+dec_dropout: 0.1
+
+# Decoding parameters
+# Special tokens and labels
+blank_index: 0
+bos_index: 1
+pad_index: 1
+num_labels: 45
+beam_size: 10
+nbest: 1
+
+# If True uses torchaudio loss. Otherwise, the numba one
+use_torchaudio: True
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+   limit: !ref <number_of_epochs>
+
+normalize: !new:speechbrain.processing.features.InputNormalization
+   norm_type: global
+   update_until_epoch: 4
+
+compute_features: !new:speechbrain.lobes.features.Fbank
+   sample_rate: !ref <sample_rate>
+   n_fft: !ref <n_fft>
+   n_mels: !ref <n_mels>
+   win_length: !ref <win_length>
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+   orig_freq: !ref <sample_rate>
+   speeds: !ref <speed_changes>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+   parallel_augment: False
+   concat_original: False
+   repeat_augment: 1
+   shuffle_augmentations: False
+   min_augmentations: 1
+   max_augmentations: 1
+   augment_prob: 1.0
+   augmentations: [!ref <speed_perturb>]
+
+
+# Time Drop
+time_drop_length_low: 15  # Min length for temporal chunk to drop in spectrogram
+time_drop_length_high: 25  # Max length for temporal chunk to drop in spectrogram
+time_drop_count_low: 5  # Min number of chunks to drop in time in the spectrogram
+time_drop_count_high: 5  # Max number of chunks to drop in time in the spectrogram
+time_drop_replace: "zeros"  # Method of dropping chunks
+
+time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: !ref <time_drop_length_low>
+   drop_length_high: !ref <time_drop_length_high>
+   drop_count_low: !ref <time_drop_count_low>
+   drop_count_high: !ref <time_drop_count_high>
+   replace: !ref <time_drop_replace>
+   dim: 1
+
+# Frequency Drop
+freq_drop_length_low: 25  # Min length for chunks to drop in frequency in the spectrogram
+freq_drop_length_high: 35  # Max length for chunks to drop in frequency in the spectrogram
+freq_drop_count_low: 2  # Min number of chunks to drop in frequency in the spectrogram
+freq_drop_count_high: 2  # Max number of chunks to drop in frequency in the spectrogram
+freq_drop_replace: "zeros"  # Method of dropping chunks
+
+freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
+   drop_length_low: !ref <freq_drop_length_low>
+   drop_length_high: !ref <freq_drop_length_high>
+   drop_count_low: !ref <freq_drop_count_low>
+   drop_count_high: !ref <freq_drop_count_high>
+   replace: !ref <freq_drop_replace>
+   dim: 2
+
+# Time warp
+time_warp_window: 5  # Length of time warping window
+time_warp_mode: "bicubic"  # Time warping method
+
+time_warp: !new:speechbrain.augment.freq_domain.Warping
+   warp_window: !ref <time_warp_window>
+   warp_mode: !ref <time_warp_mode>
+   dim: 1
+
+fea_augment: !new:speechbrain.augment.augmenter.Augmenter
+   parallel_augment: False
+   concat_original: False
+   repeat_augment: 1
+   shuffle_augmentations: False
+   min_augmentations: 3
+   max_augmentations: 3
+   augment_prob: 1.0
+   augmentations: [
+      !ref <time_drop>,
+      !ref <freq_drop>,
+      !ref <time_warp>]
+
+CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
+   input_shape: (8, 10, 80)
+   num_blocks: 2
+   num_layers_per_block: 1
+   out_channels: (64, 32)
+   kernel_sizes: (3, 3)
+   strides: (2, 2)
+   residuals: (False, False)
+
+Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
+   input_size: 640
+   tgt_vocab: !ref <output_neurons>
+   d_model: !ref <d_model>
+   nhead: !ref <nhead>
+   num_encoder_layers: !ref <num_encoder_layers>
+   num_decoder_layers: !ref <num_decoder_layers>
+   d_ffn: !ref <d_ffn>
+   dropout: !ref <transformer_dropout>
+   activation: !ref <activation>
+   encoder_module: conformer
+   attention_type: RelPosMHAXL
+   normalize_before: True
+   causal: False
+
+# We must call an encoder wrapper so the decoder isn't run (we don't have any)
+enc: !new:speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper
+   transformer: !ref <Transformer>
+
+# For MTL CTC over the encoder
+proj_ctc: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <joint_dim>
+   n_neurons: !ref <output_neurons>
+
+# Define some projection layers to make sure that enc and dec
+# output dim are the same before joining
+proj_enc: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <d_model>
+   n_neurons: !ref <joint_dim>
+   bias: False
+
+proj_dec: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <dec_dim>
+   n_neurons: !ref <joint_dim>
+   bias: False
+
+# Uncomment for MTL with CTC
+ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
+   blank_index: !ref <blank_index>
+   reduction: !ref <loss_reduction>
+
+emb: !new:speechbrain.nnet.embedding.Embedding
+   num_embeddings: !ref <output_neurons>
+   consider_as_one_hot: True
+   blank_id: !ref <blank_index>
+
+dec: !new:speechbrain.nnet.RNN.LSTM
+   input_shape: [null, null, !ref <output_neurons> - 1]
+   hidden_size: !ref <dec_dim>
+   num_layers: 1
+   re_init: True
+
+# For MTL
+ce_cost: !name:speechbrain.nnet.losses.nll_loss
+   label_smoothing: 0.1
+
+Tjoint: !new:speechbrain.nnet.transducer.transducer_joint.Transducer_joint
+   joint: sum # joint [sum | concat]
+   nonlinearity: !ref <activation>
+
+transducer_lin: !new:speechbrain.nnet.linear.Linear
+   input_size: !ref <joint_dim>
+   n_neurons: !ref <output_neurons>
+   bias: False
+
+log_softmax: !new:speechbrain.nnet.activations.Softmax
+   apply_log: True
+
+transducer_cost: !name:speechbrain.nnet.losses.transducer_loss
+   blank_index: !ref <blank_index>
+   use_torchaudio: !ref <use_torchaudio>
+
+modules:
+   CNN: !ref <CNN>
+   enc: !ref <enc>
+   emb: !ref <emb>
+   dec: !ref <dec>
+   Tjoint: !ref <Tjoint>
+   transducer_lin: !ref <transducer_lin>
+   normalize: !ref <normalize>
+   proj_ctc: !ref <proj_ctc>
+   proj_dec: !ref <proj_dec>
+   proj_enc: !ref <proj_enc>
+
+model: !new:torch.nn.ModuleList
+   - [!ref <CNN>, !ref <enc>, !ref <emb>, !ref <dec>, !ref <proj_enc>, !ref <proj_dec>, !ref <proj_ctc>, !ref <transducer_lin>]
+
+Greedysearcher: !new:speechbrain.decoders.transducer.TransducerBeamSearcher
+   decode_network_lst: [!ref <emb>, !ref <dec>, !ref <proj_dec>]
+   tjoint: !ref <Tjoint>
+   classifier_network: [!ref <transducer_lin>]
+   blank_id: !ref <blank_index>
+   beam_size: 1
+   nbest: 1
+
+opt_class: !name:torch.optim.Adadelta
+   lr: !ref <lr>
+
+error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
+
+per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
diff --git a/tests/integration/ASR_Transducer/example_asr_transducer_experiment.py b/tests/integration/ASR_Transducer/example_asr_transducer_experiment.py
index 01efb0304d90556749bc053b48a2da75d2970384..b2e4e28b3e70cc6f610fcc9208bb05fe155f5bef 100755
--- a/tests/integration/ASR_Transducer/example_asr_transducer_experiment.py
+++ b/tests/integration/ASR_Transducer/example_asr_transducer_experiment.py
@@ -38,7 +38,7 @@ class TransducerBrain(sb.Brain):
         if stage == sb.Stage.TRAIN:
             return outputs, lens
         else:
-            hyps, scores, _, _ = self.hparams.searcher(TN_output)
+            hyps, _, _, _ = self.hparams.searcher(TN_output)
             return outputs, lens, hyps
 
     def compute_objectives(self, predictions, batch, stage):
diff --git a/tests/integration/ASR_seq2seq/example_asr_seq2seq_experiment.py b/tests/integration/ASR_seq2seq/example_asr_seq2seq_experiment.py
index 65f7a01c3bca080e03463222328239fe189d1222..821a6de4ad98ce332b56cb60441680fb9a068894 100755
--- a/tests/integration/ASR_seq2seq/example_asr_seq2seq_experiment.py
+++ b/tests/integration/ASR_seq2seq/example_asr_seq2seq_experiment.py
@@ -27,18 +27,15 @@ class seq2seqBrain(sb.Brain):
         logits = self.modules.lin(h)
         outputs = self.hparams.softmax(logits)
 
+        seq = None
         if stage != sb.Stage.TRAIN:
-            seq, _ = self.hparams.searcher(x, wav_lens)
-            return outputs, seq
+            seq, _, _, _ = self.hparams.searcher(x, wav_lens)
 
-        return outputs
+        return outputs, seq
 
     def compute_objectives(self, predictions, batch, stage):
         "Given the network predictions and targets computed the NLL loss."
-        if stage == sb.Stage.TRAIN:
-            outputs = predictions
-        else:
-            outputs, seq = predictions
+        outputs, seq = predictions
 
         ids = batch.id
         phns, phn_lens = batch.phn_encoded_eos
@@ -50,22 +47,6 @@ class seq2seqBrain(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Fits train batches"""
-        preds = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(preds, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage=sb.Stage.TEST):
-        """Evaluates test batches"""
-        out = self.compute_forward(batch, stage)
-        loss = self.compute_objectives(out, batch, stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch=None):
         "Gets called when a stage (either training, validation, test) starts."
         if stage != sb.Stage.TRAIN:
diff --git a/tests/integration/G2P/example_g2p.py b/tests/integration/G2P/example_g2p.py
index 7a5b70f965fa970bd7a129d30ceda3beda67eb48..6345f94a75c31e22c1b48b9d07554a36eaf7a3d2 100755
--- a/tests/integration/G2P/example_g2p.py
+++ b/tests/integration/G2P/example_g2p.py
@@ -27,18 +27,15 @@ class seq2seqBrain(sb.Brain):
         logits = self.modules.lin(h)
         outputs = self.hparams.softmax(logits)
 
+        seq = None
         if stage != sb.Stage.TRAIN:
-            seq, _ = self.hparams.searcher(x, char_lens)
-            return outputs, seq
+            seq, _, _, _ = self.hparams.searcher(x, char_lens)
 
-        return outputs
+        return outputs, seq
 
     def compute_objectives(self, predictions, batch, stage):
         "Given the network predictions and targets computed the NLL loss."
-        if stage == sb.Stage.TRAIN:
-            outputs = predictions
-        else:
-            outputs, seq = predictions
+        outputs, seq = predictions
 
         phns, phn_lens = batch.phn_encoded_eos
         loss = self.hparams.compute_cost(outputs, phns, length=phn_lens)
diff --git a/tests/integration/augmentation/hyperparams.yaml b/tests/integration/augmentation/hyperparams.yaml
index 454465441878fa41b6742ea3dfbb78f3e9b0d70c..8300664de4ecc0d59a4a8d0e06d445a0425651c2 100644
--- a/tests/integration/augmentation/hyperparams.yaml
+++ b/tests/integration/augmentation/hyperparams.yaml
@@ -11,18 +11,13 @@ sample_data: !new:speechbrain.dataio.legacy.ExtendedCSVDataset
     replacements:
         data_folder: !ref <data_folder>/single-mic
 
-add_babble: !new:speechbrain.processing.speech_augmentation.AddBabble
-    speaker_count: 4  # Must set batch size to 5 or more
-    snr_low: 0
-    snr_high: 0
-
-add_reverb: !new:speechbrain.processing.speech_augmentation.AddReverb
+add_reverb: !new:speechbrain.augment.speech_augment.AddReverb
     csv_file: !ref <data_folder>/annotation/RIRs.csv
     sorting: descending
     replacements:
         rir_folder: !ref <data_folder>/RIRs
 
-add_noise: !new:speechbrain.processing.speech_augmentation.AddNoise
+add_noise: !new:speechbrain.augment.speech_augment.AddNoise
     csv_file: !ref <data_folder>/annotation/noise.csv
     sorting: descending
     snr_low: 0
@@ -32,14 +27,14 @@ add_noise: !new:speechbrain.processing.speech_augmentation.AddNoise
     replacements:
         noise_folder: !ref <data_folder>/noise
 
-drop_freq: !new:speechbrain.processing.speech_augmentation.DropFreq
+drop_freq: !new:speechbrain.augment.speech_augment.DropFreq
     drop_freq_low: 0.5
     drop_freq_high: 0.5
     drop_count_low: 1
     drop_count_high: 1
     drop_width: 0.05
 
-drop_chunk: !new:speechbrain.processing.speech_augmentation.DropChunk
+drop_chunk: !new:speechbrain.augment.speech_augment.DropChunk
     drop_length_low: 1000
     drop_length_high: 1000
     drop_count_low: 1
@@ -47,10 +42,10 @@ drop_chunk: !new:speechbrain.processing.speech_augmentation.DropChunk
     drop_start: 1000
     drop_end: 2000
 
-do_clip: !new:speechbrain.processing.speech_augmentation.DoClip
+do_clip: !new:speechbrain.augment.speech_augment.DoClip
     clip_low: 0.01
     clip_high: 0.01
 
-speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
+speed_perturb: !new:speechbrain.augment.speech_augment.SpeedPerturb
     orig_freq: !ref <sample_rate>
     speeds: [90]
diff --git a/tests/integration/enhance_GAN/hyperparams.yaml b/tests/integration/enhance_GAN/hyperparams.yaml
index 3591924d8525bbe66b89eb26b43b9546ba15ac86..055a99a4e2d363c10c242df5d92ed1130db5f117 100644
--- a/tests/integration/enhance_GAN/hyperparams.yaml
+++ b/tests/integration/enhance_GAN/hyperparams.yaml
@@ -11,7 +11,7 @@ dataloader_options:
 
 models: !include:models.yaml
 
-add_noise: !new:speechbrain.processing.speech_augmentation.AddNoise
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
 
 modules:
     generator: !ref <models[generator]>
diff --git a/tests/integration/sampling/example_sorting.py b/tests/integration/sampling/example_sorting.py
index 0f856280d19892f4b6c8fd8b956c91fab201bc27..ebd578c009ede11d7ad3b83dc5781b0a6d09880b 100644
--- a/tests/integration/sampling/example_sorting.py
+++ b/tests/integration/sampling/example_sorting.py
@@ -153,7 +153,7 @@ def recipe(device="cpu", yaml_file="hyperparams.yaml", run_opts=None):
     if run_opts is None:
         run_opts = {}
     else:
-        hparams["rank"] = run_opts["local_rank"]
+        hparams["rank"] = os.environ["RANK"]
     run_opts["device"] = device
 
     ctc_brain = SamplingBrain(
@@ -186,13 +186,13 @@ def ddp_recipes(rank, size, backend="gloo"):
     """ Initialize the distributed environment. """
     os.environ["WORLD_SIZE"] = f"{size}"
     os.environ["RANK"] = f"{rank}"
+    os.environ["LOCAL_RANK"] = f"{rank}"
     os.environ["MASTER_ADDR"] = "127.0.0.1"
     os.environ["MASTER_PORT"] = "29500"
 
     run_opts = dict()
     run_opts["distributed_launch"] = True
     run_opts["distributed_backend"] = backend
-    run_opts["local_rank"] = rank
 
     sb.utils.distributed.ddp_init_group(run_opts)
 
diff --git a/tests/recipes/AISHELL-1.csv b/tests/recipes/AISHELL-1.csv
index b3fe65c9f18527617c3789ad4de3a52058741b62..3bf402c6c98d5ceebea9d70387a91283194af223 100644
--- a/tests/recipes/AISHELL-1.csv
+++ b/tests/recipes/AISHELL-1.csv
@@ -1,7 +1,7 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR,AISHELL-1,recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py,recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml,recipes/AISHELL-1/ASR/CTC/aishell_prepare.py,recipes/AISHELL-1/ASR/CTC/README.md,https://www.dropbox.com/sh/e4bth1bylk7c6h8/AADFq3cWzBBKxuDv09qjvUMta?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-ctc-aishell,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.csv --valid_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,train_with_wav2vec.py,env.log,cer.txt,hyperparams.yaml]"
-ASR,AISHELL-1,recipes/AISHELL-1/ASR/seq2seq/train.py,recipes/AISHELL-1/ASR/seq2seq/hparams/train.yaml,recipes/AISHELL-1/ASR/seq2seq/aishell_prepare.py,recipes/AISHELL-1/ASR/seq2seq/README.md,https://www.dropbox.com/sh/kefuzzf6jaljqbr/AADBRWRzHz74GCMDqJY9BES4a?dl=0,,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.csv --valid_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=42 --dnn_neurons=128 --dec_neurons=64 --data_folder_rirs=tests/tmp/,"file_exists=[train_log.txt,log.txt,env.log,train.py,cer.txt,hyperparams.yaml,save/tokenizer/tokenizer.ckpt]"
-ASR,AISHELL-1,recipes/AISHELL-1/ASR/transformer/train.py,recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer.yaml,recipes/AISHELL-1/ASR/transformer/aishell_prepare.py,recipes/AISHELL-1/ASR/transformer/README.md,https://www.dropbox.com/sh/tp6tjmysorgvsr4/AAD7KNqi1ot0gR4N406JbKM6a?dl=0,https://huggingface.co/speechbrain/asr-transformer-aishell,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.csv --valid_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=42 --d_model=64 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1 --data_folder_rirs=tests/tmp/,"file_exists=[train_log.txt,log.txt,env.log,train.py,cer.txt,hyperparams.yaml,save/tokenizer.ckpt]"
-ASR,AISHELL-1,recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py,recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer_with_wav2vect.yaml,recipes/AISHELL-1/ASR/transformer/aishell_prepare.py,recipes/AISHELL-1/ASR/transformer/README.md,https://www.dropbox.com/sh/tp6tjmysorgvsr4/AAD7KNqi1ot0gR4N406JbKM6a?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-transformer-aishell,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.csv --valid_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=42 --d_model=64 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint --data_folder_rirs=tests/tmp/,"file_exists=[train_log.txt,log.txt,env.log,train_with_wav2vect.py,cer.txt,hyperparams.yaml,save/tokenizer.ckpt]"
-Tokenizer,AISHELL-1,recipes/AISHELL-1/Tokenizer/train.py,recipes/AISHELL-1/Tokenizer/hparams/tokenizer_bpe5000.yaml,recipes/AISHELL-1/Tokenizer/aishell_prepare.py,recipes/AISHELL-1/Tokenizer/README.md,https://www.dropbox.com/sh/gh1qyf833t7h3op/AADG0y1bGGIL4yufsXtuBgXma?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]"
-Tokenizer,AISHELL-1,recipes/AISHELL-1/Tokenizer/train.py,recipes/AISHELL-1/Tokenizer/hparams/train_transformer_tokenizer_bpe5000.yaml,recipes/AISHELL-1/Tokenizer/aishell_prepare.py,recipes/AISHELL-1/Tokenizer/README.md,https://www.dropbox.com/sh/gh1qyf833t7h3op/AADG0y1bGGIL4yufsXtuBgXma?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR,AISHELL-1,recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py,recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml,recipes/AISHELL-1/ASR/CTC/aishell_prepare.py,recipes/AISHELL-1/ASR/CTC/README.md,https://www.dropbox.com/sh/e4bth1bylk7c6h8/AADFq3cWzBBKxuDv09qjvUMta?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-ctc-aishell,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.csv --valid_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,train_with_wav2vec.py,env.log,cer.txt,hyperparams.yaml]",Test-CER=5.06
+ASR,AISHELL-1,recipes/AISHELL-1/ASR/seq2seq/train.py,recipes/AISHELL-1/ASR/seq2seq/hparams/train.yaml,recipes/AISHELL-1/ASR/seq2seq/aishell_prepare.py,recipes/AISHELL-1/ASR/seq2seq/README.md,https://www.dropbox.com/sh/kefuzzf6jaljqbr/AADBRWRzHz74GCMDqJY9BES4a?dl=0,,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.csv --valid_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=42 --dnn_neurons=128 --dec_neurons=64,"file_exists=[train_log.txt,log.txt,env.log,train.py,cer.txt,hyperparams.yaml,save/tokenizer/tokenizer.ckpt]",Test-CER=7.51
+ASR,AISHELL-1,recipes/AISHELL-1/ASR/transformer/train.py,recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer.yaml,recipes/AISHELL-1/ASR/transformer/aishell_prepare.py,recipes/AISHELL-1/ASR/transformer/README.md,https://www.dropbox.com/sh/tp6tjmysorgvsr4/AAD7KNqi1ot0gR4N406JbKM6a?dl=0,https://huggingface.co/speechbrain/asr-transformer-aishell,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.csv --valid_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=42 --d_model=64 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1,"file_exists=[train_log.txt,log.txt,env.log,train.py,cer.txt,hyperparams.yaml,save/tokenizer.ckpt]",Test-CER=6.04
+ASR,AISHELL-1,recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py,recipes/AISHELL-1/ASR/transformer/hparams/train_ASR_transformer_with_wav2vect.yaml,recipes/AISHELL-1/ASR/transformer/aishell_prepare.py,recipes/AISHELL-1/ASR/transformer/README.md,https://www.dropbox.com/sh/tp6tjmysorgvsr4/AAD7KNqi1ot0gR4N406JbKM6a?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-transformer-aishell,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.csv --valid_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=42 --d_model=64 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,env.log,train_with_wav2vect.py,cer.txt,hyperparams.yaml,save/tokenizer.ckpt]",Test-CER=5.58
+Tokenizer,AISHELL-1,recipes/AISHELL-1/Tokenizer/train.py,recipes/AISHELL-1/Tokenizer/hparams/tokenizer_bpe5000.yaml,recipes/AISHELL-1/Tokenizer/aishell_prepare.py,recipes/AISHELL-1/Tokenizer/README.md,https://www.dropbox.com/sh/gh1qyf833t7h3op/AADG0y1bGGIL4yufsXtuBgXma?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]",
+Tokenizer,AISHELL-1,recipes/AISHELL-1/Tokenizer/train.py,recipes/AISHELL-1/Tokenizer/hparams/train_transformer_tokenizer_bpe5000.yaml,recipes/AISHELL-1/Tokenizer/aishell_prepare.py,recipes/AISHELL-1/Tokenizer/README.md,https://www.dropbox.com/sh/gh1qyf833t7h3op/AADG0y1bGGIL4yufsXtuBgXma?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]",
diff --git a/tests/recipes/Aishell1Mix.csv b/tests/recipes/Aishell1Mix.csv
index fed0cfa5a713867de3ceb178594150de0c4fa2d0..1e1df29b725bce96249a9f29d7619393aaa6e0d9 100644
--- a/tests/recipes/Aishell1Mix.csv
+++ b/tests/recipes/Aishell1Mix.csv
@@ -1,5 +1,5 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Separation,Aishell1Mix,recipes/Aishell1Mix/separation/train.py,recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2-wham.yaml,recipes/Aishell1Mix/separation/dynamic_mixing.py recipes/Aishell1Mix/prepare_data.py recipes/LibriMix/separation/train.py,recipes/Aishell1Mix/separation/README.md,https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0,,--data_folder_nspks=tests/samples/separation --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --N_encoder_out=32 --out_channels=64 --d_ffn=128,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,Aishell1Mix,recipes/Aishell1Mix/separation/train.py,recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2.yaml,recipes/Aishell1Mix/separation/dynamic_mixing.py recipes/Aishell1Mix/prepare_data.py recipes/LibriMix/separation/train.py,recipes/Aishell1Mix/separation/README.md,https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0,,--data_folder_nspks=tests/samples/separation --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --N_encoder_out=32 --out_channels=64 --d_ffn=128,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,Aishell1Mix,recipes/Aishell1Mix/separation/train.py,recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3-wham.yaml,recipes/Aishell1Mix/separation/dynamic_mixing.py recipes/Aishell1Mix/prepare_data.py recipes/LibriMix/separation/train.py,recipes/Aishell1Mix/separation/README.md,https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0,,--data_folder_nspks=tests/samples/separation --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --N_encoder_out=32 --out_channels=64 --d_ffn=128,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,Aishell1Mix,recipes/Aishell1Mix/separation/train.py,recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3.yaml,recipes/Aishell1Mix/separation/dynamic_mixing.py recipes/Aishell1Mix/prepare_data.py recipes/LibriMix/separation/train.py,recipes/Aishell1Mix/separation/README.md,https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0,,--data_folder_nspks=tests/samples/separation --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --N_encoder_out=32 --out_channels=64 --d_ffn=128,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Separation,Aishell1Mix,recipes/Aishell1Mix/separation/train.py,recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2-wham.yaml,recipes/Aishell1Mix/separation/dynamic_mixing.py recipes/Aishell1Mix/prepare_data.py recipes/LibriMix/separation/train.py,recipes/Aishell1Mix/separation/README.md,https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0,,--data_folder_nspks=tests/samples/separation --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --N_encoder_out=32 --out_channels=64 --d_ffn=128 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Separation,Aishell1Mix,recipes/Aishell1Mix/separation/train.py,recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix2.yaml,recipes/Aishell1Mix/separation/dynamic_mixing.py recipes/Aishell1Mix/prepare_data.py recipes/LibriMix/separation/train.py,recipes/Aishell1Mix/separation/README.md,https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0,,--data_folder_nspks=tests/samples/separation --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --N_encoder_out=32 --out_channels=64 --d_ffn=128 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=13.4dB
+Separation,Aishell1Mix,recipes/Aishell1Mix/separation/train.py,recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3-wham.yaml,recipes/Aishell1Mix/separation/dynamic_mixing.py recipes/Aishell1Mix/prepare_data.py recipes/LibriMix/separation/train.py,recipes/Aishell1Mix/separation/README.md,https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0,,--data_folder_nspks=tests/samples/separation --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --N_encoder_out=32 --out_channels=64 --d_ffn=128 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Separation,Aishell1Mix,recipes/Aishell1Mix/separation/train.py,recipes/Aishell1Mix/separation/hparams/sepformer-aishell1mix3.yaml,recipes/Aishell1Mix/separation/dynamic_mixing.py recipes/Aishell1Mix/prepare_data.py recipes/LibriMix/separation/train.py,recipes/Aishell1Mix/separation/README.md,https://www.dropbox.com/sh/6x9356yuybj8lue/AABPlpS03Vcci_E3jA69oKoXa?dl=0,,--data_folder_nspks=tests/samples/separation --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --N_encoder_out=32 --out_channels=64 --d_ffn=128 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=11.2dB
diff --git a/tests/recipes/BinauralWSJ0Mix.csv b/tests/recipes/BinauralWSJ0Mix.csv
index 6da2621e0ddf88715449ebeed2fb3c095577ccae..a530192d05d8d151dab745490aa75a8db799ae76 100644
--- a/tests/recipes/BinauralWSJ0Mix.csv
+++ b/tests/recipes/BinauralWSJ0Mix.csv
@@ -1,6 +1,6 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-cross.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-independent.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-noise.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-reverb.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-cross.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=12.39dB
+Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-independent.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=11.90dB
+Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-noise.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=18.25dB
+Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel-reverb.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=6.95dB
+Separation,BinauralWSJ0Mix,recipes/BinauralWSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/hparams/convtasnet-parallel.yaml,recipes/BinauralWSJ0Mix/separation/dynamic_mixing.py recipes/BinauralWSJ0Mix/prepare_data.py recipes/WSJ0Mix/separation/train.py,recipes/BinauralWSJ0Mix/separation/README.md,https://www.dropbox.com/sh/i7fhu7qswjb84gw/AABsX1zP-GOTmyl86PtU8GGua?dl=0,,--data_folder=tests/samples/stereo --wsj_root=tests/samples/stereo --base_folder_dm=tests/samples/stereo --train_data=tests/samples/annotation/separation_train_stereo.csv --valid_data=tests/samples/annotation/separation_dev_stereo.csv --test_data=tests/samples/annotation/separation_dev_stereo.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=16.93dB
diff --git a/tests/recipes/CVSS.csv b/tests/recipes/CVSS.csv
new file mode 100644
index 0000000000000000000000000000000000000000..a3a0c79bf6862c88fd2ef81fad701b8722659e36
--- /dev/null
+++ b/tests/recipes/CVSS.csv
@@ -0,0 +1,2 @@
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,test_download,performance
+S2ST,CVSS,recipes/CVSS/S2ST/train.py,recipes/CVSS/S2ST/hparams/train_fr-en.yaml,recipes/CVSS/cvss_prepare.py,recipes/CVSS/S2ST/README.md, https://www.dropbox.com/sh/woz4i1p8pkfkqhf/AACmOvr3sS7p95iXl3twCj_xa?dl=0, , --epochs=1 --train_json=tests/download/S2ST/S2ST_train.json --valid_json=tests/download/S2ST/S2ST_train.json --test_json=tests/download/S2ST/S2ST_train.json --valid_small_json=tests/download/S2ST/S2ST_train.json --codes_folder=tests/download/S2ST/S2ST/codes --skip_prep=True --skip_extract=True --sample_rate=16000 --src_data_folder=null --tgt_data_folder=null --dynamic_batching=False --train_batch_size=2 --test_bs=1 --evaluation_interval=1,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,samples/1/bleu.txt]","download_file('https://www.dropbox.com/scl/fi/gv5m6wyl3blis3eiw9emh/S2ST.zip?rlkey=54quqf4yiugitvespz939hh6u&dl=1', 'tests/download/S2ST.zip', unpack=True, dest_unpack='tests/download/', write_permissions=True)",Test-sacrebleu=24.47
diff --git a/tests/recipes/CommonLanguage.csv b/tests/recipes/CommonLanguage.csv
index 29cfbe0d32da24226f89ab05ed470b3f38e78328..52dcca479c973a27a70e7d3f11ac79c44747513d 100644
--- a/tests/recipes/CommonLanguage.csv
+++ b/tests/recipes/CommonLanguage.csv
@@ -1,2 +1,2 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Language-id,CommonLanguage,recipes/CommonLanguage/lang_id/train.py,recipes/CommonLanguage/lang_id/hparams/train_ecapa_tdnn.yaml,recipes/CommonLanguage/lang_id/common_language_prepare.py,recipes/CommonLanguage/lang_id/README.md,https://www.dropbox.com/sh/1fxpzyv67ouwd2c/AAAeMUWYP2f1ycpE1Lp1CwEla?dl=0,https://huggingface.co/speechbrain/lang-id-commonlanguage_ecapa,--data_folder=tests/samples/lang/ --rir_folder=tests/tmp --train_csv=tests/samples/annotation/multi_annotation.csv --dev_csv=tests/samples/annotation/multi_annotation.csv --test_csv=tests/samples/annotation/multi_annotation.csv --skip_prep=True --number_of_epochs=2 --n_languages=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/language_encoder.txt,save/embedding_model.ckpt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Language-id,CommonLanguage,recipes/CommonLanguage/lang_id/train.py,recipes/CommonLanguage/lang_id/hparams/train_ecapa_tdnn.yaml,recipes/CommonLanguage/lang_id/common_language_prepare.py,recipes/CommonLanguage/lang_id/README.md,https://www.dropbox.com/sh/1fxpzyv67ouwd2c/AAAeMUWYP2f1ycpE1Lp1CwEla?dl=0,https://huggingface.co/speechbrain/lang-id-commonlanguage_ecapa,--data_folder=tests/samples/lang/ --train_csv=tests/samples/annotation/multi_annotation.csv --dev_csv=tests/samples/annotation/multi_annotation.csv --test_csv=tests/samples/annotation/multi_annotation.csv --skip_prep=True --number_of_epochs=2 --n_languages=2 --drop_last=False,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/language_encoder.txt,save/embedding_model.ckpt]",Error=15.1%
diff --git a/tests/recipes/CommonVoice.csv b/tests/recipes/CommonVoice.csv
index 911b9b27276ee9e5e1cf1acd81a75240cd399db0..3a787093a29c5ec259f92a848344056882e5e66c 100644
--- a/tests/recipes/CommonVoice.csv
+++ b/tests/recipes/CommonVoice.csv
@@ -1,24 +1,34 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-fr,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_char.model,save/ASR_train.txt,save/27_char.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_it_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_rw_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/vdz7apt16nbq94g/AADI5o23Ll_NmjiPlg9bzPjta?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-de,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/32_char.vocab,save/ASR_train.txt,save/32_char.model]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_de.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-de,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_en.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_fr.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-fr,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_it.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-it,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_rw.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_en_with_wav2vec.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-en,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --emb_size=64 --dec_neurons=128 --beam_size=3 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_fr_with_wav2vec.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --emb_size=64 --dec_neurons=128 --beam_size=3 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_it_with_wav2vec.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-it,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --emb_size=64 --dec_neurons=128 --beam_size=3 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_rw_with_wav2vec.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-rw,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --emb_size=64 --dec_neurons=128 --beam_size=3 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/transducer/train.py,recipes/CommonVoice/ASR/transducer/hparams/train_fr.yaml,recipes/CommonVoice/ASR/transducer/common_voice_prepare.py,recipes/CommonVoice/ASR/transducer/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --rnn_neurons=64 --dnn_neurons=64 --dec_neurons=64 --joint_dim=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train.py,recipes/CommonVoice/ASR/transformer/hparams/train_fr.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --d_model=128 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1,"file_exists=[wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]"
-ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/852eq7pbt6d65ai/AACv4wAzk1pWbDo4fjVKLICYa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]"
-SSL,CommonVoice,recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py,recipes/CommonVoice/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml,recipes/CommonVoice/self-supervised-learning/wav2vec2/common_voice_prepare.py,recipes/CommonVoice/self-supervised-learning/wav2vec2/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --d_model=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_hf_wav2vec2.py,train_log.txt,log.txt,env.log,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/ch10cnbhf1faz3w/AACdHFG65LC6582H0Tet_glTa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-en,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=16.16%
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/0i7esfa8jp3rxpp/AAArdi8IuCRmob2WAS7lg6M4a?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-fr,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=22 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/22_char.model,save/ASR_train.txt,save/22_char.vocab]",Test-WER=9.71%
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_it_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/hthxqzh5boq15rn/AACftSab_FM6EFWWPgHpKw82a?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-it,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=7.99%
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_rw_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/4iax0l4yfry37gn/AABuQ31JY-Sbyi1VlOJfV7haa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-rw,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=22.52%
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/dn7plq4wfsujsi1/AABS1kqB_uqLJVkg-bFkyPpVa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-de,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --output_neurons=24 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/24_char.model,save/ASR_train.txt,save/24_char.vocab]",Test-WER=8.39%
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_ar_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/7tnuqqbr4vy96cc/AAA_5_R0RmqFIiyR0o1nVS4Ia?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-ar,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=28.53%
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_es_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/ejvzgl3d3g8g9su/AACYtbSWbDHvBr06lAb7A4mVa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-es,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=12.67%
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_pt_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/80wucrvijdvao2a/AAD6-SZ2_ZZXmlAjOTw6fVloa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-pt,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=21.69%
+ASR-CTC,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_zh-CN_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/2bikr81vgufoglf/AABMpD0rLIaZBxjtwBHgrNpga?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-zh-CN,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_log.txt,log.txt,wer_test.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/27_unigram.vocab,save/ASR_train.txt,save/27_unigram.model]",Test-WER=23.17%
+ASR-seq2seq,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_de.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/zgatirb118f79ef/AACmjh-D94nNDWcnVI4Ef5K7a?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-de,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=12.25%
+ASR-seq2seq,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_en.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/h8ged0yu3ztypkh/AAAu-12k_Ceg-tTjuZnrg7dza?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-en,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=23.88%
+ASR-seq2seq,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_fr.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/07a5lt21wxp98x5/AABhNwmWFaNFyA734bNZUO03a?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-fr,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=14.88%
+ASR-seq2seq,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_it.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/ss59uu0j5boscvp/AAASsiFhlB1nDWPkFX410bzna?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-it,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=17.02%
+ASR-seq2seq,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_rw.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/i1fv4f8miilqgii/AAB3gE97kmFDA0ISkIDSUW_La?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-rw,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=29.22%
+ASR-seq2seq,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_es.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://www.dropbox.com/sh/r3w0b2tm1p73vft/AADCxdhUwDN6j4PVT9TYe-d5a?dl=0,https://huggingface.co/speechbrain/asr-crdnn-commonvoice-14-es,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --rnn_neurons=128 --dnn_neurons=128 --dec_neurons=128 --emb_size=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,train.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=14.77%
+ASR-transducer,CommonVoice,recipes/CommonVoice/ASR/transducer/train.py,recipes/CommonVoice/ASR/transducer/hparams/train_fr.yaml,recipes/CommonVoice/ASR/transducer/common_voice_prepare.py,recipes/CommonVoice/ASR/transducer/README.md,https://www.dropbox.com/sh/nv2pnpo5n3besn3/AADZ7l41oLt11ZuOE4MqoJhCa?dl=0,https://huggingface.co/speechbrain/asr-transducer-commonvoice-14-fr,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --rnn_neurons=64 --dnn_neurons=64 --dec_neurons=64 --joint_dim=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,train.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=17.58%
+ASR-transducer,CommonVoice,recipes/CommonVoice/ASR/transducer/train.py,recipes/CommonVoice/ASR/transducer/hparams/train_it.yaml,recipes/CommonVoice/ASR/transducer/common_voice_prepare.py,recipes/CommonVoice/ASR/transducer/README.md,https://www.dropbox.com/sh/ksm08x0wwiomrgs/AABnjPePWGPxqIqW7bJHp1jea?dl=0,https://huggingface.co/speechbrain/asr-transducer-commonvoice-14-it,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --rnn_neurons=64 --dnn_neurons=64 --dec_neurons=64 --joint_dim=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,train.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=14.88%
+ASR-transducer,CommonVoice,recipes/CommonVoice/ASR/transducer/train.py,recipes/CommonVoice/ASR/transducer/hparams/train_de.yaml,recipes/CommonVoice/ASR/transducer/common_voice_prepare.py,recipes/CommonVoice/ASR/transducer/README.md,https://www.dropbox.com/sh/jfge6ixbtoje64t/AADeAgL5un0A8uEjPSM84ex8a?dl=0,https://huggingface.co/speechbrain/asr-transducer-commonvoice-14-de,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --rnn_neurons=64 --dnn_neurons=64 --dec_neurons=64 --joint_dim=64 --cnn_channels=[64, 100, 128]","file_exists=[train_log.txt,log.txt,wer_test.txt,train.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=15.25%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train.py,recipes/CommonVoice/ASR/transformer/hparams/train_fr.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/zvu9h9pctksnuvp/AAD1kyS3-N0YtmcoMgjM-_Tba?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --d_model=128 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1,"file_exists=[train_log.txt,log.txt,wer_test.txt,train.py,env.log,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=17.61%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train.py,recipes/CommonVoice/ASR/transformer/hparams/train_it.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/yy8du12jgbkm3qe/AACBHhTCM-cU-oGvAKJ9kTtaa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --d_model=128 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=16.80%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train.py,recipes/CommonVoice/ASR/transformer/hparams/train_de.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/umfq986o3d9o1px/AAARNF2BFYELOWx3xhIOEoZka?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --d_model=128 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1,"file_exists=[wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/27_unigram.model,save/ASR_train.txt,save/27_unigram.vocab]",Test-WER=16.76%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/0e4vtvbg6hf2e13/AAD-tfzCZGUrh85aeAeJj8I9a?dl=0,https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-ar,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]",Test-WER=16.96%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/w1urihacmtoulmi/AADMtK3qeAF5mLYk5LMHyiOra?dl=0,https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fa,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]",Test-WER=31.75%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/7zlk07yxnslk4yy/AAANcI3EaG0ZFy6UrKk1Mm2Ga?dl=0,https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fr,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]",Test-WER=10.62%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/5lhk230q45sd97z/AAD-U9b_Ws_vFPs-cazsbOY0a?dl=0,https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-sr,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]",Test-WER=22.29%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/6fbhmey7q1udykf/AAAiGObWTTe2cdXHt2Uv2VQXa?dl=0,https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-mn,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]",Test-WER=67.84%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/z9vriyy3i6xqvif/AAB7ql-40yWTjKEQJiuhYUr5a?dl=0,https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-hi,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]",Test-WER=15.27%
+ASR-transformer,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_it_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://www.dropbox.com/sh/u5tex3nvzzs5pex/AAD-J7cOBE_fNfBono8waTKCa?dl=0,https://huggingface.co/speechbrain/asr-whisper-medium-commonvoice-it,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[train_with_whisper.py,wer_valid.txt,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]",Test-WER=9.63%
+SSL,CommonVoice,recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py,recipes/CommonVoice/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml,recipes/CommonVoice/self-supervised-learning/wav2vec2/common_voice_prepare.py,recipes/CommonVoice/self-supervised-learning/wav2vec2/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --d_model=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_hf_wav2vec2.py,train_log.txt,log.txt,env.log,hyperparams.yaml]",
+quantization,CommonVoice,recipes/CommonVoice/quantization/train.py,recipes/CommonVoice/quantization/hparams/train_with_hubert.yaml,recipes/CommonVoice/quantization/common_voice_prepare.py,recipes/CommonVoice/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
+quantization,CommonVoice,recipes/CommonVoice/quantization/train.py,recipes/CommonVoice/quantization/hparams/train_with_wav2vec.yaml,recipes/CommonVoice/quantization/common_voice_prepare.py,recipes/CommonVoice/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
+quantization,CommonVoice,recipes/CommonVoice/quantization/train.py,recipes/CommonVoice/quantization/hparams/train_with_wavlm.yaml,recipes/CommonVoice/quantization/common_voice_prepare.py,recipes/CommonVoice/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
+LM,CommonVoice,recipes/CommonVoice/LM/train.py,recipes/CommonVoice/LM/hparams/train_kenlm.yaml,recipes/CommonVoice/LM/common_voice_prepare.py,recipes/CommonVoice/LM/README.md,https://www.dropbox.com/scl/fo/zw505t10kesqpvkt6m3tu/h?rlkey=6626h1h665tvlo1mtekop9rx5&dl=0,,--data_folder=tests/samples/ASR/ --text_file=tests/samples/annotation/LM_train.txt --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
diff --git a/tests/recipes/DNS.csv b/tests/recipes/DNS.csv
new file mode 100644
index 0000000000000000000000000000000000000000..f58921c68292c7aa139c5440f0d1faba420d159b
--- /dev/null
+++ b/tests/recipes/DNS.csv
@@ -0,0 +1,2 @@
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,test_download,performance
+Enhancement,DNS,recipes/DNS/enhancement/train.py,recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml,recipes/DNS/create_wds_shards.py,recipes/DNS/enhancement/README.md,https://www.dropbox.com/sh/d3rp5d3gjysvy7c/AACmwcEkm_IFvaW1lt2GdtQka?dl=0,https://huggingface.co/speechbrain/sepformer-dns4-16k-enhancement,--data_folder=tests/download/DNS/ --train_data=tests/download/DNS/train_shards/ --valid_data=tests/download/DNS/train_shards/ --baseline_noisy_shards_folder=tests/download/DNS/baseline/ --baseline_shards=tests/download/DNS/baseline/shard-000000.tar --N_epochs=1 --batch_size=1,"file_exists=[valid_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]","download_file('https://www.dropbox.com/scl/fi/i3iwzmrnyw8pgputkqvgq/DNS.zip?rlkey=1ka0g2ig4x488fg1exnxmbprd&dl=1', 'tests/download/DNS.zip', unpack=True, dest_unpack='tests/download/', write_permissions=True)",valid-PESQ=2.06 test-SIG=2.999 test-BAK=3.076 test-OVRL=2.437
diff --git a/tests/recipes/DVoice.csv b/tests/recipes/DVoice.csv
index 405e2f744cbe0d0f319f887284131b21b74ab151..282026353264463b9dc7ff8f24edd37a2ca691d1 100644
--- a/tests/recipes/DVoice.csv
+++ b/tests/recipes/DVoice.csv
@@ -1,7 +1,7 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_amh_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-amharic,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=23 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/23_char.model,save/ASR_train.txt,save/23_char.vocab]"
-ASR,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_dar_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-darija,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=23 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/23_char.model,save/ASR_train.txt,save/23_char.vocab]"
-ASR,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_fon_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-fongbe,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=23 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/23_char.model,save/ASR_train.txt,save/23_char.vocab]"
-ASR,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_multi_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=17 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/17_char.model,save/ASR_train.txt,save/17_char.vocab]"
-ASR,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_sw_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-swahili,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=23 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/23_char.model,save/ASR_train.txt,save/23_char.vocab]"
-ASR,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_wol_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-wolof,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=23 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/23_char.model,save/ASR_train.txt,save/23_char.vocab]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR-CTC,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_amh_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-amharic,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=21 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/21_char.model,save/ASR_train.txt,save/21_char.vocab]",Test-WER=24.92%
+ASR-CTC,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_dar_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-darija,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=21 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/21_char.model,save/ASR_train.txt,save/21_char.vocab]",Test-WER=18.28%
+ASR-CTC,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_fon_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-fongbe,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=21 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/21_char.model,save/ASR_train.txt,save/21_char.vocab]",Test-WER=9.00%
+Multilingual-ASR-CTC,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_multi_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=17 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/17_char.model,save/ASR_train.txt,save/17_char.vocab]",Test-WER-Darija=13.27% Test-WER-Swahili=29.31% Test-WER-Fongbe=10.26% Test-WER-Fongbe-Wolof=21.54% Test-WER-Amharic=31.15%
+ASR-CTC,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_sw_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-swahili,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=21 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/21_char.model,save/ASR_train.txt,save/21_char.vocab]",Test-WER=23.16%
+ASR-CTC,DVoice,recipes/DVoice/ASR/CTC/train_with_wav2vec2.py,recipes/DVoice/ASR/CTC/hparams/train_wol_with_wav2vec.yaml,recipes/DVoice/ASR/CTC/dvoice_prepare.py,recipes/DVoice/ASR/CTC/README.md,https://www.dropbox.com/sh/pyu40jq1ebv6hcc/AADQO_lAD-F9Q0vlVq8KoXHqa?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-dvoice-wolof,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=21 --dnn_neurons=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_test.txt,train_log.txt,log.txt,train_with_wav2vec2.py,env.log,hyperparams.yaml,save/21_char.model,save/ASR_train.txt,save/21_char.vocab]",Test-WER=16.05%
diff --git a/tests/recipes/ESC50.csv b/tests/recipes/ESC50.csv
index ed1f5e97a434226993685c755a4fe06acc6242b9..5068d01acb7e7f5196fa23475e0f401af7b68e96 100644
--- a/tests/recipes/ESC50.csv
+++ b/tests/recipes/ESC50.csv
@@ -1,7 +1,7 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-SoundClassification,ESC50,recipes/ESC50/classification/train_classifier.py,recipes/ESC50/classification/hparams/cnn14_classifier.yaml,recipes/ESC50/classification/esc50_prepare.py,recipes/ESC50/classification/README.md,https://www.dropbox.com/sh/fbe7l14o3n8f5rw/AACABE1BQGBbX4j6A1dIhBcSa?dl=0,,--data_folder tests/samples/ESC50 --debug --out_n_neurons 3 --use_pretrained 0,"file_exists=[train_classifier.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]"
-SoundClassification,ESC50,recipes/ESC50/classification/train_classifier.py,recipes/ESC50/classification/hparams/conv2d_classifier.yaml,recipes/ESC50/classification/esc50_prepare.py,recipes/ESC50/classification/README.md,https://www.dropbox.com/sh/tl2pbfkreov3z7e/AADwwhxBLw1sKvlSWzp6DMEia?dl=0,,--data_folder tests/samples/ESC50 --debug --out_n_neurons 3 --use_pretrained 0,"file_exists=[train_classifier.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]"
-Interpretability,ESC50,recipes/ESC50/interpret/train_nmf.py,recipes/ESC50/interpret/hparams/nmf.yaml,recipes/ESC50/interpret/esc50_prepare.py,recipes/ESC50/interpret/README.md,https://www.dropbox.com/sh/01exv8dt3k6l1kk/AADuKmikAPwMw5wlulojd5Ira?dl=0,,--data_folder tests/samples/ESC50 --debug,"file_exists=[train_log.txt,train_nmf.py,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]"
-Interpretability,ESC50,recipes/ESC50/interpret/train_l2i.py,recipes/ESC50/interpret/hparams/l2i_cnn14.yaml,recipes/ESC50/interpret/esc50_prepare.py,recipes/ESC50/interpret/README.md,https://www.dropbox.com/sh/cli2gm8nb4bthow/AAAKnzU0c80s_Rm7wx4i_Orza?dl=0,,--data_folder tests/samples/ESC50 --debug --use_pretrained 0,"file_exists=[train_l2i.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]"
-Interpretability,ESC50,recipes/ESC50/interpret/train_l2i.py,recipes/ESC50/interpret/hparams/l2i_conv2dclassifier.yaml,recipes/ESC50/interpret/esc50_prepare.py,recipes/ESC50/interpret/README.md,https://www.dropbox.com/sh/gcpk9jye9ka08n0/AAB-m10r1YEH0rJdUMrCwizUa?dl=0,,--data_folder tests/samples/ESC50 --debug --use_pretrained 0,"file_exists=[train_l2i.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]"
-Interpretability,ESC50,recipes/ESC50/interpret/train_piq.py,recipes/ESC50/interpret/hparams/piq.yaml,recipes/ESC50/interpret/esc50_prepare.py,recipes/ESC50/interpret/README.md,https://www.dropbox.com/sh/v1x5ks9t67ftysp/AABo494rDElHTiTpKR_6PP_ua?dl=0,,--data_folder tests/samples/ESC50 --debug --use_pretrained 0 --out_n_neurons 3 --number_of_epochs 1,"file_exists=[train_log.txt,log.txt,train_piq.py,env.log,hyperparams.yaml,save/label_encoder.txt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+SoundClassification,ESC50,recipes/ESC50/classification/train_classifier.py,recipes/ESC50/classification/hparams/cnn14_classifier.yaml,recipes/ESC50/classification/esc50_prepare.py,recipes/ESC50/classification/README.md,https://www.dropbox.com/sh/fbe7l14o3n8f5rw/AACABE1BQGBbX4j6A1dIhBcSa?dl=0,,--data_folder tests/samples/ESC50 --debug --out_n_neurons 3 --use_pretrained 0,"file_exists=[train_classifier.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]",Accuracy=82%
+SoundClassification,ESC50,recipes/ESC50/classification/train_classifier.py,recipes/ESC50/classification/hparams/conv2d_classifier.yaml,recipes/ESC50/classification/esc50_prepare.py,recipes/ESC50/classification/README.md,https://www.dropbox.com/sh/tl2pbfkreov3z7e/AADwwhxBLw1sKvlSWzp6DMEia?dl=0,,--data_folder tests/samples/ESC50 --debug --out_n_neurons 3 --use_pretrained 0,"file_exists=[train_classifier.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]",Accuracy=75%
+Interpretability,ESC50,recipes/ESC50/interpret/train_nmf.py,recipes/ESC50/interpret/hparams/nmf.yaml,recipes/ESC50/interpret/esc50_prepare.py,recipes/ESC50/interpret/README.md,https://www.dropbox.com/sh/01exv8dt3k6l1kk/AADuKmikAPwMw5wlulojd5Ira?dl=0,,--data_folder tests/samples/ESC50 --debug,"file_exists=[train_log.txt,train_nmf.py,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]",
+Interpretability,ESC50,recipes/ESC50/interpret/train_l2i.py,recipes/ESC50/interpret/hparams/l2i_cnn14.yaml,recipes/ESC50/interpret/esc50_prepare.py,recipes/ESC50/interpret/README.md,https://www.dropbox.com/sh/cli2gm8nb4bthow/AAAKnzU0c80s_Rm7wx4i_Orza?dl=0,,--data_folder tests/samples/ESC50 --debug --use_pretrained 0,"file_exists=[train_l2i.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]",
+Interpretability,ESC50,recipes/ESC50/interpret/train_l2i.py,recipes/ESC50/interpret/hparams/l2i_conv2dclassifier.yaml,recipes/ESC50/interpret/esc50_prepare.py,recipes/ESC50/interpret/README.md,https://www.dropbox.com/sh/gcpk9jye9ka08n0/AAB-m10r1YEH0rJdUMrCwizUa?dl=0,,--data_folder tests/samples/ESC50 --debug --use_pretrained 0,"file_exists=[train_l2i.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]",
+Interpretability,ESC50,recipes/ESC50/interpret/train_piq.py,recipes/ESC50/interpret/hparams/piq.yaml,recipes/ESC50/interpret/esc50_prepare.py,recipes/ESC50/interpret/README.md,https://www.dropbox.com/sh/v1x5ks9t67ftysp/AABo494rDElHTiTpKR_6PP_ua?dl=0,,--data_folder tests/samples/ESC50 --debug --use_pretrained 0 --out_n_neurons 3 --number_of_epochs 1,"file_exists=[train_log.txt,log.txt,train_piq.py,env.log,hyperparams.yaml,save/label_encoder.txt]",
diff --git a/tests/recipes/Fisher-Callhome-Spanish.csv b/tests/recipes/Fisher-Callhome-Spanish.csv
index d47abd477d0e6f20b22cfd5f1b349b016649e3b2..fad6f60878a0f418db5c6acd44c23ee5023b761b 100644
--- a/tests/recipes/Fisher-Callhome-Spanish.csv
+++ b/tests/recipes/Fisher-Callhome-Spanish.csv
@@ -1,4 +1,4 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Tokenizer,Fisher-Callhome-Spanish,recipes/Fisher-Callhome-Spanish/Tokenizer/train.py,recipes/Fisher-Callhome-Spanish/Tokenizer/hparams/train_bpe_1k.yaml,recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py,recipes/Fisher-Callhome-Spanish/README.md,,,--data_folder=tests/samples/single-mic --train_annotation=tests/samples/annotation/multi_annotation.json --skip_prep=True --token_output=30 --original_data_folder=tests/samples/single-mic,"file_exists=[30_bpe.model,30_bpe.vocab,log.txt,multi_annotation.txt,env.log,train.py,hyperparams.yaml]"
-Speech_Translation,Fisher-Callhome-Spanish,recipes/Fisher-Callhome-Spanish/ST/transformer/train.py,recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/conformer.yaml,recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py,recipes/Fisher-Callhome-Spanish/README.md,https://www.dropbox.com/sh/tmh7op8xwthdta0/AACuU9xHDHPs8ToxIIwoTLB0a?dl=0 https://www.dropbox.com/sh/qz33qjr10y351gk/AADApachs3WtDXx67pIz5fCZa?dl=0,,--data_folder=tests/samples/single-mic --number_of_epochs=2 --vocab_size=30 --output_neurons=30 --tokenizer_file=tests/tmp/Fisher-Callhome-Spanish_row_2/30_bpe.model,"file_exists=[train_log.txt,log.txt,env.log,train.py,bleu.txt,hyperparams.yaml]"
-Speech_Translation,Fisher-Callhome-Spanish,recipes/Fisher-Callhome-Spanish/ST/transformer/train.py,recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/transformer.yaml,recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py,recipes/Fisher-Callhome-Spanish/README.md,https://www.dropbox.com/sh/tmh7op8xwthdta0/AACuU9xHDHPs8ToxIIwoTLB0a?dl=0 https://www.dropbox.com/sh/qz33qjr10y351gk/AADApachs3WtDXx67pIz5fCZa?dl=0,,--data_folder=tests/samples/single-mic --number_of_epochs=2 --vocab_size=30 --output_neurons=30 --tokenizer_file=tests/tmp/Fisher-Callhome-Spanish_row_2/30_bpe.model,"file_exists=[train_log.txt,log.txt,env.log,train.py,bleu.txt,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Tokenizer,Fisher-Callhome-Spanish,recipes/Fisher-Callhome-Spanish/Tokenizer/train.py,recipes/Fisher-Callhome-Spanish/Tokenizer/hparams/train_bpe_1k.yaml,recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py,recipes/Fisher-Callhome-Spanish/README.md,,,--data_folder=tests/samples/single-mic --train_annotation=tests/samples/annotation/multi_annotation.json --skip_prep=True --token_output=30 --original_data_folder=tests/samples/single-mic,"file_exists=[30_bpe.model,30_bpe.vocab,log.txt,multi_annotation.txt,env.log,train.py,hyperparams.yaml]",
+Speech_Translation,Fisher-Callhome-Spanish,recipes/Fisher-Callhome-Spanish/ST/transformer/train.py,recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/transformer.yaml,recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py,recipes/Fisher-Callhome-Spanish/README.md,https://www.dropbox.com/sh/tmh7op8xwthdta0/AACuU9xHDHPs8ToxIIwoTLB0a?dl=0,,--data_folder=tests/samples/single-mic --number_of_epochs=2 --vocab_size=30 --output_neurons=30 --tokenizer_file=tests/tmp/Fisher-Callhome-Spanish_row_02/30_bpe.model,"file_exists=[train_log.txt,log.txt,env.log,train.py,bleu.txt,hyperparams.yaml]",Test-sacrebleu=47.31
+Speech_Translation,Fisher-Callhome-Spanish,recipes/Fisher-Callhome-Spanish/ST/transformer/train.py,recipes/Fisher-Callhome-Spanish/ST/transformer/hparams/conformer.yaml,recipes/Fisher-Callhome-Spanish/fisher_callhome_prepare.py,recipes/Fisher-Callhome-Spanish/README.md,https://www.dropbox.com/sh/tmh7op8xwthdta0/AACuU9xHDHPs8ToxIIwoTLB0a?dl=0,,--data_folder=tests/samples/single-mic --number_of_epochs=2 --vocab_size=30 --output_neurons=30 --tokenizer_file=tests/tmp/Fisher-Callhome-Spanish_row_02/30_bpe.model,"file_exists=[train_log.txt,log.txt,env.log,train.py,bleu.txt,hyperparams.yaml]",Test-sacrebleu=48.04
diff --git a/tests/recipes/Google-speech-commands.csv b/tests/recipes/Google-speech-commands.csv
index a9baaf05b5d626d5529260612eabc4c946f8e8b7..c60035fd6cf89d7ff774b2c7dea7d8e5c6cf87b9 100644
--- a/tests/recipes/Google-speech-commands.csv
+++ b/tests/recipes/Google-speech-commands.csv
@@ -1,3 +1,3 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Command_recognition,Google-speech-commands,recipes/Google-speech-commands/train.py,recipes/Google-speech-commands/hparams/xvect.yaml,recipes/Google-speech-commands/prepare_GSC.py,recipes/Google-speech-commands/README.md,https://www.dropbox.com/sh/9n9q42pugbx0g7a/AADihpfGKuWf6gkwQznEFINDa?dl=0,https://huggingface.co/speechbrain/google_speech_command_xvector,--rir_folder=tests/tmp --data_folder=tests/samples/ASR --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --test_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]"
-Command_recognition,Google-speech-commands,recipes/Google-speech-commands/train.py,recipes/Google-speech-commands/hparams/xvect_leaf.yaml,recipes/Google-speech-commands/prepare_GSC.py,recipes/Google-speech-commands/README.md,https://www.dropbox.com/sh/r63w4gytft4s1x6/AAApP8-pp179QKGCZHV_OuD8a?dl=0,,--rir_folder=tests/tmp --data_folder=tests/samples/ASR --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --test_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Command_recognition,Google-speech-commands,recipes/Google-speech-commands/train.py,recipes/Google-speech-commands/hparams/xvect.yaml,recipes/Google-speech-commands/prepare_GSC.py,recipes/Google-speech-commands/README.md,https://www.dropbox.com/sh/9n9q42pugbx0g7a/AADihpfGKuWf6gkwQznEFINDa?dl=0,https://huggingface.co/speechbrain/google_speech_command_xvector,--data_folder=tests/samples/ASR --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --test_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]",Test-accuracy=97.43%
+Command_recognition,Google-speech-commands,recipes/Google-speech-commands/train.py,recipes/Google-speech-commands/hparams/xvect_leaf.yaml,recipes/Google-speech-commands/prepare_GSC.py,recipes/Google-speech-commands/README.md,https://www.dropbox.com/sh/r63w4gytft4s1x6/AAApP8-pp179QKGCZHV_OuD8a?dl=0,,--data_folder=tests/samples/ASR --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --test_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]",Test-accuracy=96.79%
diff --git a/tests/recipes/IEMOCAP.csv b/tests/recipes/IEMOCAP.csv
index 605896cd780b9af885edc07bf2ab53f75ae21a33..626ff711437eb9b716944afea38d9318aeec50a9 100644
--- a/tests/recipes/IEMOCAP.csv
+++ b/tests/recipes/IEMOCAP.csv
@@ -1,3 +1,6 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Emotion_recognition,IEMOCAP,recipes/IEMOCAP/emotion_recognition/train_with_wav2vec2.py,recipes/IEMOCAP/emotion_recognition/hparams/train_with_wav2vec2.yaml,recipes/IEMOCAP/emotion_recognition/iemocap_prepare.py,recipes/IEMOCAP/README.md,https://www.dropbox.com/sh/lmebg4li83sgkhg/AACooPKbNlwd-7n5qSJMbc7ya?dl=0 https://www.dropbox.com/sh/ikjwnwebekf2xx2/AADyaJKPiaR0_iO0nntucH5pa?dl=0 https://www.dropbox.com/sh/ke4fxiry97z58m8/AACPEOM5bIyxo9HxG2mT9v_aa?dl=0,https://huggingface.co/speechbrain/emotion-recognition-wav2vec2-IEMOCAP/,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_with_wav2vec2.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]"
-Emotion_recognition,IEMOCAP,recipes/IEMOCAP/emotion_recognition/train.py,recipes/IEMOCAP/emotion_recognition/hparams/train.yaml,recipes/IEMOCAP/emotion_recognition/iemocap_prepare.py,recipes/IEMOCAP/README.md,https://www.dropbox.com/sh/lmebg4li83sgkhg/AACooPKbNlwd-7n5qSJMbc7ya?dl=0 https://www.dropbox.com/sh/ikjwnwebekf2xx2/AADyaJKPiaR0_iO0nntucH5pa?dl=0 https://www.dropbox.com/sh/ke4fxiry97z58m8/AACPEOM5bIyxo9HxG2mT9v_aa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,predictions.csv,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Emotion_recognition,IEMOCAP,recipes/IEMOCAP/emotion_recognition/train_with_wav2vec2.py,recipes/IEMOCAP/emotion_recognition/hparams/train_with_wav2vec2.yaml,recipes/IEMOCAP/emotion_recognition/iemocap_prepare.py,recipes/IEMOCAP/README.md,https://www.dropbox.com/sh/lmebg4li83sgkhg/AACooPKbNlwd-7n5qSJMbc7ya?dl=0,https://huggingface.co/speechbrain/emotion-recognition-wav2vec2-IEMOCAP/,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[train_with_wav2vec2.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/label_encoder.txt]",Test-Accuracy=65.7%
+Emotion_recognition,IEMOCAP,recipes/IEMOCAP/emotion_recognition/train.py,recipes/IEMOCAP/emotion_recognition/hparams/train.yaml,recipes/IEMOCAP/emotion_recognition/iemocap_prepare.py,recipes/IEMOCAP/README.md,https://www.dropbox.com/sh/ke4fxiry97z58m8/AACPEOM5bIyxo9HxG2mT9v_aa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,predictions.csv,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]",Test-Accuracy=77.0%
+quantization,IEMOCAP,recipes/IEMOCAP/quantization/train.py,recipes/IEMOCAP/quantization/hparams/train_with_hubert.yaml,recipes/IEMOCAP/quantization/iemocap_prepare.py,recipes/IEMOCAP/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
+quantization,IEMOCAP,recipes/IEMOCAP/quantization/train.py,recipes/IEMOCAP/quantization/hparams/train_with_wav2vec.yaml,recipes/IEMOCAP/quantization/iemocap_prepare.py,recipes/IEMOCAP/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
+quantization,IEMOCAP,recipes/IEMOCAP/quantization/train.py,recipes/IEMOCAP/quantization/hparams/train_with_wavlm.yaml,recipes/IEMOCAP/quantization/iemocap_prepare.py,recipes/IEMOCAP/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
diff --git a/tests/recipes/IWSLT22_lowresource.csv b/tests/recipes/IWSLT22_lowresource.csv
index 52c757dffddcd6eda8b4f71cbc2f27894e31f3ad..46cb4e5976674fb41b1aeae835557c12febd05c1 100644
--- a/tests/recipes/IWSLT22_lowresource.csv
+++ b/tests/recipes/IWSLT22_lowresource.csv
@@ -1,2 +1,7 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Speech_Translation,IWSLT22_lowresource,recipes/IWSLT22_lowresource/train.py,recipes/IWSLT22_lowresource/hparams/train_w2v2_st.yaml,recipes/IWSLT22_lowresource/prepare_iwslt22.py,recipes/IWSLT22_lowresource/README.md,,,--root_data_folder=tests/samples/ASR/ --data_folder=tests/samples/ASR/ --number_of_epochs=2 --vocab_size=42 --annotation_train=tests/samples/annotation/ASR_train.json --annotation_valid=tests/samples/annotation/ASR_dev.json --annotation_test=tests/samples/annotation/ASR_dev.json --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/42_unigram.model,save/ASR_train.txt,save/42_unigram.vocab]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Speech_Translation,IWSLT22_lowresource,recipes/IWSLT22_lowresource/AST/transformer/train.py,recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_st.yaml,recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py,recipes/IWSLT22_lowresource/AST/transformer/README.md,,,--root_data_folder=tests/samples/ASR/ --data_folder=tests/samples/ASR/ --number_of_epochs=2 --vocab_size=42 --annotation_train=tests/samples/annotation/ASR_train.json --annotation_valid=tests/samples/annotation/ASR_dev.json --annotation_test=tests/samples/annotation/ASR_dev.json --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/42_unigram.model,save/ASR_train.txt,save/42_unigram.vocab]",
+Speech_Translation,IWSLT22_lowresource,recipes/IWSLT22_lowresource/AST/transformer/train_with_w2v_mbart.py,recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_mbart_st.yaml,recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py,recipes/IWSLT22_lowresource/AST/transformer/README.md,https://www.dropbox.com/sh/xjo0ou739oksnus/AAAgyrCwywmDRRuUiDnUva2za?dl=0,,--root_data_folder=tests/samples/ASR/ --data_folder=tests/samples/ASR/ --number_of_epochs=2 --vocab_size=250054 --annotation_train=tests/samples/annotation/ASR_train.json --annotation_valid=tests/samples/annotation/ASR_dev.json --annotation_test=tests/samples/annotation/ASR_dev.json --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train_with_w2v_mbart.py,hyperparams.yaml]",Test-BLEU=7.73
+Speech_Translation,IWSLT22_lowresource,recipes/IWSLT22_lowresource/AST/transformer/train_with_w2v_mbart.py,recipes/IWSLT22_lowresource/AST/transformer/hparams/train_w2v2_nllb_st.yaml,recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py,recipes/IWSLT22_lowresource/AST/transformer/README.md,https://www.dropbox.com/sh/spp2ijgfdbzuz26/AABkJ97e72D7aKzNLTm1qmWEa?dl=0,,--root_data_folder=tests/samples/ASR/ --data_folder=tests/samples/ASR/ --number_of_epochs=2 --vocab_size=256206 --annotation_train=tests/samples/annotation/ASR_train.json --annotation_valid=tests/samples/annotation/ASR_dev.json --annotation_test=tests/samples/annotation/ASR_dev.json --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train_with_w2v_mbart.py,hyperparams.yaml]",Test-BLEU=8.70
+SAMU_Pretraining,IWSLT22_lowresource,recipes/IWSLT22_lowresource/AST/transformer/train_samu.py,recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu.yaml,recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py,recipes/IWSLT22_lowresource/AST/transformer/README.md,,,--root_data_folder=tests/samples/ASR/ --data_folder=tests/samples/ASR/ --number_of_epochs=2 --train_set=tests/samples/annotation/ASR_train.json --valid_set=tests/samples/annotation/ASR_dev.json --test_set=tests/samples/annotation/ASR_dev.json --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train_samu.py,hyperparams.yaml]",
+Speech_Translation,IWSLT22_lowresource,recipes/IWSLT22_lowresource/AST/transformer/train_with_samu_mbart.py,recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_mbart_st.yaml,recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py,recipes/IWSLT22_lowresource/AST/transformer/README.md,https://www.dropbox.com/sh/98s1xyc3chreaw6/AABom3FnwY5SsIvg4en9tWC2a?dl=0,,--root_data_folder=tests/samples/ASR/ --data_folder=tests/samples/ASR/ --number_of_epochs=2 --vocab_size=250054 --annotation_train=tests/samples/annotation/ASR_train.json --annotation_valid=tests/samples/annotation/ASR_dev.json --annotation_test=tests/samples/annotation/ASR_dev.json --pre_trained_samu=tests/tmp/IWSLT22_lowresource_row_05/save/CKPT+checkpoint_epoch2/wav2vec2.ckpt --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train_with_samu_mbart.py,hyperparams.yaml]",Test-BLEU=10.28
+Speech_Translation,IWSLT22_lowresource,recipes/IWSLT22_lowresource/AST/transformer/train_with_samu_mbart.py,recipes/IWSLT22_lowresource/AST/transformer/hparams/train_samu_nllb_st.yaml,recipes/IWSLT22_lowresource/AST/transformer/prepare_iwslt22.py,recipes/IWSLT22_lowresource/AST/transformer/README.md,https://www.dropbox.com/sh/ekkpl9c3kxsgllj/AABa0q2LrJe_o7JF-TTbfxZ-a?dl=0,,--root_data_folder=tests/samples/ASR/ --data_folder=tests/samples/ASR/ --number_of_epochs=2 --vocab_size=256206 --annotation_train=tests/samples/annotation/ASR_train.json --annotation_valid=tests/samples/annotation/ASR_dev.json --annotation_test=tests/samples/annotation/ASR_dev.json --pre_trained_samu=tests/tmp/IWSLT22_lowresource_row_05/save/CKPT+checkpoint_epoch2/wav2vec2.ckpt --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train_with_samu_mbart.py,hyperparams.yaml]",Test-BLEU=11.32
diff --git a/tests/recipes/KsponSpeech.csv b/tests/recipes/KsponSpeech.csv
index 982896302c65a32046cb101f6c578036ff7308e2..763f7042edac61f319fe96e503ffe0a18a48980e 100644
--- a/tests/recipes/KsponSpeech.csv
+++ b/tests/recipes/KsponSpeech.csv
@@ -1,4 +1,4 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR,KsponSpeech,recipes/KsponSpeech/ASR/transformer/train.py,recipes/KsponSpeech/ASR/transformer/hparams/conformer_medium.yaml,recipes/KsponSpeech/ASR/transformer/ksponspeech_prepare.py,recipes/KsponSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/uibokbz83o8ybv3/AACtO5U7mUbu_XhtcoOphAjza?dl=0,https://huggingface.co/speechbrain/asr-conformer-transformerlm-ksponspeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --skip_prep=True --num_encoder_layers=3 --num_decoder_layers=3,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt]"
-LM,KsponSpeech,recipes/KsponSpeech/LM/train.py,recipes/KsponSpeech/LM/hparams/transformer.yaml,recipes/KsponSpeech/LM/ksponspeech_prepare.py,recipes/KsponSpeech/LM/README.md,https://www.dropbox.com/sh/egv5bdn8b5i45eo/AAB7a8gFt2FqbnO4yhL6DQ8na?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --d_model=120 --d_ffn=96,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt]"
-Tokenizer,KsponSpeech,recipes/KsponSpeech/Tokenizer/train.py,recipes/KsponSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml,recipes/KsponSpeech/Tokenizer/ksponspeech_prepare.py,recipes/KsponSpeech/Tokenizer/README.md,https://www.dropbox.com/sh/prnqt09e7xpc1kr/AAB-HkfUazPifn7kXnKnAJSga?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR,KsponSpeech,recipes/KsponSpeech/ASR/transformer/train.py,recipes/KsponSpeech/ASR/transformer/hparams/conformer_medium.yaml,recipes/KsponSpeech/ASR/transformer/ksponspeech_prepare.py,recipes/KsponSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/uibokbz83o8ybv3/AACtO5U7mUbu_XhtcoOphAjza?dl=0,https://huggingface.co/speechbrain/asr-conformer-transformerlm-ksponspeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --skip_prep=True --num_encoder_layers=3 --num_decoder_layers=3,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt]",Test-clean-WER=20.78% Test-others-WER=25.73%
+LM,KsponSpeech,recipes/KsponSpeech/LM/train.py,recipes/KsponSpeech/LM/hparams/transformer.yaml,recipes/KsponSpeech/LM/ksponspeech_prepare.py,recipes/KsponSpeech/LM/README.md,https://www.dropbox.com/sh/egv5bdn8b5i45eo/AAB7a8gFt2FqbnO4yhL6DQ8na?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --d_model=120 --d_ffn=96,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt]",
+Tokenizer,KsponSpeech,recipes/KsponSpeech/Tokenizer/train.py,recipes/KsponSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml,recipes/KsponSpeech/Tokenizer/ksponspeech_prepare.py,recipes/KsponSpeech/Tokenizer/README.md,https://www.dropbox.com/sh/prnqt09e7xpc1kr/AAB-HkfUazPifn7kXnKnAJSga?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]",
diff --git a/tests/recipes/LJSpeech.csv b/tests/recipes/LJSpeech.csv
index b2018c74124aae535234b64aee20aabdc6e334c8..2aeabc8f70b2d77331d31d3f3e129bcaf9ef5a81 100644
--- a/tests/recipes/LJSpeech.csv
+++ b/tests/recipes/LJSpeech.csv
@@ -1,5 +1,10 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-TTS,LJSpeech,recipes/LJSpeech/TTS/fastspeech2/train.py,recipes/LJSpeech/TTS/fastspeech2/hparams/train.yaml,recipes/LJSpeech/TTS/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,https://www.dropbox.com/sh/tqyp58ogejqfres/AAAtmq7cRoOR3XTsq0iSgyKBa?dl=0,https://huggingface.co/speechbrain/tts-fastspeech2-ljspeech,--batch_size=2 --epochs=2 --data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --valid_json=tests/samples/annotation/TTS_train.json --test_json=tests/samples/annotation/TTS_train.json --skip_prep=True --sample_rate=16000 --num_workers_train 0 --num_workers_valid 0,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-TTS,LJSpeech,recipes/LJSpeech/TTS/tacotron2/train.py,recipes/LJSpeech/TTS/tacotron2/hparams/train.yaml,recipes/LJSpeech/TTS/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,https://www.dropbox.com/sh/1npvo1g1ncafipf/AAC5DR1ErF2Q9V4bd1DHqX43a?dl=0,https://huggingface.co/speechbrain/tts-tacotron2-ljspeech,--epochs=2 --data_folder=tests/samples/ASR --train_json=tests/samples/annotation/ASR_train.json --valid_json=tests/samples/annotation/ASR_dev.json --test_json=tests/samples/annotation/ASR_dev.json --skip_prep=True --sample_rate=16000 --num_workers 0,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,samples/1/inference_mel_out.png,samples/1/raw_batch.pth,samples/1/target.png,samples/1/alignments.png,samples/1/output_postnet.png,samples/1/output.png,samples/2/inference_mel_out.png,samples/2/raw_batch.pth,samples/2/target.png,samples/2/alignments.png,samples/2/output_postnet.png,samples/2/output.png]"
-TTS,LJSpeech,recipes/LJSpeech/TTS/vocoder/hifi_gan/train.py,recipes/LJSpeech/TTS/vocoder/hifi_gan/hparams/train.yaml,recipes/LJSpeech/TTS/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,https://www.dropbox.com/sh/m2xrdssiroipn8g/AAD-TqPYLrSg6eNxUkcImeg4a?dl=0,https://huggingface.co/speechbrain/tts-hifigan-ljspeech,--epochs=2 --data_folder=tests/samples/ASR --train_json=tests/samples/annotation/ASR_train.json --valid_json=tests/samples/annotation/ASR_dev.json --test_json=tests/samples/annotation/ASR_dev.json --skip_prep=True --sample_rate=16000,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,samples/1/synthesized.wav,samples/1/target.wav,samples/2/synthesized.wav,samples/2/target.wav]"
-TTS,LJSpeech,recipes/LJSpeech/TTS/vocoder/diffwave/train.py,recipes/LJSpeech/TTS/vocoder/diffwave/hparams/train.yaml,recipes/LJSpeech/TTS/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,,,--number_of_epochs=2 --data_folder=tests/samples/ASR --train_json=tests/samples/annotation/ASR_train.json --valid_json=tests/samples/annotation/ASR_dev.json --test_json=tests/samples/annotation/ASR_dev.json --skip_prep=True --sample_rate=16000 --num_workers 0,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
+TTS,LJSpeech,recipes/LJSpeech/TTS/fastspeech2/train_internal_alignment.py,recipes/LJSpeech/TTS/fastspeech2/hparams/train_internal_alignment.yaml,recipes/LJSpeech/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,https://www.dropbox.com/scl/fo/4ctkc6jjas3uij9dzcwta/h?rlkey=i0k086d77flcsdx40du1ppm2d&dl=0,https://huggingface.co/speechbrain/tts-fastspeech2-internal-alignment-ljspeech,--batch_size=2 --epochs=2 --data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --valid_json=tests/samples/annotation/TTS_train.json --test_json=tests/samples/annotation/TTS_train.json --skip_prep=True --sample_rate=16000 --num_workers_train 0 --num_workers_valid 0,"file_exists=[train_log.txt,log.txt,env.log,train_internal_alignment.py,hyperparams.yaml]"
+TTS,LJSpeech,recipes/LJSpeech/TTS/fastspeech2/train.py,recipes/LJSpeech/TTS/fastspeech2/hparams/train.yaml,recipes/LJSpeech/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,https://www.dropbox.com/scl/fo/vtgbltqdrvw9r0vs7jz67/h?rlkey=cm2mwh5rce5ad9e90qaciypox&dl=0,https://huggingface.co/speechbrain/tts-fastspeech2-ljspeech,--batch_size=2 --epochs=2 --data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --valid_json=tests/samples/annotation/TTS_train.json --test_json=tests/samples/annotation/TTS_train.json --skip_prep=True --sample_rate=16000 --num_workers_train 0 --num_workers_valid 0,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+TTS,LJSpeech,recipes/LJSpeech/TTS/tacotron2/train.py,recipes/LJSpeech/TTS/tacotron2/hparams/train.yaml,recipes/LJSpeech/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,https://www.dropbox.com/sh/1npvo1g1ncafipf/AAC5DR1ErF2Q9V4bd1DHqX43a?dl=0,https://huggingface.co/speechbrain/tts-tacotron2-ljspeech,--epochs=2 --data_folder=tests/samples/ASR --train_json=tests/samples/annotation/ASR_train.json --valid_json=tests/samples/annotation/ASR_dev.json --test_json=tests/samples/annotation/ASR_dev.json --skip_prep=True --sample_rate=16000 --num_workers 0,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+TTS,LJSpeech,recipes/LJSpeech/TTS/vocoder/hifi_gan/train.py,recipes/LJSpeech/TTS/vocoder/hifi_gan/hparams/train.yaml,recipes/LJSpeech/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,https://www.dropbox.com/sh/m2xrdssiroipn8g/AAD-TqPYLrSg6eNxUkcImeg4a?dl=0,https://huggingface.co/speechbrain/tts-hifigan-ljspeech,--epochs=2 --data_folder=tests/samples/ASR --train_json=tests/samples/annotation/ASR_train.json --valid_json=tests/samples/annotation/ASR_dev.json --test_json=tests/samples/annotation/ASR_dev.json --skip_prep=True --sample_rate=16000,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,samples/1/synthesized.wav,samples/1/target.wav,samples/2/synthesized.wav,samples/2/target.wav]"
+TTS,LJSpeech,recipes/LJSpeech/TTS/vocoder/diffwave/train.py,recipes/LJSpeech/TTS/vocoder/diffwave/hparams/train.yaml,recipes/LJSpeech/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,,,--number_of_epochs=2 --data_folder=tests/samples/ASR --train_json=tests/samples/annotation/ASR_train.json --valid_json=tests/samples/annotation/ASR_dev.json --test_json=tests/samples/annotation/ASR_dev.json --skip_prep=True --sample_rate=16000 --num_workers 0,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+TTS,LJSpeech,recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/train.py,recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/hparams/train.yaml,recipes/LJSpeech/ljspeech_prepare.py,recipes/LJSpeech/TTS/README.md,,,--batch_size=2 --epochs=2 --data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --valid_json=tests/samples/annotation/TTS_train.json --test_json=tests/samples/annotation/TTS_train.json --skip_prep=True --sample_rate=16000 --codes_folder=tests/samples/TTS/codes --kmeans_folder=null,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,samples/1/synthesized.wav,samples/1/target.wav,samples/2/synthesized.wav,samples/2/target.wav]"
+quantization,LJSpeech,recipes/LJSpeech/quantization/train.py,recipes/LJSpeech/quantization/hparams/train_with_hubert.yaml,recipes/LJSpeech/quantization/ljspeech_prepare.py,recipes/LJSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]"
+quantization,LJSpeech,recipes/LJSpeech/quantization/train.py,recipes/LJSpeech/quantization/hparams/train_with_wav2vec.yaml,recipes/LJSpeech/quantization/ljspeech_prepare.py,recipes/LJSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]"
+quantization,LJSpeech,recipes/LJSpeech/quantization/train.py,recipes/LJSpeech/quantization/hparams/train_with_wavlm.yaml,recipes/LJSpeech/quantization/ljspeech_prepare.py,recipes/LJSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]"
diff --git a/tests/recipes/LibriMix.csv b/tests/recipes/LibriMix.csv
index ada1bacf138cc4a43546b440c37236af695dad8e..671562ab01a2ef8f8d22736fa5b7dadec5d8f7d7 100644
--- a/tests/recipes/LibriMix.csv
+++ b/tests/recipes/LibriMix.csv
@@ -1,3 +1,3 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Separation,LibriMix,recipes/LibriMix/separation/train.py,recipes/LibriMix/separation/hparams/sepformer-libri2mix.yaml,recipes/LibriMix/separation/dynamic_mixing.py,recipes/LibriMix/separation/README.md ,https://www.dropbox.com/sh/skkiozml92xtgdo/AAD0eJxgbCTK03kAaILytGtVa?dl=0,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/encoder.ckpt,save/masknet.ckpt,save/decoder.ckpt]"
-Separation,LibriMix,recipes/LibriMix/separation/train.py,recipes/LibriMix/separation/hparams/sepformer-libri3mix.yaml,recipes/LibriMix/separation/dynamic_mixing.py,recipes/LibriMix/separation/README.md ,https://www.dropbox.com/sh/kmyz7tts9tyg198/AACsDcRwKvelXxEB-k5q1OaIa?dl=0,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/encoder.ckpt,save/masknet.ckpt,save/decoder.ckpt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Separation,LibriMix,recipes/LibriMix/separation/train.py,recipes/LibriMix/separation/hparams/sepformer-libri2mix.yaml,recipes/LibriMix/separation/dynamic_mixing.py,recipes/LibriMix/separation/README.md ,https://www.dropbox.com/sh/skkiozml92xtgdo/AAD0eJxgbCTK03kAaILytGtVa?dl=0,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/encoder.ckpt,save/masknet.ckpt,save/decoder.ckpt]",SI-SNR=20.4dB
+Separation,LibriMix,recipes/LibriMix/separation/train.py,recipes/LibriMix/separation/hparams/sepformer-libri3mix.yaml,recipes/LibriMix/separation/dynamic_mixing.py,recipes/LibriMix/separation/README.md ,https://www.dropbox.com/sh/kmyz7tts9tyg198/AACsDcRwKvelXxEB-k5q1OaIa?dl=0,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/encoder.ckpt,save/masknet.ckpt,save/decoder.ckpt]",SI-SNR=19.0dB
diff --git a/tests/recipes/LibriParty.csv b/tests/recipes/LibriParty.csv
index 85c41746fadf4187d99bf69341adfb2d23cde6a6..4f1ce06740df901697c0b30cd48600416371937f 100644
--- a/tests/recipes/LibriParty.csv
+++ b/tests/recipes/LibriParty.csv
@@ -1,2 +1,2 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-VAD,LibriParty,recipes/LibriParty/VAD/train.py,recipes/LibriParty/VAD/hparams/train.yaml,recipes/LibriParty/VAD/libriparty_prepare.py,recipes/LibriParty/VAD/README.md,https://www.dropbox.com/sh/6yguuzn4pybjasd/AABpUF8LAQ8d2TJyC8aK2OBga?dl=0 ,https://huggingface.co/speechbrain/vad-crdnn-libriparty,--open_rir_folder=tests/tmp --data_folder=tests/samples/VAD --skip_prep=True --musan_folder= --commonlanguage_folder= --annotation_train=tests/samples/annotation/VAD_train_root.json --annotation_valid=tests/samples/annotation/VAD_dev_root.json --annotation_test=tests/samples/annotation/VAD_dev_root.json --noise_csv=tests/samples/annotation/noise_paths.csv --music_csv=tests/samples/annotation/noise_paths.csv --speech_csv=tests/samples/annotation/noise_paths.csv --multilang_speech_csv=tests/samples/annotation/noise_paths.csv --example_length=5.79 --N_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+VAD,LibriParty,recipes/LibriParty/VAD/train.py,recipes/LibriParty/VAD/hparams/train.yaml,recipes/LibriParty/VAD/libriparty_prepare.py,recipes/LibriParty/VAD/README.md,https://www.dropbox.com/sh/6yguuzn4pybjasd/AABpUF8LAQ8d2TJyC8aK2OBga?dl=0 ,https://huggingface.co/speechbrain/vad-crdnn-libriparty,--data_folder=tests/samples/VAD --skip_prep=True --musan_folder= --commonlanguage_folder= --annotation_train=tests/samples/annotation/VAD_train_root.json --annotation_valid=tests/samples/annotation/VAD_dev_root.json --annotation_test=tests/samples/annotation/VAD_dev_root.json --noise_csv=tests/samples/annotation/noise_paths.csv --music_csv=tests/samples/annotation/noise_paths.csv --speech_csv=tests/samples/annotation/noise_paths.csv --multilang_speech_csv=tests/samples/annotation/noise_paths.csv --example_length=5.79 --N_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",Test-Precision=0.9518 Test Recall=0.9437 Test F-Score=0.9477
diff --git a/tests/recipes/LibriSpeech.csv b/tests/recipes/LibriSpeech.csv
index a4455ad59ff2fa44a63976234ab5ba8cd08b0f80..bf36c953d0f9fae3bff2151e44b1c6a07187597a 100644
--- a/tests/recipes/LibriSpeech.csv
+++ b/tests/recipes/LibriSpeech.csv
@@ -1,34 +1,45 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,https://www.dropbox.com/sh/qj2ps85g8oiicrj/AAAxlkQw5Pfo0M9EyHMi8iAra?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <3.5, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_sb_wav2vec.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --output_neurons=21 --number_of_epochs=2 --skip_prep=True --wav2vec2_hub=speechbrain/ssl-wav2vec2-base-librispeech,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt,save/extractor.ckpt,save/encoder_wrapper.ckpt]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/seq2seq/train.py,recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000.yaml,recipes/LibriSpeech/ASR/seq2seq/librispeech_prepare.py,recipes/LibriSpeech/ASR/seq2seq/README.md,https://www.dropbox.com/sh/1ycv07gyxdq8hdl/AABUDYzza4SLYtY45RcGf2_0a?dl=0 https://www.dropbox.com/sh/a39wq3h60luv552/AABBnCM2Uf-CNax_cgMWdqDda?dl=0,https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True --data_folder_rirs=tests/tmp,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=18.5, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/seq2seq/train.py,recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml,recipes/LibriSpeech/ASR/seq2seq/librispeech_prepare.py,recipes/LibriSpeech/ASR/seq2seq/README.md,https://www.dropbox.com/sh/1ycv07gyxdq8hdl/AABUDYzza4SLYtY45RcGf2_0a?dl=0 https://www.dropbox.com/sh/a39wq3h60luv552/AABBnCM2Uf-CNax_cgMWdqDda?dl=0,https://huggingface.co/speechbrain/asr-crdnn-transformerlm-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True --data_folder_rirs=tests/tmp,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=18.5, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/seq2seq/train.py,recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000_sligru.yaml,recipes/LibriSpeech/ASR/seq2seq/librispeech_prepare.py,recipes/LibriSpeech/ASR/seq2seq/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True --data_folder_rirs=tests/tmp,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=18.5, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/transducer/train.py,recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml,recipes/LibriSpeech/ASR/transducer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transducer/README.md,https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True --use_torchaudio=True --beam_size=1,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=1000, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/transducer/train.py,recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml,recipes/LibriSpeech/ASR/transducer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transducer/README.md,https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True --use_torchaudio=False --beam_size=1,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=1000, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/conformer_small.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/s0x6ni124858b8i/AAALaCH6sGTMRUVTjh8Tm8Jwa?dl=0,https://huggingface.co/speechbrain/asr-conformersmall-transformerlm-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <350, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/653kq8h2k87md4p/AAByAaAryXtQKpRzYtzV9ih5a?dl=0,https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=350, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 2 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --use_language_modelling=False --ngram_lm_path=empty,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_conv_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 2 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --use_language_modelling=False --ngram_lm_path=empty,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_average_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 2 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --use_language_modelling=False --ngram_lm_path=empty,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_average_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 3 --upsampling=True --ctc_neurons=58 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --use_language_modelling=False --ngram_lm_path=empty,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_conv_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 3 --upsampling=True --ctc_neurons=58 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --use_language_modelling=False --ngram_lm_path=empty,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]"
-ASR,minilibrispeech,templates/speech_recognition/ASR/train.py,templates/speech_recognition/ASR/train.yaml,templates/speech_recognition/ASR/mini_librispeech_prepare.py,templates/speech_recognition/ASR/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --data_folder_rirs=tests/tmp,"file_exists=[env.log,hyperparams.yaml,log.txt,save/lm.ckpt,save/model.ckpt,save/tokenizer.ckpt,train_log.txt,train.py,wer_test.txt,save/CKPT+latest/brain.ckpt,save/CKPT+latest/CKPT.yaml,save/CKPT+latest/counter.ckpt,save/CKPT+latest/dataloader-TRAIN.ckpt,save/CKPT+latest/model.ckpt,save/CKPT+latest/normalizer.ckpt,save/CKPT+latest/optimizer.ckpt,save/CKPT+latest/scheduler.ckpt]"
-Enhancement,minilibrispeech,templates/enhancement/train.py,templates/enhancement/train.yaml,templates/enhancement/mini_librispeech_prepare.py,templates/enhancement/README.md,,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhancement_train.json --valid_annotation=tests/samples/annotation/enhancement_dev.json --test_annotation=tests/samples/annotation/enhancement_dev.json --skip_prep=True --rir_folder=tests/tmp,"file_exists=[env.log,hyperparams.yaml,log.txt,save,train_log.txt,train.py]"
-G2P,LibriSpeech,recipes/LibriSpeech/G2P/train.py,recipes/LibriSpeech/G2P/hparams/hparams_g2p_rnn.yaml,recipes/LibriSpeech/G2P/librispeech_prepare.py,recipes/LibriSpeech/G2P/README.md,https://www.dropbox.com/sh/qmcl1obp8pxqaap/AAC3yXvjkfJ3mL-RKyAUxPdNa?dl=0,,--data_folder=tests/samples/ASR/ --tokenizer_train_data=tests/samples/annotation/ASR_train.json --tokenizer_valid_data=tests/samples/annotation/ASR_dev.json --lexicon_epochs=2 --skip_prep=True --phn_token_output=42 --use_tensorboard=False --debug,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/wer_sentence.txt,save/homograph_stats.txt,save/wer_homograph.txt,save/wer_lexicon.txt,save/pretrained_models/ctc_lin.ckpt,save/pretrained_models/model.ckpt,reports/lexicon/1/wer_lexicon.txt,reports/lexicon/2/wer_lexicon.txt,reports/homograph/1/homograph_stats.txt,reports/homograph/1/wer_homograph.txt,reports/homograph/2/homograph_stats.txt,reports/homograph/2/wer_homograph.txt,reports/sentence/1/wer_sentence.txt,reports/sentence/2/wer_sentence.txt]"
-G2P,LibriSpeech,recipes/LibriSpeech/G2P/train.py,recipes/LibriSpeech/G2P/hparams/hparams_g2p_transformer.yaml,recipes/LibriSpeech/G2P/librispeech_prepare.py,recipes/LibriSpeech/G2P/README.md,https://www.dropbox.com/sh/zhrxg7anuhje7e8/AADTeJtdsja_wClkE2DsF9Ewa?dl=0,,--data_folder=tests/samples/ASR/ --tokenizer_train_data=tests/samples/annotation/ASR_train.json --tokenizer_valid_data=tests/samples/annotation/ASR_dev.json --lexicon_epochs=2 --skip_prep=True --phn_token_output=42 --use_tensorboard=False --debug,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/wer_sentence.txt,save/homograph_stats.txt,save/wer_homograph.txt,save/wer_lexicon.txt,save/pretrained_models/ctc_lin.ckpt,save/pretrained_models/model.ckpt,reports/lexicon/1/wer_lexicon.txt,reports/lexicon/2/wer_lexicon.txt,reports/homograph/1/homograph_stats.txt,reports/homograph/1/wer_homograph.txt,reports/homograph/2/homograph_stats.txt,reports/homograph/2/wer_homograph.txt,reports/sentence/1/wer_sentence.txt,reports/sentence/2/wer_sentence.txt]"
-G2P,LibriSpeech,recipes/LibriSpeech/G2P/train_lm.py,recipes/LibriSpeech/G2P/hparams/hparams_lm_rnn.yaml,recipes/LibriSpeech/G2P/librispeech_prepare.py,recipes/LibriSpeech/G2P/README.md,https://www.dropbox.com/sh/pig0uk80xxii7cg/AACQ1rrRLYthvpNZ5FadPLtRa?dl=0,,--data_folder=tests/samples/ASR/ --tokenizer_train_data=tests/samples/annotation/ASR_train.json --tokenizer_valid_data=tests/samples/annotation/ASR_dev.json --train_data=tests/samples/annotation/ASR_train.json --valid_data=tests/samples/annotation/ASR_dev.json --test_data=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --phn_token_output=42,"file_exists=[train_lm.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/tokenizer_annotation_train.json,save/tokenizer_annotation_valid.json,save/tokenizer_annotation_test.json,save/phoneme_annotations.txt,save/phoneme_tokenizer/42_unigram.model,save/phoneme_tokenizer/42_unigram.vocab]"
-G2P,LibriSpeech,recipes/LibriSpeech/G2P/train_lm.py,recipes/LibriSpeech/G2P/hparams/hparams_lm_transformer.yaml,recipes/LibriSpeech/G2P/librispeech_prepare.py,recipes/LibriSpeech/G2P/README.md,https://www.dropbox.com/sh/tkf6di10edpz4i6/AAArnGAkE0bEEOvOGfc6KWuma?dl=0,,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.json --valid_data=tests/samples/annotation/ASR_dev.json --test_data=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --emb_dim=64 --debug,"file_exists=[train_lm.py,train_log.txt,log.txt,env.log,hyperparams.yaml]"
-LM,LibriSpeech,recipes/LibriSpeech/LM/train.py,recipes/LibriSpeech/LM/hparams/RNNLM.yaml,recipes/LibriSpeech/LM/librispeech_prepare.py,recipes/LibriSpeech/LM/README.md,https://www.dropbox.com/sh/8xpybezuv70ibcg/AAByv2NuNv_ZFXuDdG89-MVPa?dl=0 https://www.dropbox.com/sh/8462ef441wvava2/AABNfHr07J_0SsdaM1yO5qkxa?dl=0 https://www.dropbox.com/sh/6uwqlw2tvv3kiy6/AACgvTR5jihyMrugBrpZPFNha?dl=0,,--data_folder=tests/samples/annotation/ --lm_corpus_path=tests/samples/annotation/LM_train.txt.gz --train_transcripts_pattern=LM_train.txt --dev_transcripts_pattern=LM_dev.txt --test_transcripts_pattern=LM_dev.txt --number_of_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt]"
-LM,LibriSpeech,recipes/LibriSpeech/LM/train.py,recipes/LibriSpeech/LM/hparams/transformer.yaml,recipes/LibriSpeech/LM/librispeech_prepare.py,recipes/LibriSpeech/LM/README.md,https://www.dropbox.com/sh/8xpybezuv70ibcg/AAByv2NuNv_ZFXuDdG89-MVPa?dl=0 https://www.dropbox.com/sh/8462ef441wvava2/AABNfHr07J_0SsdaM1yO5qkxa?dl=0 https://www.dropbox.com/sh/6uwqlw2tvv3kiy6/AACgvTR5jihyMrugBrpZPFNha?dl=0,,--data_folder=tests/samples/annotation/ --lm_corpus_path=tests/samples/annotation/LM_train.txt.gz --train_transcripts_pattern=LM_train.txt --dev_transcripts_pattern=LM_dev.txt --test_transcripts_pattern=LM_dev.txt --number_of_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt]"
-Tokenizer,minilibrispeech,templates/speech_recognition/Tokenizer/train.py,templates/speech_recognition/Tokenizer/tokenizer.yaml,templates/speech_recognition/Tokenizer/mini_librispeech_prepare.py,templates/speech_recognition/Tokenizer/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --skip_prep=True --token_output=24 --annotation_read=wrd,"file_exists=[env.log,hyperparams.yaml,log.txt,train.py,24_unigram.model,24_unigram.vocab,ASR_train.txt]"
-LM,minilibrispeech,templates/speech_recognition/LM/train.py,templates/speech_recognition/LM/RNNLM.yaml,templates/speech_recognition/mini_librispeech_prepare.py,templates/speech_recognition/README.md,,,--data_folder=tests/samples/ASR/ --lm_train_data=tests/samples/annotation/LM_train.txt --lm_valid_data=tests/samples/annotation/LM_dev.txt --lm_test_data=tests/samples/annotation/LM_dev.txt --number_of_epochs=2 --emb_dim=64 --rnn_size=128 --tokenizer_file=tests/tmp/LibriSpeech_row_24/24_unigram.model,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py]"
-Speaker_recognition,minilibrispeech,templates/hyperparameter_optimization_speaker_id/train.py,templates/hyperparameter_optimization_speaker_id/train.yaml,templates/hyperparameter_optimization_speaker_id/mini_librispeech_prepare.py,templates/hyperparameter_optimization_speaker_id/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --n_classes=2 --emb_dim=42 --skip_prep=True --rir_folder=tests/tmp,"file_exists=[env.log,hyperparams.yaml,log.txt,save,train_log.txt,train.py,save/label_encoder.txt]"
-Speaker_recognition,minilibrispeech,templates/speaker_id/train.py,templates/speaker_id/train.yaml,templates/speaker_id/mini_librispeech_prepare.py,templates/speaker_id/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --n_classes=2 --emb_dim=42 --skip_prep=True --rir_folder=tests/tmp,"file_exists=[env.log,hyperparams.yaml,log.txt,save,train_log.txt,train.py,save/label_encoder.txt]"
-Tokenizer,LibriSpeech,recipes/LibriSpeech/Tokenizer/train.py,recipes/LibriSpeech/Tokenizer/hparams/1K_unigram_subword_bpe.yaml,recipes/LibriSpeech/Tokenizer/librispeech_prepare.py,recipes/LibriSpeech/Tokenizer/README.md,https://www.dropbox.com/sh/xyifwhyq2o7g8u8/AACVHHgXUsRUZIfrzHOccLP7a?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]"
-Tokenizer,LibriSpeech,recipes/LibriSpeech/Tokenizer/train.py,recipes/LibriSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml,recipes/LibriSpeech/Tokenizer/librispeech_prepare.py,recipes/LibriSpeech/Tokenizer/README.md,https://www.dropbox.com/sh/xyifwhyq2o7g8u8/AACVHHgXUsRUZIfrzHOccLP7a?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]"
-self-supervised-learning,LibriSpeech,recipes/LibriSpeech/self-supervised-learning/wav2vec2/train_sb_wav2vec2.py,recipes/LibriSpeech/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml,recipes/LibriSpeech/self-supervised-learning/wav2vec2/librispeech_prepare.py,recipes/LibriSpeech/self-supervised-learning/wav2vec2/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_stage_log.txt,train_sb_wav2vec2.py,log.txt,env.log,hyperparams.yaml]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_whisper.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_whisper_encoder.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --skip_prep=True --whisper_folder=tests/tmp/whisper_checkpoint,"file_exists=[train_with_whisper.py,wer_ASR_train.txt,train_log.txt,log.txt,env.log,hyperparams.yaml,save/29_char.vocab,save/29_char.model,save/ASR_train.txt]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train_with_whisper.py,recipes/LibriSpeech/ASR/transformer/hparams/train_hf_whisper.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --skip_prep=True --whisper_folder=tests/tmp/whisper_checkpoint,"file_exists=[train_with_whisper.py,wer_ASR_train.txt,train_log.txt,log.txt,env.log,hyperparams.yaml]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/conformer_large.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]"
-ASR,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/branchformer_large.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,https://www.dropbox.com/sh/qj2ps85g8oiicrj/AAAxlkQw5Pfo0M9EyHMi8iAra?dl=0,https://huggingface.co/speechbrain/asr-wav2vec2-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <3.5, epoch: 10]",Test_clean-WER=1.65% Test_other-WER=3.67%
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_transformer_rescoring.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,https://www.dropbox.com/sh/ijqalvre7mm08ng/AAD_hsN-8dBneUMMkELsOOxga?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --output_neurons=21 --number_of_epochs=2 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]",Test_clean-WER=1.57% Test_other-WER=3.37%
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_rnn_rescoring.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,https://www.dropbox.com/sh/k4ixa211yp5b1tm/AAD85sgYw2CH7NKk_qKMO9Tja?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --output_neurons=21 --number_of_epochs=2 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]",
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_whisper.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_whisper_encoder.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,https://www.dropbox.com/sh/zmtp13huxn02fot/AADyKL5q0MwRhEG1-WbSXDWda?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --skip_prep=True --whisper_folder=tests/tmp/whisper_checkpoint,"file_exists=[train_with_whisper.py,wer_ASR_train.txt,train_log.txt,log.txt,env.log,hyperparams.yaml,save/29_char.vocab,save/29_char.model,save/ASR_train.txt]",
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_sb_wav2vec.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --output_neurons=21 --number_of_epochs=2 --skip_prep=True --wav2vec2_hub=speechbrain/ssl-wav2vec2-base-librispeech,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt,save/extractor.ckpt,save/encoder_wrapper.ckpt]",
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 2 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]",
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_conv_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 2 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]",
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_average_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 2 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]",
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_average_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 3 --upsampling=True --ctc_neurons=58 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]",
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/downsampled/train_hf_wavlm_conv_downsampling.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--downsampling_factor 3 --upsampling=True --ctc_neurons=58 --data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,train_with_wav2vec.py,env.log,hyperparams.yaml,save/label_encoder.txt]",
+ASR-Seq2Seq,LibriSpeech,recipes/LibriSpeech/ASR/seq2seq/train.py,recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000.yaml,recipes/LibriSpeech/ASR/seq2seq/librispeech_prepare.py,recipes/LibriSpeech/ASR/seq2seq/README.md,https://www.dropbox.com/sh/1ycv07gyxdq8hdl/AABUDYzza4SLYtY45RcGf2_0a?dl=0 https://www.dropbox.com/sh/a39wq3h60luv552/AABBnCM2Uf-CNax_cgMWdqDda?dl=0,https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=18.5, epoch: 10]",
+ASR-Seq2Seq,LibriSpeech,recipes/LibriSpeech/ASR/seq2seq/train.py,recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml,recipes/LibriSpeech/ASR/seq2seq/librispeech_prepare.py,recipes/LibriSpeech/ASR/seq2seq/README.md,https://www.dropbox.com/sh/1ycv07gyxdq8hdl/AABUDYzza4SLYtY45RcGf2_0a?dl=0,https://huggingface.co/speechbrain/asr-crdnn-transformerlm-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=18.5, epoch: 10]",Test_clean-WER=2.89% Test_other-WER=8.09%
+ASR-Seq2Seq,LibriSpeech,recipes/LibriSpeech/ASR/seq2seq/train.py,recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000_sligru.yaml,recipes/LibriSpeech/ASR/seq2seq/librispeech_prepare.py,recipes/LibriSpeech/ASR/seq2seq/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=18.5, epoch: 10]",
+ASR-Transducers,LibriSpeech,recipes/LibriSpeech/ASR/transducer/train.py,recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml,recipes/LibriSpeech/ASR/transducer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transducer/README.md,https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True --beam_size=1,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=1000, epoch: 10]",Test_clean-WER=2.72% Test_other-WER=6.47%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/conformer_small.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/s0x6ni124858b8i/AAALaCH6sGTMRUVTjh8Tm8Jwa?dl=0,https://huggingface.co/speechbrain/asr-conformersmall-transformerlm-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <350, epoch: 10]",Test_clean-WER=2.49% Test_other-WER=6.10%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/653kq8h2k87md4p/AAByAaAryXtQKpRzYtzV9ih5a?dl=0,https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <=350, epoch: 10]",Test_clean-WER=2.27% Test_other-WER=5.53%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/conformer_large.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/ef3chrau8i45ip1/AAD9un8oabOB1a9OiSomZEhZa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]",Test_clean-WER=2.01% Test_other-WER=4.52%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/branchformer_large.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/gxkye4efa6hvl2c/AADO85EkkfbIGe5KjBAU6BrEa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[wer_ASR_train.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt,save/lm.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]",Test_clean-WER=2.04% Test_other-WER=4.12%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_22M.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/gxkye4efa6hvl2c/AADO85EkkfbIGe5KjBAU6BrEa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_ASR_train.txt,save/lm.ckpt,save/tokenizer.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]",Test_clean-WER=2.23% Test_other-WER=4.54%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/hyperconformer_8M.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/gxkye4efa6hvl2c/AADO85EkkfbIGe5KjBAU6BrEa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_ASR_train.txt,save/lm.ckpt,save/tokenizer.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]",Test_clean-WER=2.55% Test_other-WER=6.61%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_25M.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_ASR_train.txt,save/lm.ckpt,save/tokenizer.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]",Test_clean-WER=2.36% Test_other-WER=6.89%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/hyperbranchformer_13M.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_ASR_train.txt,save/lm.ckpt,save/tokenizer.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]",Test_clean-WER=2.54% Test_other-WER=6.58%
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train_with_whisper.py,recipes/LibriSpeech/ASR/transformer/hparams/train_hf_whisper.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --skip_prep=True --whisper_folder=tests/tmp/whisper_checkpoint,"file_exists=[train_with_whisper.py,wer_ASR_train.txt,train_log.txt,log.txt,env.log,hyperparams.yaml]",-
+ASR-Transformers,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train_bayesspeech.py,recipes/LibriSpeech/ASR/transformer/hparams/bayesspeech.yaml,recipes/LibriSpeech/ASR/transformer/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/scl/fo/cdken4jqfj96ev1v84jxm/h?rlkey=25eu1ytgm5ac51zqj8p65zwxd&dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train_bayesspeech.py,wer_ASR_train.txt,save/lm.ckpt,save/tokenizer.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]",Test_clean-WER=2.84% Test_other-WER=6.27%
+ASR,minilibrispeech,templates/speech_recognition/ASR/train.py,templates/speech_recognition/ASR/train.yaml,templates/speech_recognition/ASR/mini_librispeech_prepare.py,templates/speech_recognition/ASR/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,save/lm.ckpt,save/model.ckpt,save/tokenizer.ckpt,train_log.txt,train.py,wer_test.txt,save/CKPT+latest/brain.ckpt,save/CKPT+latest/CKPT.yaml,save/CKPT+latest/counter.ckpt,save/CKPT+latest/dataloader-TRAIN.ckpt,save/CKPT+latest/model.ckpt,save/CKPT+latest/normalizer.ckpt,save/CKPT+latest/optimizer.ckpt,save/CKPT+latest/scheduler.ckpt]",
+Enhancement,minilibrispeech,templates/enhancement/train.py,templates/enhancement/train.yaml,templates/enhancement/mini_librispeech_prepare.py,templates/enhancement/README.md,,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhancement_train.json --valid_annotation=tests/samples/annotation/enhancement_dev.json --test_annotation=tests/samples/annotation/enhancement_dev.json --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,save,train_log.txt,train.py]",
+G2P,LibriSpeech,recipes/LibriSpeech/G2P/train.py,recipes/LibriSpeech/G2P/hparams/hparams_g2p_rnn.yaml,recipes/LibriSpeech/G2P/librispeech_prepare.py,recipes/LibriSpeech/G2P/README.md,https://www.dropbox.com/sh/qmcl1obp8pxqaap/AAC3yXvjkfJ3mL-RKyAUxPdNa?dl=0,,--data_folder=tests/samples/ASR/ --tokenizer_train_data=tests/samples/annotation/ASR_train.json --tokenizer_valid_data=tests/samples/annotation/ASR_dev.json --lexicon_epochs=2 --skip_prep=True --phn_token_output=42 --use_tensorboard=False --debug,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/wer_sentence.txt,save/homograph_stats.txt,save/wer_homograph.txt,save/wer_lexicon.txt,save/pretrained_models/ctc_lin.ckpt,save/pretrained_models/model.ckpt,reports/lexicon/1/wer_lexicon.txt,reports/lexicon/2/wer_lexicon.txt,reports/homograph/1/homograph_stats.txt,reports/homograph/1/wer_homograph.txt,reports/homograph/2/homograph_stats.txt,reports/homograph/2/wer_homograph.txt,reports/sentence/1/wer_sentence.txt,reports/sentence/2/wer_sentence.txt]",PER-Test=2.72%
+G2P,LibriSpeech,recipes/LibriSpeech/G2P/train.py,recipes/LibriSpeech/G2P/hparams/hparams_g2p_transformer.yaml,recipes/LibriSpeech/G2P/librispeech_prepare.py,recipes/LibriSpeech/G2P/README.md,https://www.dropbox.com/sh/zhrxg7anuhje7e8/AADTeJtdsja_wClkE2DsF9Ewa?dl=0,https://huggingface.co/speechbrain/soundchoice-g2p,--data_folder=tests/samples/ASR/ --tokenizer_train_data=tests/samples/annotation/ASR_train.json --tokenizer_valid_data=tests/samples/annotation/ASR_dev.json --lexicon_epochs=2 --skip_prep=True --phn_token_output=42 --use_tensorboard=False --debug,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/wer_sentence.txt,save/homograph_stats.txt,save/wer_homograph.txt,save/wer_lexicon.txt,save/pretrained_models/ctc_lin.ckpt,save/pretrained_models/model.ckpt,reports/lexicon/1/wer_lexicon.txt,reports/lexicon/2/wer_lexicon.txt,reports/homograph/1/homograph_stats.txt,reports/homograph/1/wer_homograph.txt,reports/homograph/2/homograph_stats.txt,reports/homograph/2/wer_homograph.txt,reports/sentence/1/wer_sentence.txt,reports/sentence/2/wer_sentence.txt]",PER-Test=2.89%
+G2P,LibriSpeech,recipes/LibriSpeech/G2P/train_lm.py,recipes/LibriSpeech/G2P/hparams/hparams_lm_rnn.yaml,recipes/LibriSpeech/G2P/librispeech_prepare.py,recipes/LibriSpeech/G2P/README.md,https://www.dropbox.com/sh/pig0uk80xxii7cg/AACQ1rrRLYthvpNZ5FadPLtRa?dl=0,,--data_folder=tests/samples/ASR/ --tokenizer_train_data=tests/samples/annotation/ASR_train.json --tokenizer_valid_data=tests/samples/annotation/ASR_dev.json --train_data=tests/samples/annotation/ASR_train.json --valid_data=tests/samples/annotation/ASR_dev.json --test_data=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --phn_token_output=42,"file_exists=[train_lm.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/tokenizer_annotation_train.json,save/tokenizer_annotation_valid.json,save/tokenizer_annotation_test.json,save/phoneme_annotations.txt,save/phoneme_tokenizer/42_unigram.model,save/phoneme_tokenizer/42_unigram.vocab]",
+G2P,LibriSpeech,recipes/LibriSpeech/G2P/train_lm.py,recipes/LibriSpeech/G2P/hparams/hparams_lm_transformer.yaml,recipes/LibriSpeech/G2P/librispeech_prepare.py,recipes/LibriSpeech/G2P/README.md,https://www.dropbox.com/sh/tkf6di10edpz4i6/AAArnGAkE0bEEOvOGfc6KWuma?dl=0,,--data_folder=tests/samples/ASR/ --train_data=tests/samples/annotation/ASR_train.json --valid_data=tests/samples/annotation/ASR_dev.json --test_data=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --emb_dim=64 --debug,"file_exists=[train_lm.py,train_log.txt,log.txt,env.log,hyperparams.yaml]",
+LM,LibriSpeech,recipes/LibriSpeech/LM/train.py,recipes/LibriSpeech/LM/hparams/RNNLM.yaml,recipes/LibriSpeech/LM/librispeech_prepare.py,recipes/LibriSpeech/LM/README.md,https://www.dropbox.com/sh/8xpybezuv70ibcg/AAByv2NuNv_ZFXuDdG89-MVPa?dl=0 https://www.dropbox.com/sh/8462ef441wvava2/AABNfHr07J_0SsdaM1yO5qkxa?dl=0 https://www.dropbox.com/sh/6uwqlw2tvv3kiy6/AACgvTR5jihyMrugBrpZPFNha?dl=0,,--data_folder=tests/samples/annotation/ --lm_corpus_path=tests/samples/annotation/LM_train.txt.gz --train_transcripts_pattern=LM_train.txt --dev_transcripts_pattern=LM_dev.txt --test_transcripts_pattern=LM_dev.txt --number_of_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt]",
+LM,LibriSpeech,recipes/LibriSpeech/LM/train.py,recipes/LibriSpeech/LM/hparams/transformer.yaml,recipes/LibriSpeech/LM/librispeech_prepare.py,recipes/LibriSpeech/LM/README.md,https://www.dropbox.com/sh/8xpybezuv70ibcg/AAByv2NuNv_ZFXuDdG89-MVPa?dl=0 https://www.dropbox.com/sh/8462ef441wvava2/AABNfHr07J_0SsdaM1yO5qkxa?dl=0 https://www.dropbox.com/sh/6uwqlw2tvv3kiy6/AACgvTR5jihyMrugBrpZPFNha?dl=0,,--data_folder=tests/samples/annotation/ --lm_corpus_path=tests/samples/annotation/LM_train.txt.gz --train_transcripts_pattern=LM_train.txt --dev_transcripts_pattern=LM_dev.txt --test_transcripts_pattern=LM_dev.txt --number_of_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizer.ckpt]",
+Tokenizer,minilibrispeech,templates/speech_recognition/Tokenizer/train.py,templates/speech_recognition/Tokenizer/tokenizer.yaml,templates/speech_recognition/Tokenizer/mini_librispeech_prepare.py,templates/speech_recognition/Tokenizer/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --skip_prep=True --token_output=24 --annotation_read=wrd,"file_exists=[env.log,hyperparams.yaml,log.txt,train.py,24_unigram.model,24_unigram.vocab,ASR_train.txt]",
+LM,minilibrispeech,templates/speech_recognition/LM/train.py,templates/speech_recognition/LM/RNNLM.yaml,templates/speech_recognition/mini_librispeech_prepare.py,templates/speech_recognition/README.md,,,--data_folder=tests/samples/ASR/ --lm_train_data=tests/samples/annotation/LM_train.txt --lm_valid_data=tests/samples/annotation/LM_dev.txt --lm_test_data=tests/samples/annotation/LM_dev.txt --number_of_epochs=2 --emb_dim=64 --rnn_size=128 --tokenizer_file=tests/tmp/LibriSpeech_row_34/24_unigram.model,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py]",
+Speaker_recognition,minilibrispeech,templates/hyperparameter_optimization_speaker_id/train.py,templates/hyperparameter_optimization_speaker_id/train.yaml,templates/hyperparameter_optimization_speaker_id/mini_librispeech_prepare.py,templates/hyperparameter_optimization_speaker_id/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --n_classes=2 --emb_dim=42 --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,save,train_log.txt,train.py,save/label_encoder.txt]",
+Speaker_recognition,minilibrispeech,templates/speaker_id/train.py,templates/speaker_id/train.yaml,templates/speaker_id/mini_librispeech_prepare.py,templates/speaker_id/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --n_classes=2 --emb_dim=42 --skip_prep=True,"file_exists=[env.log,hyperparams.yaml,log.txt,save,train_log.txt,train.py,save/label_encoder.txt]",
+Tokenizer,LibriSpeech,recipes/LibriSpeech/Tokenizer/train.py,recipes/LibriSpeech/Tokenizer/hparams/1K_unigram_subword_bpe.yaml,recipes/LibriSpeech/Tokenizer/librispeech_prepare.py,recipes/LibriSpeech/Tokenizer/README.md,https://www.dropbox.com/sh/xyifwhyq2o7g8u8/AACVHHgXUsRUZIfrzHOccLP7a?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]",
+Tokenizer,LibriSpeech,recipes/LibriSpeech/Tokenizer/train.py,recipes/LibriSpeech/Tokenizer/hparams/5K_unigram_subword_bpe.yaml,recipes/LibriSpeech/Tokenizer/librispeech_prepare.py,recipes/LibriSpeech/Tokenizer/README.md,https://www.dropbox.com/sh/xyifwhyq2o7g8u8/AACVHHgXUsRUZIfrzHOccLP7a?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train.txt,env.log,train.py,hyperparams.yaml]",
+self-supervised-learning,LibriSpeech,recipes/LibriSpeech/self-supervised-learning/wav2vec2/train_sb_wav2vec2.py,recipes/LibriSpeech/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml,recipes/LibriSpeech/self-supervised-learning/wav2vec2/librispeech_prepare.py,recipes/LibriSpeech/self-supervised-learning/wav2vec2/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_stage_log.txt,train_sb_wav2vec2.py,log.txt,env.log,hyperparams.yaml]",
+quantization,LibriSpeech,recipes/LibriSpeech/quantization/train.py,recipes/LibriSpeech/quantization/hparams/train_with_hubert.yaml,recipes/LibriSpeech/quantization/librispeech_prepare.py,recipes/LibriSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
+quantization,LibriSpeech,recipes/LibriSpeech/quantization/train.py,recipes/LibriSpeech/quantization/hparams/train_with_wav2vec.yaml,recipes/LibriSpeech/quantization/librispeech_prepare.py,recipes/LibriSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
+quantization,LibriSpeech,recipes/LibriSpeech/quantization/train.py,recipes/LibriSpeech/quantization/hparams/train_with_wavlm.yaml,recipes/LibriSpeech/quantization/librispeech_prepare.py,recipes/LibriSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]",
+ASR-CTC,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_k2.yaml,recipes/LibriSpeech/ASR/CTC/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,,,--data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2,"file_exists=[metric_ASR_train/wer_HL_1best.txt,train_log.txt,log.txt,train_with_wav2vec_k2.py,env.log,hyperparams.yaml]",
+LM,LibriSpeech,recipes/LibriSpeech/LM/train_ngram.py,recipes/LibriSpeech/LM/hparams/train_ngram.yaml,recipes/LibriSpeech/LM/librispeech_prepare.py,recipes/LibriSpeech/LM/README.md,,,--data_folder=tests/samples/ASR/ --skip_prep=True --train_csv=tests/samples/annotation/ASR_train.csv,"file_exists=[env.log,hyperparams.yaml,log.txt,lang/words.txt,libri_lm_corpus.txt,train_ngram.py]",
diff --git a/tests/recipes/LibriTTS.csv b/tests/recipes/LibriTTS.csv
index afa44648e34403ae69e6e6483824f31869178f79..8bebec42ff433af9a2b2737e556e11eaa60ce1b3 100644
--- a/tests/recipes/LibriTTS.csv
+++ b/tests/recipes/LibriTTS.csv
@@ -1,2 +1,3 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-TTS,LibriTTS,recipes/LibriTTS/vocoder/hifigan/train.py,recipes/LibriTTS/vocoder/hifigan/hparams/train.yaml,recipes/LibriTTS/libritts_prepare.py,recipes/LibriTTS/README.md,https://www.dropbox.com/sh/gjs1kslxkxz819q/AABPriN4dOoD1qL7NoIyVk0Oa?dl=0 ,https://huggingface.co/speechbrain/tts-hifigan-libritts-16kHz,--epochs=2 --data_folder=tests/samples/ASR --train_json=tests/samples/annotation/ASR_train.json --valid_json=tests/samples/annotation/ASR_dev.json --test_json=tests/samples/annotation/ASR_dev.json --skip_prep=True --sample_rate=16000,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
+TTS,LibriTTS,recipes/LibriTTS/vocoder/hifigan/train.py,recipes/LibriTTS/vocoder/hifigan/hparams/train.yaml,recipes/LibriTTS/libritts_prepare.py,recipes/LibriTTS/README.md,https://www.dropbox.com/sh/gjs1kslxkxz819q/AABPriN4dOoD1qL7NoIyVk0Oa?dl=0 ,https://huggingface.co/speechbrain/tts-hifigan-libritts-16kHz,--epochs=2 --data_folder=tests/samples/ASR --train_json=tests/samples/annotation/ASR_train.json --valid_json=tests/samples/annotation/ASR_dev.json --test_json=tests/samples/annotation/ASR_dev.json --skip_prep=True --sample_rate=16000,
+TTS,LibriTTS,recipes/LibriTTS/TTS/mstacotron2/train.py,recipes/LibriTTS/TTS/mstacotron2/hparams/train.yaml,recipes/LibriTTS/libritts_prepare.py,recipes/LibriTTS/README.md,https://www.dropbox.com/sh/ti2vk7sce8f9fgd/AABcDGWCrBvLX_ZQs76mlJRYa?dl=0,,--batch_size=1 --epochs=2 --data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --valid_json=tests/samples/annotation/TTS_train.json --test_json=tests/samples/annotation/TTS_train.json --skip_prep=True --sample_rate=16000,
diff --git a/tests/recipes/MEDIA.csv b/tests/recipes/MEDIA.csv
index 0bb6ab9593834d26a41cb12e3a93c0781b941db7..0be4a41617f0dc7dacb1e6d1ec8ca3f33d3e6bd9 100644
--- a/tests/recipes/MEDIA.csv
+++ b/tests/recipes/MEDIA.csv
@@ -1,4 +1,4 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-SLU,MEDIA,recipes/MEDIA/SLU/CTC/train_hf_wav2vec.py,recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_full.yaml,recipes/MEDIA/media_prepare.py,recipes/MEDIA/SLU/CTC/README.md,,,--data_folder=tests/samples/ASR/ --channels_path=Null --concepts_path=Null --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_hf_wav2vec.py,cver_test.txt,coer_test.txt,train_log.txt,cer_test.txt,log.txt,env.log,ctc_test.txt,hyperparams.yaml,save/labelencoder.txt]"
-SLU,MEDIA,recipes/MEDIA/SLU/CTC/train_hf_wav2vec.py,recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_relax.yaml,recipes/MEDIA/media_prepare.py,recipes/MEDIA/SLU/CTC/README.md,,,--data_folder=tests/samples/ASR/ --channels_path=Null --concepts_path=Null --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_hf_wav2vec.py,cver_test.txt,coer_test.txt,train_log.txt,cer_test.txt,log.txt,env.log,ctc_test.txt,hyperparams.yaml,save/labelencoder.txt]"
-ASR,MEDIA,recipes/MEDIA/ASR/CTC/train_hf_wav2vec.py,recipes/MEDIA/ASR/CTC/hparams/train_hf_wav2vec.yaml,recipes/MEDIA/media_prepare.py,recipes/MEDIA/ASR/CTC/README.md,,,--data_folder=tests/samples/ASR/ --channels_path=Null --concepts_path=Null --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_hf_wav2vec.py,train_log.txt,cer_test.txt,log.txt,env.log,ctc_test.txt,hyperparams.yaml,save/labelencoder.txt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+SLU,MEDIA,recipes/MEDIA/SLU/CTC/train_hf_wav2vec.py,recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_full.yaml,recipes/MEDIA/media_prepare.py,recipes/MEDIA/SLU/CTC/README.md,,https://huggingface.co/speechbrain/slu-wav2vec2-ctc-MEDIA-relax,--data_folder=tests/samples/ASR/ --channels_path=Null --concepts_path=Null --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_hf_wav2vec.py,cver_test.txt,coer_test.txt,train_log.txt,cer_test.txt,log.txt,env.log,ctc_test.txt,hyperparams.yaml,save/labelencoder.txt]",Test-ChER=7.46% Test-CER=20.10% Test-CVER=31.41%
+SLU,MEDIA,recipes/MEDIA/SLU/CTC/train_hf_wav2vec.py,recipes/MEDIA/SLU/CTC/hparams/train_hf_wav2vec_relax.yaml,recipes/MEDIA/media_prepare.py,recipes/MEDIA/SLU/CTC/README.md,,https://huggingface.co/speechbrain/slu-wav2vec2-ctc-MEDIA-full,--data_folder=tests/samples/ASR/ --channels_path=Null --concepts_path=Null --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_hf_wav2vec.py,cver_test.txt,coer_test.txt,train_log.txt,cer_test.txt,log.txt,env.log,ctc_test.txt,hyperparams.yaml,save/labelencoder.txt]",Test-ChER=7.78% Test-CER=24.88% Test-CVER=35.77%
+ASR,MEDIA,recipes/MEDIA/ASR/CTC/train_hf_wav2vec.py,recipes/MEDIA/ASR/CTC/hparams/train_hf_wav2vec.yaml,recipes/MEDIA/media_prepare.py,recipes/MEDIA/ASR/CTC/README.md,,https://huggingface.co/speechbrain/asr-wav2vec2-ctc-MEDIA,--data_folder=tests/samples/ASR/ --channels_path=Null --concepts_path=Null --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_hf_wav2vec.py,train_log.txt,cer_test.txt,log.txt,env.log,ctc_test.txt,hyperparams.yaml,save/labelencoder.txt]",Test-ChER=7.78% Test-CER=4.78%
diff --git a/tests/recipes/MultiWOZ.csv b/tests/recipes/MultiWOZ.csv
new file mode 100644
index 0000000000000000000000000000000000000000..9a30cc639ed6d5b26eaf625a40b3038245c3da48
--- /dev/null
+++ b/tests/recipes/MultiWOZ.csv
@@ -0,0 +1,3 @@
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Response-Generation,MultiWOZ,recipes/MultiWOZ/response_generation/gpt/train_with_gpt.py,recipes/MultiWOZ/response_generation/gpt/hparams/train_gpt.yaml,recipes/MultiWOZ/response_generation/gpt/multiwoz_prepare.py,recipes/MultiWOZ/response_generation/README.md,https://www.dropbox.com/sh/vm8f5iavohr4zz9/AACrkOxXuxsrvJy4Cjpih9bQa?dl=0,https://huggingface.co/speechbrain/MultiWOZ-GPT-Response_Generation,--data_folder=tests/samples/ASR/  --train_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --valid_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --test_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --number_of_epochs=2 --skip_prep=True,,Test-PPL=4.01 Test_BLEU-4=2.54e-04
+Response-Generation,MultiWOZ,recipes/MultiWOZ/response_generation/llama2/train_with_llama2.py,recipes/MultiWOZ/response_generation/llama2/hparams/train_llama2.yaml,recipes/MultiWOZ/response_generation/llama2/multiwoz_prepare.py,recipes/MultiWOZ/response_generation/README.md,https://www.dropbox.com/sh/d093vsje1d7ijj9/AAA-nHEd_MwNEFJfBGLmXxJra?dl=0,https://huggingface.co/speechbrain/MultiWOZ-Llama2-Response_Generation,--data_folder=tests/samples/ASR/  --train_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --valid_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --test_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --number_of_epochs=2 --skip_prep=True,,Test-PPL=2.90 Test_BLEU-4=7.45e-04
diff --git a/tests/recipes/README.md b/tests/recipes/README.md
index e255c79394f6ddf28adade0e22b9c348c42643a6..3ef63ed354f8188c24340e15cc867d9ed19ff92e 100644
--- a/tests/recipes/README.md
+++ b/tests/recipes/README.md
@@ -6,7 +6,7 @@ Key point: **specify `test_debug_flags` and make sure testing your recipe works
 
 For GPU testing, install all extra requirements:
 ```
-find recipes | grep extra | xargs cat | sort -u | grep -v \# | xargs -I {} pip install {}
+find recipes | grep extra_requirements.txt | xargs cat | sort -u | grep -v \# | xargs -I {} pip install {}
 ```
 
 ---
@@ -16,6 +16,18 @@ If you like to test for all recipes belonging to one dataset:
 python -c 'from tests.utils.recipe_tests import run_recipe_tests; print("TEST FAILED!") if not(run_recipe_tests(filters_fields=["Dataset"], filters=[["CommonLanguage", "LibriSpeech"]], do_checks=False, run_opts="--device=cuda")) else print("TEST PASSED")'
 ```
 
+You can run the recipe on the CPU just by setting the run_opts properly:
+```
+python -c 'from tests.utils.recipe_tests import run_recipe_tests; print("TEST FAILED!") if not(run_recipe_tests(filters_fields=["Dataset"], filters=[["CommonLanguage", "LibriSpeech"]], do_checks=False, run_opts="--device=cpu")) else print("TEST PASSED")'
+```
+
+In some cases, you might want to test the recipe on a non-default GPU (e.g, cuda:1). This helps detecting issues in recipes where the device was hard-coded. You can do that simply with:
+
+```
+python -c 'from tests.utils.recipe_tests import run_recipe_tests; print("TEST FAILED!") if not(run_recipe_tests(filters_fields=["Dataset"], filters=[["CommonLanguage", "LibriSpeech"]], do_checks=False, run_opts="--device=cuda:0")) else print("TEST PASSED")'
+```
+
+
 To target a specific recipe (here by its hparam yaml):
 ```
 python -c 'from tests.utils.recipe_tests import run_recipe_tests; print("TEST FAILED!") if not(run_recipe_tests(filters_fields=["Hparam_file"], filters=[["recipes/TIMIT/ASR/transducer/hparams/train_wav2vec.yaml"]], do_checks=False, run_opts="--device=cuda")) else print("TEST PASSED")'
@@ -59,7 +71,7 @@ Let's take a look at recipes: their structural outline & their testing definitio
 * `recipes/DATASET/TASK/METHOD/README.md` – a _Readme_file_, which points to
   * some GDrive url – a _Result_url_ [optional]
   * some HuggingFace url – a _HF_repo_ [optional], which has
-    * pretrained model – `hyperparameters.yaml` to be loaded either by [a pretrained interface](https://github.com/speechbrain/speechbrain/tree/develop/speechbrain/pretrained) or a custom interface
+    * pretrained model – `hyperparameters.yaml` to be loaded either by [a pretrained interface](https://github.com/speechbrain/speechbrain/tree/develop/speechbrain/inference) or a custom interface
     * code snippets, for demonstration
   * additional references, incl. further URLs
   > _Note: all URLs references (in .py, .md & .txt files) are checked to be valid._
diff --git a/tests/recipes/REAL-M.csv b/tests/recipes/REAL-M.csv
index a174bf54caa9465d5d80bdc4b6a8645099b3269b..643f7bdf48fdf0df380c7e2d6181bc66a32a3a6c 100644
--- a/tests/recipes/REAL-M.csv
+++ b/tests/recipes/REAL-M.csv
@@ -1,2 +1,2 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Sisnr-estimation,REAL-M,recipes/REAL-M/sisnr-estimation/train.py,recipes/REAL-M/sisnr-estimation/hparams/pool_sisnrestimator.yaml,,recipes/REAL-M/sisnr-estimation/README.md,https://www.dropbox.com/sh/n55lm8i5z51pbm1/AABHfByOEy__UP_bmT4GJvSba?dl=0,https://huggingface.co/speechbrain/REAL-M-sisnr-estimator https://huggingface.co/speechbrain/REAL-M-sisnr-estimator-training,--rir_path=tests/tmp/RIRS_NOISES/pointsource_noises --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --whamr_data_folder=data_folder=tests/samples/separation --train_whamr_data=tests/samples/annotation/separation_train.csv --base_folder_dm_whamr=tests/samples/separation,"file_exists=[test_results_wsj.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/whamr_rirs.csv]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Sisnr-estimation,REAL-M,recipes/REAL-M/sisnr-estimation/train.py,recipes/REAL-M/sisnr-estimation/hparams/pool_sisnrestimator.yaml,,recipes/REAL-M/sisnr-estimation/README.md,https://www.dropbox.com/sh/n55lm8i5z51pbm1/AABHfByOEy__UP_bmT4GJvSba?dl=0,https://huggingface.co/speechbrain/REAL-M-sisnr-estimator,--rir_path=tests/tmp/RIRS_NOISES/pointsource_noises --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --whamr_data_folder=data_folder=tests/samples/separation --train_whamr_data=tests/samples/annotation/separation_train.csv --base_folder_dm_whamr=tests/samples/separation --use_wavedrop=True,"file_exists=[test_results_wsj.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/whamr_rirs.csv]",L1-Error=1.71dB
diff --git a/tests/recipes/RescueSpeech.csv b/tests/recipes/RescueSpeech.csv
index 6fd30e324b62ff74ce0fcec58108e4e1f1b4ff2e..c3df4c99acf0ac33156168ab6a7dc9b88d24b16e 100644
--- a/tests/recipes/RescueSpeech.csv
+++ b/tests/recipes/RescueSpeech.csv
@@ -1,2 +1,2 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR+enhancement,RescueSpeech,recipes/RescueSpeech/ASR/noise-robust/train.py,recipes/RescueSpeech/ASR/noise-robust/hparams/robust_asr_16k.yaml,recipes/RescueSpeech/rescuespeech_prepare.py,recipes/RescueSpeech/README.md,https://www.dropbox.com/sh/nk55xm0saa5qfly/AAC6yHgJnQalFMePgKFZqVfPa?dl=0,https://huggingface.co/sangeet2020/noisy-whisper-resucespeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,"file_exists=[test_results.csv,train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/spk1_snt1/enhanced.wav,enhanced_wavs/spk1_snt1/noisy.wav,enhanced_wavs/spk1_snt1/clean.wav,enhanced_wavs/spk1_snt2/enhanced.wav,enhanced_wavs/spk1_snt2/noisy.wav,enhanced_wavs/spk1_snt2/clean.wav,save/encoder.ckpt,save/masknet.ckpt,save/whisper.ckpt,save/decoder.ckpt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR+enhancement,RescueSpeech,recipes/RescueSpeech/ASR/noise-robust/train.py,recipes/RescueSpeech/ASR/noise-robust/hparams/robust_asr_16k.yaml,recipes/RescueSpeech/rescuespeech_prepare.py,recipes/RescueSpeech/README.md,https://www.dropbox.com/sh/kqs2ld14fm20cxl/AACiobSLdNtXhm-4Y3IIbTeia?dl=0,https://huggingface.co/sangeet2020/noisy-whisper-resucespeech,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,,SISNRi=7.482 SDRi=8.011 PESQ=2.083 STOI=0.854 WER=45.29%
diff --git a/tests/recipes/SLURP.csv b/tests/recipes/SLURP.csv
index d3f253617d422bedca232cd126fd033b26f2b1a0..f4fe80ba8d4e9d0e892f128be79caab806b7628b 100644
--- a/tests/recipes/SLURP.csv
+++ b/tests/recipes/SLURP.csv
@@ -1,5 +1,5 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-SLU,SLURP,recipes/SLURP/NLU/train.py,recipes/SLURP/NLU/hparams/train.yaml,recipes/SLURP/NLU/prepare.py,recipes/SLURP/README.md,https://www.dropbox.com/sh/c2rnjads9gfd7k2/AADtVZi616qH_jb_owKK7b9Ra?dl=0 https://www.dropbox.com/sh/uygyiir8rfajcmu/AACXbjhM34ZDy2UprWfg-uyVa?dl=0 ,,--data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[wer_test_real.txt,train_log.txt,predictions.jsonl,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizers/slu_tokenizer.ckpt,save/tokenizers/asr_tokenizer.ckpt]"
-SLU,SLURP,recipes/SLURP/direct/train.py,recipes/SLURP/direct/hparams/train.yaml,recipes/SLURP/direct/prepare.py,recipes/SLURP/README.md,https://www.dropbox.com/sh/c2rnjads9gfd7k2/AADtVZi616qH_jb_owKK7b9Ra?dl=0 https://www.dropbox.com/sh/uygyiir8rfajcmu/AACXbjhM34ZDy2UprWfg-uyVa?dl=0,,--data_folder_rirs=tests/tmp --data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[wer_test_real.txt,train_log.txt,predictions.jsonl,log.txt,env.log,train.py,hyperparams.yaml,save/SLURM_tokenizer/tokenizer.ckpt]"
-SLU,SLURP,recipes/SLURP/direct/train_with_wav2vec2.py,recipes/SLURP/direct/hparams/train_with_wav2vec2.yaml,recipes/SLURP/direct/prepare.py,recipes/SLURP/README.md,https://www.dropbox.com/sh/c2rnjads9gfd7k2/AADtVZi616qH_jb_owKK7b9Ra?dl=0 https://www.dropbox.com/sh/uygyiir8rfajcmu/AACXbjhM34ZDy2UprWfg-uyVa?dl=0,https://huggingface.co/speechbrain/SLU-direct-SLURP-hubert-enc,--data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[train_with_wav2vec2.py,wer_test_real.txt,train_log.txt,predictions.jsonl,log.txt,env.log,hyperparams.yaml,save/SLURP_tokenizer/tokenizer.ckpt]"
-Tokenizer,SLURP,recipes/SLURP/Tokenizer/train.py,recipes/SLURP/Tokenizer/hparams/tokenizer_bpe58.yaml,recipes/SLURP/Tokenizer/prepare.py,recipes/SLURP/README.md,https://www.dropbox.com/sh/c2rnjads9gfd7k2/AADtVZi616qH_jb_owKK7b9Ra?dl=0 https://www.dropbox.com/sh/uygyiir8rfajcmu/AACXbjhM34ZDy2UprWfg-uyVa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=17,"file_exists=[17_unigram.vocab,log.txt,17_unigram.model,ASR_train.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+SLU,SLURP,recipes/SLURP/NLU/train.py,recipes/SLURP/NLU/hparams/train.yaml,recipes/SLURP/NLU/prepare.py,recipes/SLURP/README.md,https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 ,,--data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[wer_test_real.txt,train_log.txt,predictions.jsonl,log.txt,env.log,train.py,hyperparams.yaml,save/tokenizers/slu_tokenizer.ckpt,save/tokenizers/asr_tokenizer.ckpt]",scenario-accuracy=90.81% action-accuracy=88.29% intent-accuracy=87.28%
+SLU,SLURP,recipes/SLURP/direct/train.py,recipes/SLURP/direct/hparams/train.yaml,recipes/SLURP/direct/prepare.py,recipes/SLURP/README.md,https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 ,,--data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[wer_test_real.txt,train_log.txt,predictions.jsonl,log.txt,env.log,train.py,hyperparams.yaml,save/SLURM_tokenizer/tokenizer.ckpt]",scenario-accuracy=81.73% action-accuracy=77.11% intent-accuracy=75.05%
+SLU,SLURP,recipes/SLURP/direct/train_with_wav2vec2.py,recipes/SLURP/direct/hparams/train_with_wav2vec2.yaml,recipes/SLURP/direct/prepare.py,recipes/SLURP/README.md,https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 ,https://huggingface.co/speechbrain/SLU-direct-SLURP-hubert-enc,--data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[train_with_wav2vec2.py,wer_test_real.txt,train_log.txt,predictions.jsonl,log.txt,env.log,hyperparams.yaml,save/SLURP_tokenizer/tokenizer.ckpt]",scenario-accuracy=91.24% action-accuracy=88.47% intent-accuracy=87.55%
+Tokenizer,SLURP,recipes/SLURP/Tokenizer/train.py,recipes/SLURP/Tokenizer/hparams/tokenizer_bpe58.yaml,recipes/SLURP/Tokenizer/prepare.py,recipes/SLURP/README.md,https://www.dropbox.com/scl/fo/c0rm2ja8oxus8q27om8ve/h?rlkey=irxzl1ea8g7e6ipk0vuc288zh&dl=0 ,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=17,"file_exists=[17_unigram.vocab,log.txt,17_unigram.model,ASR_train.txt,env.log,train.py,hyperparams.yaml]",
diff --git a/tests/recipes/Switchboard.csv b/tests/recipes/Switchboard.csv
index e9ce9143bf5b5191bfbf6c9bb16fa5f42fabe017..164717e3c2b7491f7fbb6f0bbd4b52c144f282fc 100644
--- a/tests/recipes/Switchboard.csv
+++ b/tests/recipes/Switchboard.csv
@@ -1,8 +1,8 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR,Switchboard,recipes/Switchboard/ASR/CTC/train_with_wav2vec.py,recipes/Switchboard/ASR/CTC/hparams/train_with_wav2vec.yaml,recipes/Switchboard/ASR/CTC/switchboard_prepare.py,recipes/Switchboard/ASR/CTC/README.md,,https://huggingface.co/speechbrain/asr-wav2vec2-switchboard,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --test_csv=[tests/samples/annotation/ASR_train_stereo.csv] --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --normalize_words=False --train_tokenizer_csv=tests/samples/annotation/ASR_train_stereo.csv,"file_exists=[train_log.txt,log.txt,train_with_wav2vec.py,env.log,wer_ASR_train_stereo.txt,hyperparams.yaml,save/27_unigram.model,save/ASR_train_stereo.txt,save/27_unigram.vocab]"
-Tokenizer,Switchboard,recipes/Switchboard/Tokenizer/train.py,recipes/Switchboard/Tokenizer/hparams/2K_unigram_subword_bpe.yaml,recipes/Switchboard/Tokenizer/switchboard_prepare.py,recipes/Switchboard/Tokenizer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train_stereo.txt,env.log,train.py,hyperparams.yaml]"
-LM,Switchboard,recipes/Switchboard/LM/train.py,recipes/Switchboard/LM/hparams/transformer.yaml,recipes/Switchboard/LM/switchboard_prepare.py,recipes/Switchboard/LM/README.md,,,--data_folder=tests/samples/annotation/ --train_csv=tests/samples/annotation/LM_train.csv --valid_csv=tests/samples/annotation/LM_dev.csv --test_csv=tests/samples/annotation/LM_dev.csv --number_of_epochs=2 --tokenizer_file=tests/tmp/Switchboard_row_3/23_unigram.model --skip_prep=True --output_neurons=23,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-LM,Switchboard,recipes/Switchboard/LM/train.py,recipes/Switchboard/LM/hparams/transformer_finetune.yaml,recipes/Switchboard/LM/switchboard_prepare.py,recipes/Switchboard/LM/README.md,,,--data_folder=tests/samples/annotation/ --train_csv=tests/samples/annotation/LM_train.csv --valid_csv=tests/samples/annotation/LM_dev.csv --test_csv=tests/samples/annotation/LM_dev.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/model.ckpt,save/tokenizer.ckpt]"
-ASR,Switchboard,recipes/Switchboard/ASR/seq2seq/train.py,recipes/Switchboard/ASR/seq2seq/hparams/train_BPE_2000.yaml,recipes/Switchboard/ASR/seq2seq/switchboard_prepare.py,recipes/Switchboard/ASR/seq2seq/README.md,,https://huggingface.co/speechbrain/asr-crdnn-switchboard,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --test_csv=[tests/samples/annotation/ASR_train_stereo.csv] --number_of_epochs=10 --skip_prep=True --data_folder_rirs=tests/tmp --normalize_words=False --pretrained_tokenizer_path= --tokenizer_file=tests/tmp/Switchboard_row_3/23_unigram.model --output_neurons=23,"file_exists=[train_log.txt,log.txt,env.log,train.py,wer_ASR_train_stereo.txt,hyperparams.yaml]"
-ASR,Switchboard,recipes/Switchboard/ASR/transformer/train.py,recipes/Switchboard/ASR/transformer/hparams/transformer.yaml,recipes/Switchboard/ASR/transformer/switchboard_prepare.py,recipes/Switchboard/ASR/transformer/README.md,,https://huggingface.co/speechbrain/asr-transformer-switchboard,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --test_csv=[tests/samples/annotation/ASR_train_stereo.csv] --number_of_epochs=10 --skip_prep=True --normalize_words=False --output_neurons=23 --pretrained_lm_tokenizer_path= --tokenizer_file=tests/tmp/Switchboard_row_3/23_unigram.model --lm_file=tests/tmp/Switchboard_row_4/save/CKPT+latest/model.ckpt,"file_exists=[train_log.txt,log.txt,env.log,train.py,wer_ASR_train_stereo.txt,hyperparams.yaml]"
-ASR,Switchboard,recipes/Switchboard/ASR/transformer/train.py,recipes/Switchboard/ASR/transformer/hparams/transformer_finetuned_LM.yaml,recipes/Switchboard/ASR/transformer/switchboard_prepare.py,recipes/Switchboard/ASR/transformer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --test_csv=[tests/samples/annotation/ASR_train_stereo.csv] --number_of_epochs=10 --skip_prep=True --normalize_words=False --pretrained_lm_tokenizer_path= --tokenizer_file=tests/tmp/Switchboard_row_5/save/tokenizer.ckpt --lm_file=tests/tmp/Switchboard_row_5/save/CKPT+latest/model.ckpt,"file_exists=[train_log.txt,log.txt,env.log,train.py,wer_ASR_train_stereo.txt,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR,Switchboard,recipes/Switchboard/ASR/CTC/train_with_wav2vec.py,recipes/Switchboard/ASR/CTC/hparams/train_with_wav2vec.yaml,recipes/Switchboard/ASR/CTC/switchboard_prepare.py,recipes/Switchboard/ASR/CTC/README.md,,https://huggingface.co/speechbrain/asr-wav2vec2-switchboard,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --test_csv=[tests/samples/annotation/ASR_train_stereo.csv] --number_of_epochs=2 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --normalize_words=False --train_tokenizer_csv=tests/samples/annotation/ASR_train_stereo.csv,"file_exists=[train_log.txt,log.txt,train_with_wav2vec.py,env.log,wer_ASR_train_stereo.txt,hyperparams.yaml,save/27_unigram.model,save/ASR_train_stereo.txt,save/27_unigram.vocab]",Swbd-WER=8.76% Callhome-WER=14.67% Eval2000-WER=11.78%
+Tokenizer,Switchboard,recipes/Switchboard/Tokenizer/train.py,recipes/Switchboard/Tokenizer/hparams/2K_unigram_subword_bpe.yaml,recipes/Switchboard/Tokenizer/switchboard_prepare.py,recipes/Switchboard/Tokenizer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --skip_prep=True --token_output=23,"file_exists=[23_unigram.model,23_unigram.vocab,log.txt,ASR_train_stereo.txt,env.log,train.py,hyperparams.yaml]",
+LM,Switchboard,recipes/Switchboard/LM/train.py,recipes/Switchboard/LM/hparams/transformer.yaml,recipes/Switchboard/LM/switchboard_prepare.py,recipes/Switchboard/LM/README.md,,,--data_folder=tests/samples/annotation/ --train_csv=tests/samples/annotation/LM_train.csv --valid_csv=tests/samples/annotation/LM_dev.csv --test_csv=tests/samples/annotation/LM_dev.csv --number_of_epochs=2 --tokenizer_file=tests/tmp/Switchboard_row_03/23_unigram.model --skip_prep=True --output_neurons=23,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+LM,Switchboard,recipes/Switchboard/LM/train.py,recipes/Switchboard/LM/hparams/transformer_finetune.yaml,recipes/Switchboard/LM/switchboard_prepare.py,recipes/Switchboard/LM/README.md,,,--data_folder=tests/samples/annotation/ --train_csv=tests/samples/annotation/LM_train.csv --valid_csv=tests/samples/annotation/LM_dev.csv --test_csv=tests/samples/annotation/LM_dev.csv --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/model.ckpt,save/tokenizer.ckpt]",
+ASR,Switchboard,recipes/Switchboard/ASR/seq2seq/train.py,recipes/Switchboard/ASR/seq2seq/hparams/train_BPE_2000.yaml,recipes/Switchboard/ASR/seq2seq/switchboard_prepare.py,recipes/Switchboard/ASR/seq2seq/README.md,,https://huggingface.co/speechbrain/asr-crdnn-switchboard,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --test_csv=[tests/samples/annotation/ASR_train_stereo.csv] --number_of_epochs=10 --skip_prep=True --normalize_words=False --pretrained_tokenizer_path= --tokenizer_file=tests/tmp/Switchboard_row_03/23_unigram.model --output_neurons=23,"file_exists=[train_log.txt,log.txt,env.log,train.py,wer_ASR_train_stereo.txt,hyperparams.yaml]",Swbd-WER=16.90% Callhome-WER=25.12% Eval2000-WER=20.71%
+ASR,Switchboard,recipes/Switchboard/ASR/transformer/train.py,recipes/Switchboard/ASR/transformer/hparams/transformer.yaml,recipes/Switchboard/ASR/transformer/switchboard_prepare.py,recipes/Switchboard/ASR/transformer/README.md,,https://huggingface.co/speechbrain/asr-transformer-switchboard,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --test_csv=[tests/samples/annotation/ASR_train_stereo.csv] --number_of_epochs=10 --skip_prep=True --normalize_words=False --output_neurons=23 --pretrained_lm_tokenizer_path= --tokenizer_file=tests/tmp/Switchboard_row_03/23_unigram.model --lm_file=tests/tmp/Switchboard_row_04/save/CKPT+latest/model.ckpt,"file_exists=[train_log.txt,log.txt,env.log,train.py,wer_ASR_train_stereo.txt,hyperparams.yaml]",Swbd-WER=9.80% Callhome-WER=17.89% Eval2000-WER=13.94%
+ASR,Switchboard,recipes/Switchboard/ASR/transformer/train.py,recipes/Switchboard/ASR/transformer/hparams/transformer_finetuned_LM.yaml,recipes/Switchboard/ASR/transformer/switchboard_prepare.py,recipes/Switchboard/ASR/transformer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train_stereo.csv --valid_csv=tests/samples/annotation/ASR_train_stereo.csv --test_csv=[tests/samples/annotation/ASR_train_stereo.csv] --number_of_epochs=10 --skip_prep=True --normalize_words=False --pretrained_lm_tokenizer_path= --tokenizer_file=tests/tmp/Switchboard_row_05/save/tokenizer.ckpt --lm_file=tests/tmp/Switchboard_row_05/save/CKPT+latest/model.ckpt,"file_exists=[train_log.txt,log.txt,env.log,train.py,wer_ASR_train_stereo.txt,hyperparams.yaml]",
diff --git a/tests/recipes/TIMIT.csv b/tests/recipes/TIMIT.csv
index c3a3f5c2370918db43c3100699e9003399d99b50..72a18c879a4f7b9b18acf993e84bda3106c5d327 100644
--- a/tests/recipes/TIMIT.csv
+++ b/tests/recipes/TIMIT.csv
@@ -1,19 +1,7 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR,TIMIT,recipes/TIMIT/ASR/CTC/train.py,recipes/TIMIT/ASR/CTC/hparams/train.yaml,recipes/TIMIT/ASR/CTC/timit_prepare.py,recipes/TIMIT/ASR/CTC/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/  --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=44 --number_of_epochs=10 --skip_prep=True --open_rir_folder=tests/tmp,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <=12, epoch: 10]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq/train.py,recipes/TIMIT/ASR/seq2seq/hparams/train.yaml,recipes/TIMIT/ASR/seq2seq/timit_prepare.py,recipes/TIMIT/ASR/seq2seq/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=47 --number_of_epochs=10 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <=6, epoch: 10]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq/train_with_wav2vec2.py,recipes/TIMIT/ASR/seq2seq/hparams/train_with_wav2vec2.yaml,recipes/TIMIT/ASR/seq2seq/timit_prepare.py,recipes/TIMIT/ASR/seq2seq/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=47 --number_of_epochs=10 --skip_prep=True,"file_exists=[train_with_wav2vec2.py,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <=4, epoch: 10]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea0.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea1.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea2.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea3.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea4.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea5.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea6.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea7.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea8.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_teacher.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/teachers/tea9.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,train_teacher.py,save/label_encoder.txt]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/save_teachers.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/save_teachers.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --skip_prep=True --tea_models_dir=tests/tmp/TIMIT_row_15/tea_models.txt,"file_exists=[save_teachers.py,tea_infer_8batch.hdf5,tea_models.txt,log.txt,env.log,hyperparams.yaml]"
-ASR,TIMIT,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/train_kd.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/hparams/train_kd.yaml,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/timit_prepare.py,recipes/TIMIT/ASR/seq2seq_knowledge_distillation/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train_39p.json --valid_annotation=tests/samples/annotation/ASR_train_39p.json --test_annotation=tests/samples/annotation/ASR_train_39p.json --number_of_epochs=2 --beam_size=1 --skip_prep=True --tea_infer_dir=tests/tmp/TIMIT_row_15 --pretrain_st_dir=tests/tmp/TIMIT_row_5,"file_exists=[train_kd.py,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml]"
-ASR,TIMIT,recipes/TIMIT/ASR/transducer/train.py,recipes/TIMIT/ASR/transducer/hparams/train.yaml,recipes/TIMIT/ASR/transducer/timit_prepare.py,recipes/TIMIT/ASR/transducer/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=44 --number_of_epochs=10 --beam_size=1 --skip_prep=True --openrir_folder=tests/tmp,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <=300, epoch: 10]"
-ASR,TIMIT,recipes/TIMIT/ASR/transducer/train_wav2vec.py,recipes/TIMIT/ASR/transducer/hparams/train_wav2vec.yaml,recipes/TIMIT/ASR/transducer/timit_prepare.py,recipes/TIMIT/ASR/transducer/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=44 --number_of_epochs=2 --beam_size=1 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train_wav2vec.py,hyperparams.yaml,save/label_encoder.txt]"
-Alignment,TIMIT,recipes/TIMIT/Alignment/train.py,recipes/TIMIT/Alignment/hparams/train.yaml,recipes/TIMIT/Alignment/timit_prepare.py,recipes/TIMIT/Alignment/README.md ,https://www.dropbox.com/sh/dcicuz1r6v7iitt/AAB1BpaMjfhUDBsEsxjAuaHVa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR,TIMIT,recipes/TIMIT/ASR/CTC/train.py,recipes/TIMIT/ASR/CTC/hparams/train.yaml,recipes/TIMIT/ASR/CTC/timit_prepare.py,recipes/TIMIT/ASR/CTC/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/  --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=44 --number_of_epochs=10 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <=12, epoch: 10]",Test-PER=14.78%
+ASR,TIMIT,recipes/TIMIT/ASR/seq2seq/train.py,recipes/TIMIT/ASR/seq2seq/hparams/train.yaml,recipes/TIMIT/ASR/seq2seq/timit_prepare.py,recipes/TIMIT/ASR/seq2seq/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=47 --number_of_epochs=10 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <=6, epoch: 10]",Test-PER=14.07%
+ASR,TIMIT,recipes/TIMIT/ASR/seq2seq/train_with_wav2vec2.py,recipes/TIMIT/ASR/seq2seq/hparams/train_with_wav2vec2.yaml,recipes/TIMIT/ASR/seq2seq/timit_prepare.py,recipes/TIMIT/ASR/seq2seq/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=45 --number_of_epochs=10 --skip_prep=True,"file_exists=[train_with_wav2vec2.py,train_log.txt,log.txt,wer_test.txt,env.log,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <=4, epoch: 10]",Test-PER=8.04%
+ASR,TIMIT,recipes/TIMIT/ASR/transducer/train.py,recipes/TIMIT/ASR/transducer/hparams/train.yaml,recipes/TIMIT/ASR/transducer/timit_prepare.py,recipes/TIMIT/ASR/transducer/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=43 --number_of_epochs=10 --beam_size=1 --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt] performance_check=[train_log.txt, train loss, <=300, epoch: 10]",Test-PER=14.12%
+ASR,TIMIT,recipes/TIMIT/ASR/transducer/train_wav2vec.py,recipes/TIMIT/ASR/transducer/hparams/train_wav2vec.yaml,recipes/TIMIT/ASR/transducer/timit_prepare.py,recipes/TIMIT/ASR/transducer/README.md,https://www.dropbox.com/sh/059jnwdass8v45u/AADTjh5DYdYKuZsgH9HXGx0Sa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --output_neurons=43 --number_of_epochs=2 --beam_size=1 --skip_prep=True --precision=fp32,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train_wav2vec.py,hyperparams.yaml,save/label_encoder.txt]",Test-PER=8.91%
+Alignment,TIMIT,recipes/TIMIT/Alignment/train.py,recipes/TIMIT/Alignment/hparams/train.yaml,recipes/TIMIT/Alignment/timit_prepare.py,recipes/TIMIT/Alignment/README.md ,https://www.dropbox.com/sh/dcicuz1r6v7iitt/AAB1BpaMjfhUDBsEsxjAuaHVa?dl=0,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --number_of_epochs=2 --skip_prep=True,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]",
diff --git a/tests/recipes/Tedlium2.csv b/tests/recipes/Tedlium2.csv
new file mode 100644
index 0000000000000000000000000000000000000000..634bcaf3fd70b330e83d6a0d46c108a5999bfa27
--- /dev/null
+++ b/tests/recipes/Tedlium2.csv
@@ -0,0 +1,3 @@
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Tokenizer,Tedlium2,recipes/Tedlium2/Tokenizer/train.py,recipes/Tedlium2/Tokenizer/hparams/tedlium2_500_bpe.yaml,recipes/Tedlium2/Tokenizer/tedlium2_prepare.py,recipes/Tedlium2/Tokenizer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=23 --clipped_utt_folder=None,,
+ASR,Tedlium2,recipes/Tedlium2/ASR/transformer/train.py,recipes/Tedlium2/ASR/transformer/hparams/branchformer_large.yaml,recipes/Tedlium2/Tokenizer/tedlium2_prepare.py,recipes/Tedlium2/ASR/transformer/README.md,https://www.dropbox.com/sh/el523uofs96czfi/AADgTd838pKo2aR8fhqVOh-Oa?dl=0,https://huggingface.co/speechbrain/asr-branchformer-large-tedlium2,--data_folder=. --clipped_utt_folder=. --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --output_neurons=23 --number_of_epochs=10 --skip_prep=True --test_beam_size=1 --valid_beam_size=1 --pretrained_tokenizer_file=tests/tmp/Tedlium2_row_02/23_bpe.model,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_ASR_train.txt,save/tokenizer.ckpt] performance_check=[train_log.txt, train loss, <500, epoch: 10]",Test-WER_No_LM=8.11%
diff --git a/tests/recipes/UrbanSound8k.csv b/tests/recipes/UrbanSound8k.csv
index 0fe67d07e1550dd141306c7ed4bb1079ab99f1e3..aa81f39da332e281dd65bd818fe65cd2df94a4ad 100644
--- a/tests/recipes/UrbanSound8k.csv
+++ b/tests/recipes/UrbanSound8k.csv
@@ -1,2 +1,2 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-SoundClassification,UrbanSound8k,recipes/UrbanSound8k/SoundClassification/train.py,recipes/UrbanSound8k/SoundClassification/hparams/train_ecapa_tdnn.yaml,recipes/UrbanSound8k/SoundClassification/urbansound8k_prepare.py,recipes/UrbanSound8k/README.md,https://www.dropbox.com/sh/f61325e3w8h5yy2/AADm3E3PXFi1NYA7-QW3H-Ata?dl=0 ,https://huggingface.co/speechbrain/urbansound8k_ecapa,--open_rir_folder=tests/tmp --number_of_epochs=2 --skip_manifest_creation=True --data_folder=../ --audio_data_folder=tests/samples/ASR --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --train_fold_nums=[1] --valid_fold_nums=[1] --test_fold_nums=[1] --use_tensorboard=False --out_n_neurons=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+SoundClassification,UrbanSound8k,recipes/UrbanSound8k/SoundClassification/train.py,recipes/UrbanSound8k/SoundClassification/hparams/train_ecapa_tdnn.yaml,recipes/UrbanSound8k/SoundClassification/urbansound8k_prepare.py,recipes/UrbanSound8k/README.md,https://www.dropbox.com/sh/f61325e3w8h5yy2/AADm3E3PXFi1NYA7-QW3H-Ata?dl=0 ,https://huggingface.co/speechbrain/urbansound8k_ecapa, --number_of_epochs=2 --skip_manifest_creation=True --data_folder=../ --audio_data_folder=tests/samples/ASR --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --train_fold_nums=[1] --valid_fold_nums=[1] --test_fold_nums=[1] --use_tensorboard=False --out_n_neurons=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]",Accuracy=75.4%
diff --git a/tests/recipes/Voicebank.csv b/tests/recipes/Voicebank.csv
index cfaeda433c41dee85150f3fdc104c2a6d10217be..3978a02a9065c277ffcbefcfccb1958e485a85d8 100644
--- a/tests/recipes/Voicebank.csv
+++ b/tests/recipes/Voicebank.csv
@@ -1,12 +1,12 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-ASR+enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voicebank/MTL/ASR_enhance/hparams/pretrain_perceptual.yaml,recipes/Voicebank/MTL/ASR_enhance/voicebank_prepare.py,recipes/Voicebank/MTL/ASR_enhance/README.md,https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0,https://huggingface.co/speechbrain/mtl-mimic-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --output_neurons=17,"file_exists=[valid_stats.txt,test_stats.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-ASR+enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voicebank/MTL/ASR_enhance/hparams/robust_asr.yaml,recipes/Voicebank/MTL/ASR_enhance/voicebank_prepare.py,recipes/Voicebank/MTL/ASR_enhance/README.md,https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0,https://huggingface.co/speechbrain/mtl-mimic-voicebank,--data_folder_rirs=tests/tmp --data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,"file_exists=[valid_stats.txt,test_stats.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/normalizer.ckpt,save/asr_model.ckpt,save/enhance_model.ckpt,save/tokenizer.ckpt,save/lm.ckpt]"
-ASR,Voicebank,recipes/Voicebank/ASR/CTC/train.py,recipes/Voicebank/ASR/CTC/hparams/train.yaml,recipes/Voicebank/ASR/CTC/voicebank_prepare.py,recipes/Voicebank/ASR/CTC/README.md,https://www.dropbox.com/sh/w4j0auezgmmo005/AAAjKcoJMdLDp0Pqe3m7CLVaa?dl=0,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --output_neurons=18 --number_of_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,per.txt,save/label_encoder.txt]"
-Enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voicebank/MTL/ASR_enhance/hparams/enhance_mimic.yaml,recipes/Voicebank/MTL/ASR_enhance/voicebank_prepare.py,recipes/Voicebank/MTL/ASR_enhance/README.md,https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0,https://huggingface.co/speechbrain/mtl-mimic-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,"file_exists=[valid_stats.txt,test_stats.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/perceptual_model.ckpt]"
-Enhancement,Voicebank,recipes/Voicebank/dereverb/MetricGAN-U/train.py,recipes/Voicebank/dereverb/MetricGAN-U/hparams/train_dereverb.yaml,recipes/Voicebank/dereverb/MetricGAN-U/voicebank_revb_prepare.py,recipes/Voicebank/dereverb/MetricGAN-U/README.md,https://www.dropbox.com/sh/r94qn1f5lq9r3p7/AAAZfisBhhkS8cwpzy1O5ADUa?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False --target_metric=srmr --calculate_dnsmos_on_validation_set=False,"file_exists=[historical.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh0@2.wav,enhanced_wavs/enh2.wav,enhanced_wavs/enh0@1.wav,enhanced_wavs/enh1@2.wav,enhanced_wavs/enh1@1.wav]"
-Enhancement,Voicebank,recipes/Voicebank/dereverb/spectral_mask/train.py,recipes/Voicebank/dereverb/spectral_mask/hparams/train.yaml,recipes/Voicebank/dereverb/spectral_mask/voicebank_revb_prepare.py,recipes/Voicebank/dereverb/spectral_mask/README.md,https://www.dropbox.com/sh/pw8aer8gcsrdbx7/AADknh7plHF5GBeTRK9VkIKga?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh2.wav]"
-Enhancement,Voicebank,recipes/Voicebank/enhance/MetricGAN-U/train.py,recipes/Voicebank/enhance/MetricGAN-U/hparams/train_dnsmos.yaml,recipes/Voicebank/enhance/MetricGAN-U/voicebank_prepare.py,recipes/Voicebank/enhance/MetricGAN-U/README.md,https://www.dropbox.com/sh/h9akxmyel17sc8y/AAAP3Oz5MbXDfMlEXVjOBWV0a?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --use_tensorboard=False --target_metric=srmr --calculate_dnsmos_on_validation_set=False,"file_exists=[historical.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh0@2.wav,enhanced_wavs/enh2.wav,enhanced_wavs/enh0@1.wav,enhanced_wavs/enh1@2.wav,enhanced_wavs/enh1@1.wav]"
-Enhancement,Voicebank,recipes/Voicebank/enhance/MetricGAN/train.py,recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml,recipes/Voicebank/enhance/MetricGAN/voicebank_prepare.py,recipes/Voicebank/enhance/MetricGAN/README.md,https://www.dropbox.com/sh/n5q9vjn0yn1qvk6/AAB-S7i2-XzVm6ux0MrXCvqya?dl=0 ,https://huggingface.co/speechbrain/metricgan-plus-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,"file_exists=[historical.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh0@2.wav,enhanced_wavs/enh2.wav,enhanced_wavs/enh0@1.wav,enhanced_wavs/enh1@2.wav,enhanced_wavs/enh1@1.wav]"
-Enhancement,Voicebank,recipes/Voicebank/enhance/SEGAN/train.py,recipes/Voicebank/enhance/SEGAN/hparams/train.yaml,recipes/Voicebank/enhance/SEGAN/voicebank_prepare.py,recipes/Voicebank/enhance/SEGAN/README.md,https://www.dropbox.com/sh/ez0folswdbqiad4/AADDasepeoCkneyiczjCcvaOa?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced/enh2.wav]"
-Enhancement,Voicebank,recipes/Voicebank/enhance/spectral_mask/train.py,recipes/Voicebank/enhance/spectral_mask/hparams/train.yaml,recipes/Voicebank/enhance/spectral_mask/voicebank_prepare.py,recipes/Voicebank/enhance/spectral_mask/README.md,https://www.dropbox.com/sh/n5q9vjn0yn1qvk6/AAB-S7i2-XzVm6ux0MrXCvqya?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --use_tensorboard=False,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh2.wav]"
-Enhancement,Voicebank,recipes/Voicebank/enhance/waveform_map/train.py,recipes/Voicebank/enhance/waveform_map/hparams/train.yaml,recipes/Voicebank/enhance/waveform_map/voicebank_prepare.py,recipes/Voicebank/enhance/waveform_map/README.md,,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced/enh2.wav]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+ASR+enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voicebank/MTL/ASR_enhance/hparams/pretrain_perceptual.yaml,recipes/Voicebank/MTL/ASR_enhance/voicebank_prepare.py,recipes/Voicebank/MTL/ASR_enhance/README.md,https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0,https://huggingface.co/speechbrain/mtl-mimic-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --output_neurons=17,"file_exists=[valid_stats.txt,test_stats.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+ASR+enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voicebank/MTL/ASR_enhance/hparams/robust_asr.yaml,recipes/Voicebank/MTL/ASR_enhance/voicebank_prepare.py,recipes/Voicebank/MTL/ASR_enhance/README.md,https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0,https://huggingface.co/speechbrain/mtl-mimic-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,"file_exists=[valid_stats.txt,test_stats.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/normalizer.ckpt,save/asr_model.ckpt,save/enhance_model.ckpt,save/tokenizer.ckpt,save/lm.ckpt]",PESQ=3.05 COVL=3.74 test-WER=2.80
+ASR,Voicebank,recipes/Voicebank/ASR/CTC/train.py,recipes/Voicebank/ASR/CTC/hparams/train.yaml,recipes/Voicebank/ASR/CTC/voicebank_prepare.py,recipes/Voicebank/ASR/CTC/README.md,https://www.dropbox.com/sh/w4j0auezgmmo005/AAAjKcoJMdLDp0Pqe3m7CLVaa?dl=0,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --output_neurons=18 --number_of_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,per.txt,save/label_encoder.txt]",Test-PER=10.12%
+Enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voicebank/MTL/ASR_enhance/hparams/enhance_mimic.yaml,recipes/Voicebank/MTL/ASR_enhance/voicebank_prepare.py,recipes/Voicebank/MTL/ASR_enhance/README.md,https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0,https://huggingface.co/speechbrain/mtl-mimic-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,"file_exists=[valid_stats.txt,test_stats.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/perceptual_model.ckpt]",
+Dereverberation,Voicebank,recipes/Voicebank/dereverb/MetricGAN-U/train.py,recipes/Voicebank/dereverb/MetricGAN-U/hparams/train_dereverb.yaml,recipes/Voicebank/dereverb/MetricGAN-U/voicebank_revb_prepare.py,recipes/Voicebank/dereverb/MetricGAN-U/README.md,https://www.dropbox.com/sh/r94qn1f5lq9r3p7/AAAZfisBhhkS8cwpzy1O5ADUa?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False --target_metric=srmr --calculate_dnsmos_on_validation_set=False,"file_exists=[historical.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh0@2.wav,enhanced_wavs/enh2.wav,enhanced_wavs/enh0@1.wav,enhanced_wavs/enh1@2.wav,enhanced_wavs/enh1@1.wav]",PESQ=2.07
+Dereverberation,Voicebank,recipes/Voicebank/dereverb/spectral_mask/train.py,recipes/Voicebank/dereverb/spectral_mask/hparams/train.yaml,recipes/Voicebank/dereverb/spectral_mask/voicebank_revb_prepare.py,recipes/Voicebank/dereverb/spectral_mask/README.md,https://www.dropbox.com/sh/pw8aer8gcsrdbx7/AADknh7plHF5GBeTRK9VkIKga?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh2.wav]",PESQ=2.35
+Enhancement,Voicebank,recipes/Voicebank/enhance/MetricGAN-U/train.py,recipes/Voicebank/enhance/MetricGAN-U/hparams/train_dnsmos.yaml,recipes/Voicebank/enhance/MetricGAN-U/voicebank_prepare.py,recipes/Voicebank/enhance/MetricGAN-U/README.md,https://www.dropbox.com/sh/h9akxmyel17sc8y/AAAP3Oz5MbXDfMlEXVjOBWV0a?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --use_tensorboard=False --target_metric=srmr --calculate_dnsmos_on_validation_set=False,"file_exists=[historical.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh0@2.wav,enhanced_wavs/enh2.wav,enhanced_wavs/enh0@1.wav,enhanced_wavs/enh1@2.wav,enhanced_wavs/enh1@1.wav]",
+Enhancement,Voicebank,recipes/Voicebank/enhance/MetricGAN/train.py,recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml,recipes/Voicebank/enhance/MetricGAN/voicebank_prepare.py,recipes/Voicebank/enhance/MetricGAN/README.md,https://www.dropbox.com/sh/n5q9vjn0yn1qvk6/AAB-S7i2-XzVm6ux0MrXCvqya?dl=0 ,https://huggingface.co/speechbrain/metricgan-plus-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,"file_exists=[historical.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh0@2.wav,enhanced_wavs/enh2.wav,enhanced_wavs/enh0@1.wav,enhanced_wavs/enh1@2.wav,enhanced_wavs/enh1@1.wav]",PESQ=3.15
+Enhancement,Voicebank,recipes/Voicebank/enhance/SEGAN/train.py,recipes/Voicebank/enhance/SEGAN/hparams/train.yaml,recipes/Voicebank/enhance/SEGAN/voicebank_prepare.py,recipes/Voicebank/enhance/SEGAN/README.md,https://www.dropbox.com/sh/ez0folswdbqiad4/AADDasepeoCkneyiczjCcvaOa?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced/enh2.wav]",PESQ=2.38
+Enhancement,Voicebank,recipes/Voicebank/enhance/spectral_mask/train.py,recipes/Voicebank/enhance/spectral_mask/hparams/train.yaml,recipes/Voicebank/enhance/spectral_mask/voicebank_prepare.py,recipes/Voicebank/enhance/spectral_mask/README.md,https://www.dropbox.com/sh/n5q9vjn0yn1qvk6/AAB-S7i2-XzVm6ux0MrXCvqya?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --use_tensorboard=False,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced_wavs/enh2.wav]",PESQ=2.65
+Enhancement,Voicebank,recipes/Voicebank/enhance/waveform_map/train.py,recipes/Voicebank/enhance/waveform_map/hparams/train.yaml,recipes/Voicebank/enhance/waveform_map/voicebank_prepare.py,recipes/Voicebank/enhance/waveform_map/README.md,,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,enhanced/enh2.wav]",
diff --git a/tests/recipes/VoxCeleb.csv b/tests/recipes/VoxCeleb.csv
index e1858b638c756f3690bf3421d42dd11df35b4928..32643dee67511bec1b4d9ec2f61fc934df27056c 100644
--- a/tests/recipes/VoxCeleb.csv
+++ b/tests/recipes/VoxCeleb.csv
@@ -1,7 +1,8 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py,recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,https://www.dropbox.com/sh/mau2nrt6i81ctfc/AAAUkAECzVaVWUMjD3mytjgea?dl=0 https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0,https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb,--rir_folder=tests/tmp --data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5,"file_exists=[train_speaker_embeddings.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/veri_test2.txt,save/label_encoder.txt]"
-Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/speaker_verification_cosine.py,recipes/VoxCeleb/SpeakerRec/hparams/verification_ecapa.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,,,--data_folder=tests/samples/ASR/ --verification_file=tests/samples/annotation/verification.txt --train_data=tests/samples/annotation/ASR_train.csv --enrol_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --skip_prep=True --cohort_size=2,"file_exists=[speaker_verification_cosine.py,log.txt,scores.txt,env.log,hyperparams.yaml,save/verification.txt,save/embedding_model.ckpt]"
-Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py,recipes/VoxCeleb/SpeakerRec/hparams/train_x_vectors.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,https://www.dropbox.com/sh/mau2nrt6i81ctfc/AAAUkAECzVaVWUMjD3mytjgea?dl=0 https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0  https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0,https://huggingface.co/speechbrain/spkrec-xvect-voxceleb,--rir_folder=tests/tmp --data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5 --emb_dim=2,"file_exists=[train_speaker_embeddings.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/veri_test2.txt,save/label_encoder.txt]"
-Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/speaker_verification_plda.py,recipes/VoxCeleb/SpeakerRec/hparams/verification_plda_xvector.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,,,--data_folder=tests/samples/ASR/ --verification_file=tests/samples/annotation/verification.txt --train_data=tests/samples/annotation/ASR_train_plda.csv --enrol_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --skip_prep=True --rank_f=2 --emb_dim=2 --pretrain_path=tests/tmp/VoxCeleb_row_4,"file_exists=[speaker_verification_plda.py,log.txt,env.log,hyperparams.yaml,save/verification.txt,save/VoxCeleb1_train_embeddings_stat_obj.pkl,save/stat_enrol.pkl,save/ndx.pkl,save/stat_test.pkl]"
-Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py,recipes/VoxCeleb/SpeakerRec/hparams/train_resnet.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,https://www.dropbox.com/sh/mau2nrt6i81ctfc/AAAUkAECzVaVWUMjD3mytjgea?dl=0 https://www.dropbox.com/sh/yvqn7tn6iqztx9k/AAAhhhbOCUJ47C0LbcpUlzYUa?dl=0,https://huggingface.co/speechbrain/spkrec-resnet-voxceleb,--rir_folder=tests/tmp --data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5,"file_exists=[train_speaker_embeddings.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/veri_test2.txt,save/label_encoder.txt]"
-Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/speaker_verification_cosine.py,recipes/VoxCeleb/SpeakerRec/hparams/verification_resnet.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,,,--data_folder=tests/samples/ASR/ --verification_file=tests/samples/annotation/verification.txt --train_data=tests/samples/annotation/ASR_train.csv --enrol_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --skip_prep=True --cohort_size=2,"file_exists=[speaker_verification_cosine.py,log.txt,scores.txt,env.log,hyperparams.yaml,save/verification.txt,save/embedding_model.ckpt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py,recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0,https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5,"file_exists=[train_speaker_embeddings.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/veri_test2.txt,save/label_encoder.txt]",EER=0.80%
+Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/speaker_verification_cosine.py,recipes/VoxCeleb/SpeakerRec/hparams/verification_ecapa.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,,,--data_folder=tests/samples/ASR/ --verification_file=tests/samples/annotation/verification.txt --train_data=tests/samples/annotation/ASR_train.csv --enrol_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --skip_prep=True --cohort_size=2,"file_exists=[speaker_verification_cosine.py,log.txt,scores.txt,env.log,hyperparams.yaml,save/verification.txt,save/embedding_model.ckpt]",
+Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py,recipes/VoxCeleb/SpeakerRec/hparams/train_x_vectors.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0,https://huggingface.co/speechbrain/spkrec-xvect-voxceleb,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5 --emb_dim=2,"file_exists=[train_speaker_embeddings.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/veri_test2.txt,save/label_encoder.txt]",EER=3.23%
+Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/speaker_verification_plda.py,recipes/VoxCeleb/SpeakerRec/hparams/verification_plda_xvector.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,,,--data_folder=tests/samples/ASR/ --verification_file=tests/samples/annotation/verification.txt --train_data=tests/samples/annotation/ASR_train_plda.csv --enrol_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --skip_prep=True --rank_f=2 --emb_dim=2 --pretrain_path=tests/tmp/VoxCeleb_row_04,"file_exists=[speaker_verification_plda.py,log.txt,env.log,hyperparams.yaml,save/verification.txt,save/VoxCeleb1_train_embeddings_stat_obj.pkl,save/stat_enrol.pkl,save/ndx.pkl,save/stat_test.pkl]",
+Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py,recipes/VoxCeleb/SpeakerRec/hparams/train_resnet.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,https://www.dropbox.com/sh/ab1ma1lnmskedo8/AADsmgOLPdEjSF6wV3KyhNG1a?dl=0,https://huggingface.co/speechbrain/spkrec-resnet-voxceleb,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5,"file_exists=[train_speaker_embeddings.py,train_log.txt,log.txt,env.log,hyperparams.yaml,save/veri_test2.txt,save/label_encoder.txt]",EER=0.95%
+Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/speaker_verification_cosine.py,recipes/VoxCeleb/SpeakerRec/hparams/verification_resnet.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,,,--data_folder=tests/samples/ASR/ --verification_file=tests/samples/annotation/verification.txt --train_data=tests/samples/annotation/ASR_train.csv --enrol_data=tests/samples/annotation/ASR_train.csv --test_data=tests/samples/annotation/ASR_train.csv --skip_prep=True --cohort_size=2,"file_exists=[speaker_verification_cosine.py,log.txt,scores.txt,env.log,hyperparams.yaml,save/verification.txt,save/embedding_model.ckpt]",
+Speaker_recognition,VoxCeleb,recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py,recipes/VoxCeleb/SpeakerRec/hparams/train_ecapa_tdnn_mel_spec.yaml,recipes/VoxCeleb/SpeakerRec/voxceleb_prepare.py,recipes/VoxCeleb/SpeakerRec/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5 --sample_rate=16000,,
diff --git a/tests/recipes/VoxLingua107.csv b/tests/recipes/VoxLingua107.csv
index e1fec0930f1064f67d1886517467281ea0fbb414..ed54a38f7dd90af2781da572bbe977bb3a0ffc66 100644
--- a/tests/recipes/VoxLingua107.csv
+++ b/tests/recipes/VoxLingua107.csv
@@ -1,2 +1,2 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Language-id,VoxLingua107,recipes/VoxLingua107/lang_id/train.py,recipes/VoxLingua107/lang_id/hparams/train_ecapa.yaml,recipes/VoxLingua107/lang_id/create_wds_shards.py,recipes/VoxLingua107/lang_id/README.md,https://www.dropbox.com/sh/72gpuic5m4x8ztz/AAB5R-RVIEsXJtRH8SGkb_oCa?dl=0 ,https://huggingface.co/speechbrain/lang-id-voxlingua107-ecapa,--data_folder=tests/samples/single-mic --rir_folder=tests/tmp --train_meta=tests/samples/lang-shards/meta.json --val_meta=tests/samples/lang-shards/meta.json --number_of_epochs=2 --shards_url= --train_shards=tests/samples/lang-shards/shard-000000.tar --val_shards=tests/samples/lang-shards/shard-000000.tar,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Language-id,VoxLingua107,recipes/VoxLingua107/lang_id/train.py,recipes/VoxLingua107/lang_id/hparams/train_ecapa.yaml,recipes/VoxLingua107/lang_id/create_wds_shards.py,recipes/VoxLingua107/lang_id/README.md,https://www.dropbox.com/sh/72gpuic5m4x8ztz/AAB5R-RVIEsXJtRH8SGkb_oCa?dl=0 ,https://huggingface.co/speechbrain/lang-id-voxlingua107-ecapa,--data_folder=tests/samples/single-mic --train_meta=tests/samples/lang-shards/meta.json --val_meta=tests/samples/lang-shards/meta.json --number_of_epochs=1 --batch_size=3 --shards_url= --train_shards=tests/samples/lang-shards/shard-000000.tar --val_shards=tests/samples/lang-shards/shard-000000.tar,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]",Accuracy=93.3%
diff --git a/tests/recipes/WHAMandWHAMR.csv b/tests/recipes/WHAMandWHAMR.csv
index 00b454de6e64c13a83e9c59d7f9531ef61891655..67a0adb4d16a872f1c8335f2d6c931727380a735 100644
--- a/tests/recipes/WHAMandWHAMR.csv
+++ b/tests/recipes/WHAMandWHAMR.csv
@@ -1,12 +1,12 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-wham-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-whamr-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/convtasnet-whamr-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing False,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/dprnn-whamr-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing False,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-wham.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,https://www.dropbox.com/sh/pxz2xbj76ijd5ci/AAD3c3dHyszk4oHJaa26K1_ha?dl=0,https://huggingface.co/speechbrain/sepformer-wham-enhancement,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,https://www.dropbox.com/sh/kb0xrvi5k168ou2/AAAPB2U6HyyUT1gMoUH8gxQCa?dl=0,https://huggingface.co/speechbrain/sepformer-whamr-enhancement,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,WHAMandWHAMR,recipes/WHAMandWHAMR/separation/train.py,recipes/WHAMandWHAMR/separation/hparams/sepformer-wham.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/separation/README.md,https://www.dropbox.com/sh/sfrgb3xivri432e/AACQodNmiDIKrB9vCeCFUDWUa?dl=0,https://huggingface.co/speechbrain/sepformer-whamr https://huggingface.co/speechbrain/sepformer-whamr16k,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,WHAMandWHAMR,recipes/WHAMandWHAMR/separation/train.py,recipes/WHAMandWHAMR/separation/hparams/sepformer-whamr.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/separation/README.md,https://www.dropbox.com/sh/1sia32z01xbfgvu/AADditsqaTyfN3N6tzfEFPica?dl=0,https://huggingface.co/speechbrain/sepformer-wham,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-wham-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/cnntransformer-whamr-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/convtasnet-whamr-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing False --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/dprnn-whamr-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing False --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-wham.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,https://www.dropbox.com/sh/pxz2xbj76ijd5ci/AAD3c3dHyszk4oHJaa26K1_ha?dl=0,https://huggingface.co/speechbrain/sepformer-wham-enhancement,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNR=14.4 PESQ=3.05
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-16k.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr-DM.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,,,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --dynamic_mixing=False --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Enhancement,WHAMandWHAMR,recipes/WHAMandWHAMR/enhancement/train.py,recipes/WHAMandWHAMR/enhancement/hparams/sepformer-whamr.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/enhancement/README.md,https://www.dropbox.com/sh/kb0xrvi5k168ou2/AAAPB2U6HyyUT1gMoUH8gxQCa?dl=0,https://huggingface.co/speechbrain/sepformer-whamr-enhancement,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNR=10.6 PESQ=2.84
+Separation,WHAMandWHAMR,recipes/WHAMandWHAMR/separation/train.py,recipes/WHAMandWHAMR/separation/hparams/sepformer-wham.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/separation/README.md,https://www.dropbox.com/sh/sfrgb3xivri432e/AACQodNmiDIKrB9vCeCFUDWUa?dl=0,https://huggingface.co/speechbrain/sepformer-whamr,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNR=16.5
+Separation,WHAMandWHAMR,recipes/WHAMandWHAMR/separation/train.py,recipes/WHAMandWHAMR/separation/hparams/sepformer-whamr.yaml,recipes/WHAMandWHAMR/enhancement/dynamic_mixing.py,recipes/WHAMandWHAMR/separation/README.md,https://www.dropbox.com/sh/1sia32z01xbfgvu/AADditsqaTyfN3N6tzfEFPica?dl=0,https://huggingface.co/speechbrain/sepformer-wham,--rir_path=tests/tmp --data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation_processed_16k --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --sample_rate=16000 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNR=14.0
diff --git a/tests/recipes/WSJ0Mix.csv b/tests/recipes/WSJ0Mix.csv
index 17d3d708fd600aaf9d7c0e285e32d5acae11c5ec..bde91ea405169d9da96f122c726b40c05dd9f527 100644
--- a/tests/recipes/WSJ0Mix.csv
+++ b/tests/recipes/WSJ0Mix.csv
@@ -1,8 +1,8 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Separation,WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/convtasnet.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/hdpxj47signsay7/AABbDjGoyQesnFxjg0APxl7qa?dl=0,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/dprnn.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/o8fohu5s07h4bnw/AADPNyR1E3Q4aRobg3FtXTwVa?dl=0,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/resepformer.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/obnu87zhubn1iia/AAAbn_jzqzIfeqaE9YQ7ujyQa?dl=0,https://huggingface.co/speechbrain/resepformer-wsj02mix,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/sepformer-conformerintra.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/w27rbdfnrtntrc9/AABCMFFvnxxYkKTInYXtsow3a?dl=0 ,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --num_layers_inter=2 --num_layers_intra=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/sepformer-customdataset.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/sepformer.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/9klsqadkhin6fw1/AADEqGdT98rcqxVgFlfki7Gva?dl=0 ,https://huggingface.co/speechbrain/sepformer-wsj02mix https://huggingface.co/speechbrain/sepformer-wsj03mix,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
-Separation,WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/skim.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/zy0l5rc8abxdfp3/AAA2ngB74fugqpWXmjZo5v3wa?dl=0,https://huggingface.co/speechbrain/resepformer-wsj02mix ,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Separation (2mix),WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/convtasnet.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/hdpxj47signsay7/AABbDjGoyQesnFxjg0APxl7qa?dl=0,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=14.8dB
+Separation (2mix),WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/dprnn.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/o8fohu5s07h4bnw/AADPNyR1E3Q4aRobg3FtXTwVa?dl=0,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=18.5dB
+Separation (2mix),WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/resepformer.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/obnu87zhubn1iia/AAAbn_jzqzIfeqaE9YQ7ujyQa?dl=0,https://huggingface.co/speechbrain/resepformer-wsj02mix,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=18.6dB
+Separation (2mix),WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/sepformer-conformerintra.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/w27rbdfnrtntrc9/AABCMFFvnxxYkKTInYXtsow3a?dl=0 ,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --num_layers_inter=2 --num_layers_intra=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Separation (2mix),WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/sepformer-customdataset.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,,,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",
+Separation (2mix),WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/sepformer.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/9klsqadkhin6fw1/AADEqGdT98rcqxVgFlfki7Gva?dl=0 ,https://huggingface.co/speechbrain/sepformer-wsj02mix,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=22.4dB
+Separation (2mix),WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/skim.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://www.dropbox.com/sh/zy0l5rc8abxdfp3/AAA2ngB74fugqpWXmjZo5v3wa?dl=0,https://huggingface.co/speechbrain/resepformer-wsj02mix ,--data_folder=tests/samples/separation --base_folder_dm=tests/samples/separation --train_data=tests/samples/annotation/separation_train.csv --valid_data=tests/samples/annotation/separation_dev.csv --test_data=tests/samples/annotation/separation_dev.csv --skip_prep=True --N_epochs=2 --use_wavedrop=True,"file_exists=[test_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]",SI-SNRi=18.1dB
diff --git a/tests/recipes/ZaionEmotionDataset.csv b/tests/recipes/ZaionEmotionDataset.csv
index cef6a8a6b4d0e90b823c94dcb8b95afb99f8756f..784a83b2c27ffc948c68e76fc24e5ec51845e4d1 100644
--- a/tests/recipes/ZaionEmotionDataset.csv
+++ b/tests/recipes/ZaionEmotionDataset.csv
@@ -1,2 +1,2 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-Emotion_Diarization,ZaionEmotionDataset,recipes/ZaionEmotionDataset/emotion_diarization/train.py,recipes/ZaionEmotionDataset/emotion_diarization/hparams/train.yaml,recipes/ZaionEmotionDataset/emotion_diarization/zed_prepare.py,recipes/ZaionEmotionDataset/README.md,https://www.dropbox.com/sh/woudm1v31a7vyp5/AADAMxpQOXaxf8E_1hX202GJa?dl=0,https://huggingface.co/speechbrain/emotion-diarization-wavlm-large,--emovdb_folder=tests/samples/ASR/ --esd_folder=tests/samples/ASR/ --iemocap_folder=tests/samples/ASR/ --jlcorpus_folder=tests/samples/ASR/ --ravdess_folder=tests/samples/ASR/ --zed_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[eder.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+Emotion_Diarization,ZaionEmotionDataset,recipes/ZaionEmotionDataset/emotion_diarization/train.py,recipes/ZaionEmotionDataset/emotion_diarization/hparams/train.yaml,recipes/ZaionEmotionDataset/emotion_diarization/zed_prepare.py,recipes/ZaionEmotionDataset/README.md,https://www.dropbox.com/sh/woudm1v31a7vyp5/AADAMxpQOXaxf8E_1hX202GJa?dl=0,https://huggingface.co/speechbrain/emotion-diarization-wavlm-large,--emovdb_folder=tests/samples/ASR/ --esd_folder=tests/samples/ASR/ --iemocap_folder=tests/samples/ASR/ --jlcorpus_folder=tests/samples/ASR/ --ravdess_folder=tests/samples/ASR/ --zed_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_dev.json --test_annotation=tests/samples/annotation/ASR_dev.json --number_of_epochs=2 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,"file_exists=[eder.txt,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml,save/label_encoder.txt]",EDER=30.2%
diff --git a/tests/recipes/fluent-speech-commands.csv b/tests/recipes/fluent-speech-commands.csv
index 589aed4cb335834089b416e5dc78c2835d10872b..2bf2740aa6e6077d3cc7ba24156152144b76411a 100644
--- a/tests/recipes/fluent-speech-commands.csv
+++ b/tests/recipes/fluent-speech-commands.csv
@@ -1,3 +1,3 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-SLU,fluent-speech-commands,recipes/fluent-speech-commands/direct/train.py,recipes/fluent-speech-commands/direct/hparams/train.yaml,recipes/fluent-speech-commands/direct/prepare.py,recipes/fluent-speech-commands/README.md,https://www.dropbox.com/sh/wal9ap0go9f66qw/AADBVlGs_E2pEU4vYJgEe3Fba?dl=0,,--rir_folder=tests/tmp --data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/FSC_tokenizer/tokenizer.ckpt]"
-Tokenizer,fluent-speech-commands,recipes/fluent-speech-commands/Tokenizer/train.py,recipes/fluent-speech-commands/Tokenizer/hparams/tokenizer_bpe51.yaml,recipes/fluent-speech-commands/Tokenizer/prepare.py,recipes/fluent-speech-commands/README.md,https://www.dropbox.com/sh/wal9ap0go9f66qw/AADBVlGs_E2pEU4vYJgEe3Fba?dl=0,https://huggingface.co/speechbrain/slu-direct-fluent-speech-commands-librispeech-asr,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=17,"file_exists=[17_unigram.vocab,log.txt,17_unigram.model,ASR_train.txt,env.log,train.py,hyperparams.yaml]"
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+SLU,fluent-speech-commands,recipes/fluent-speech-commands/direct/train.py,recipes/fluent-speech-commands/direct/hparams/train.yaml,recipes/fluent-speech-commands/direct/prepare.py,recipes/fluent-speech-commands/README.md,https://www.dropbox.com/sh/wal9ap0go9f66qw/AADBVlGs_E2pEU4vYJgEe3Fba?dl=0,,--data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[train_log.txt,log.txt,wer_test.txt,env.log,train.py,hyperparams.yaml,save/FSC_tokenizer/tokenizer.ckpt]",Test-accuracy=99.60%
+Tokenizer,fluent-speech-commands,recipes/fluent-speech-commands/Tokenizer/train.py,recipes/fluent-speech-commands/Tokenizer/hparams/tokenizer_bpe51.yaml,recipes/fluent-speech-commands/Tokenizer/prepare.py,recipes/fluent-speech-commands/README.md,https://www.dropbox.com/sh/wal9ap0go9f66qw/AADBVlGs_E2pEU4vYJgEe3Fba?dl=0,https://huggingface.co/speechbrain/slu-direct-fluent-speech-commands-librispeech-asr,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=17,"file_exists=[17_unigram.vocab,log.txt,17_unigram.model,ASR_train.txt,env.log,train.py,hyperparams.yaml]",
diff --git a/tests/recipes/full_inference.csv b/tests/recipes/full_inference.csv
index 8901f8440c6861895a6d353776764abf5666d3c9..d8c9a404fcdd734f04f35aecb752bee2e008ebf1 100644
--- a/tests/recipes/full_inference.csv
+++ b/tests/recipes/full_inference.csv
@@ -1,3 +1,6 @@
 Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,test_download, test_message
-full_inference,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml,recipes/LibriSpeech/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/653kq8h2k87md4p/AAByAaAryXtQKpRzYtzV9ih5a?dl=0,https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech,--data_folder=tests/download/LibriSpeech/LibriSpeech  --output_folder=tests/download/LibriSpeech_ASR_transformer --train_splits=[] --dev_splits=[] --test_splits=['test-clean'] --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_test_librispeech_clean.csv] --test_only,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_test-clean.txt] performance_check=[train_log.txt, test WER, <=3.70, epoch:-1]","download_file('https://www.openslr.org/resources/12/test-clean.tar.gz', 'tests/download/test-clean.tar.gz', unpack=True, dest_unpack='tests/download/LibriSpeech',write_permissions=True);download_file('https://www.dropbox.com/sh/wy6ktfdxqt0dmou/AABY4ty3zy3xPef0zwRx949Oa?dl=1', 'tests/download/LibriSpeech_ASR_transformer.zip', unpack=True, dest_unpack='tests/download/LibriSpeech_ASR_transformer', write_permissions=True)", "Running an Inference Test."
-full_inference,LibriSpeech,recipes/LibriSpeech/ASR/seq2seq/train.py,recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml,recipes/LibriSpeech/librispeech_prepare.py,recipes/LibriSpeech/ASR/seq2seq/README.md,https://www.dropbox.com/sh/1ycv07gyxdq8hdl/AABUDYzza4SLYtY45RcGf2_0a?dl=0 https://www.dropbox.com/sh/a39wq3h60luv552/AABBnCM2Uf-CNax_cgMWdqDda?dl=0,https://huggingface.co/speechbrain/asr-crdnn-transformerlm-librispeech,--data_folder=tests/download/LibriSpeech/LibriSpeech  --output_folder=tests/download/LibriSpeech_ASR_seq2seq --train_splits=[] --dev_splits=[] --test_splits=['test-clean'] --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_test_librispeech_clean.csv]  --data_folder_rirs=tests/tmp --test_only,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_test-clean.txt] performance_check=[train_log.txt, test WER, <=4.17, epoch:-1]","download_file('https://www.openslr.org/resources/12/test-clean.tar.gz', 'tests/download/test-clean.tar.gz', unpack=True, dest_unpack='tests/download/LibriSpeech', write_permissions=True);download_file('https://www.dropbox.com/sh/vs1os5mwh4necp0/AAAs7rMl4IO4N4s__BbVon6qa?dl=1', 'tests/download/LibriSpeech_ASR_seq2seq.zip', unpack=True, dest_unpack='tests/download/LibriSpeech_ASR_seq2seq', write_permissions=True)", "Running an Inference Test."
+full_inference,LibriSpeech,recipes/LibriSpeech/ASR/transformer/train.py,recipes/LibriSpeech/ASR/transformer/hparams/transformer.yaml,recipes/LibriSpeech/librispeech_prepare.py,recipes/LibriSpeech/ASR/transformer/README.md,https://www.dropbox.com/sh/653kq8h2k87md4p/AAByAaAryXtQKpRzYtzV9ih5a?dl=0,https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech,--data_folder=tests/download/LibriSpeech/LibriSpeech  --output_folder=tests/tmp/LibriSpeech_ASR_transformer --train_splits=[] --dev_splits=[] --test_splits=['test-clean'] --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_test_librispeech_clean.csv] --test_only,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_test-clean.txt] performance_check=[train_log.txt, test WER, <=3.70, epoch:-1]","download_file('https://www.openslr.org/resources/12/test-clean.tar.gz', 'tests/download/test-clean.tar.gz', unpack=True, dest_unpack='tests/download/LibriSpeech',write_permissions=True);download_file('https://www.dropbox.com/sh/wy6ktfdxqt0dmou/AABY4ty3zy3xPef0zwRx949Oa?dl=1', 'tests/download/LibriSpeech_ASR_transformer.zip', unpack=True, dest_unpack='tests/tmp/LibriSpeech_ASR_transformer', write_permissions=True)", "Running an Inference Test."
+full_inference,LibriSpeech,recipes/LibriSpeech/ASR/seq2seq/train.py,recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_5000.yaml,recipes/LibriSpeech/librispeech_prepare.py,recipes/LibriSpeech/ASR/seq2seq/README.md,https://www.dropbox.com/sh/1ycv07gyxdq8hdl/AABUDYzza4SLYtY45RcGf2_0a?dl=0 https://www.dropbox.com/sh/a39wq3h60luv552/AABBnCM2Uf-CNax_cgMWdqDda?dl=0,https://huggingface.co/speechbrain/asr-crdnn-transformerlm-librispeech,--data_folder=tests/download/LibriSpeech/LibriSpeech  --output_folder=tests/tmp/LibriSpeech_ASR_seq2seq --train_splits=[] --dev_splits=[] --test_splits=['test-clean'] --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_test_librispeech_clean.csv] --test_only,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_test-clean.txt] performance_check=[train_log.txt, test WER, <=4.17, epoch:-1]","download_file('https://www.openslr.org/resources/12/test-clean.tar.gz', 'tests/download/test-clean.tar.gz', unpack=True, dest_unpack='tests/download/LibriSpeech', write_permissions=True);download_file('https://www.dropbox.com/sh/vs1os5mwh4necp0/AAAs7rMl4IO4N4s__BbVon6qa?dl=1', 'tests/download/LibriSpeech_ASR_seq2seq.zip', unpack=True, dest_unpack='tests/tmp/LibriSpeech_ASR_seq2seq', write_permissions=True)", "Running an Inference Test."
+full_inference,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec.yaml,recipes/LibriSpeech/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,https://www.dropbox.com/sh/al3qzi1b4cgvvfo/AAAWf0b-uMIPcfaU3DgjPcbMa?dl=0 https://www.dropbox.com/sh/a39wq3h60luv552/AABBnCM2Uf-CNax_cgMWdqDda?dl=0,,--data_folder=tests/download/LibriSpeech/LibriSpeech   --output_folder=tests/tmp/LibriSpeech_ASR_CTC_wav2vec --train_splits=[] --dev_splits=[] --test_splits=['test-clean'] --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_test_librispeech_clean.csv] --test_only,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train_with_wav2vec.py] performance_check=[train_log.txt, test WER, <=5.09, epoch:-1]","download_file('https://www.openslr.org/resources/12/test-clean.tar.gz', 'tests/download/test-clean.tar.gz', unpack=True, dest_unpack='tests/download/LibriSpeech', write_permissions=True);download_file('https://www.dropbox.com/sh/al3qzi1b4cgvvfo/AAAWf0b-uMIPcfaU3DgjPcbMa?dl=1', 'tests/download/LibriSpeech_ASR_CTC_wav2vec.zip', unpack=True, dest_unpack='tests/tmp/LibriSpeech_ASR_CTC_wav2vec', write_permissions=True)", "Running an Inference Test."
+full_inference,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_transformer_rescoring.yaml,recipes/LibriSpeech/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,https://www.dropbox.com/sh/jvi4cvdt2cnxs8s/AAAjRqdNUkcMXbBuoNqK0iZ8a?dl=0 https://www.dropbox.com/sh/a39wq3h60luv552/AABBnCM2Uf-CNax_cgMWdqDda?dl=0,,--data_folder=tests/download/LibriSpeech/LibriSpeech   --output_folder=tests/tmp/LibriSpeech_ASR_CTC_wav2vec_transformer_rescoring --train_splits=[] --dev_splits=[] --test_splits=['test-clean'] --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_test_librispeech_clean.csv] --test_only,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train_with_wav2vec.py] performance_check=[train_log.txt, test WER, <=4.17, epoch:-1]","download_file('https://www.openslr.org/resources/12/test-clean.tar.gz', 'tests/download/test-clean.tar.gz', unpack=True, dest_unpack='tests/download/LibriSpeech', write_permissions=True);download_file('https://www.dropbox.com/sh/jvi4cvdt2cnxs8s/AAAjRqdNUkcMXbBuoNqK0iZ8a?dl=1', 'tests/download/LibriSpeech_ASR_CTC_wav2vec_transformer_rescoring.zip', unpack=True, dest_unpack='tests/tmp/LibriSpeech_ASR_CTC_wav2vec_transformer_rescoring', write_permissions=True)", "Running an Inference Test."
+full_inference,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,recipes/LibriSpeech/ASR/CTC/hparams/train_hf_wav2vec_rnn_rescoring.yaml,recipes/LibriSpeech/librispeech_prepare.py,recipes/LibriSpeech/ASR/CTC/README.md,https://www.dropbox.com/sh/xmn1pdkhiog6dd3/AAANM22MPrI7DTt1jYF8hwOWa?dl=0 https://www.dropbox.com/sh/a39wq3h60luv552/AABBnCM2Uf-CNax_cgMWdqDda?dl=0,,--data_folder=tests/download/LibriSpeech/LibriSpeech   --output_folder=tests/tmp/LibriSpeech_ASR_CTC_wav2vec_rnn_rescoring --train_splits=[] --dev_splits=[] --test_splits=['test-clean'] --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_test_librispeech_clean.csv] --test_only,"file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train_with_wav2vec.py] performance_check=[train_log.txt, test WER, <=4.17, epoch:-1]","download_file('https://www.openslr.org/resources/12/test-clean.tar.gz', 'tests/download/test-clean.tar.gz', unpack=True, dest_unpack='tests/download/LibriSpeech', write_permissions=True);download_file('https://www.dropbox.com/sh/xmn1pdkhiog6dd3/AAANM22MPrI7DTt1jYF8hwOWa?dl=1', 'tests/download/LibriSpeech_ASR_CTC_wav2vec_rnn_rescoring.zip', unpack=True, dest_unpack='tests/tmp/LibriSpeech_ASR_CTC_wav2vec_rnn_rescoring', write_permissions=True)", "Running an Inference Test."
diff --git a/tests/recipes/setup/recipes_AMI_Diarization_experiment/hparams_ecapa_tdnn b/tests/recipes/setup/recipes_AMI_Diarization_experiment/hparams_ecapa_tdnn
index 0e8d1778f8073bd1011efe46e9c72c475e58f7d1..345514506c1c4b1e7975d4c9a40030d50f9ad63f 100755
--- a/tests/recipes/setup/recipes_AMI_Diarization_experiment/hparams_ecapa_tdnn
+++ b/tests/recipes/setup/recipes_AMI_Diarization_experiment/hparams_ecapa_tdnn
@@ -1,6 +1,6 @@
-mkdir -p tests/tmp/AMI_row_2/save/ref_rttms
-mkdir -p tests/tmp/AMI_row_2/save/metadata
-mkdir -p tests/tmp/AMI_row_2/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/
-touch tests/tmp/AMI_row_2/save//sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/sys_output.rttm
-ln -s `realpath tests/tmp/AMI_row_2/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/sys_output.rttm` tests/tmp/AMI_row_2/save/ref_rttms/fullref_ami_dev.rttm
-ln -s `realpath tests/tmp/AMI_row_2/save/ref_rttms/fullref_ami_dev.rttm` tests/tmp/AMI_row_2/save/ref_rttms/fullref_ami_eval.rttm
+mkdir -p tests/tmp/AMI_row_02/save/ref_rttms
+mkdir -p tests/tmp/AMI_row_02/save/metadata
+mkdir -p tests/tmp/AMI_row_02/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/
+touch tests/tmp/AMI_row_02/save//sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/sys_output.rttm
+ln -s `realpath tests/tmp/AMI_row_02/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/sys_output.rttm` tests/tmp/AMI_row_02/save/ref_rttms/fullref_ami_dev.rttm
+ln -s `realpath tests/tmp/AMI_row_02/save/ref_rttms/fullref_ami_dev.rttm` tests/tmp/AMI_row_02/save/ref_rttms/fullref_ami_eval.rttm
diff --git a/tests/recipes/setup/recipes_AMI_Diarization_experiment/hparams_xvectors b/tests/recipes/setup/recipes_AMI_Diarization_experiment/hparams_xvectors
index 5c02188d21fff09d4b3f6530efbc6e9fd968bd72..4263af2edf172521beca45ca74107f76caf981d8 100755
--- a/tests/recipes/setup/recipes_AMI_Diarization_experiment/hparams_xvectors
+++ b/tests/recipes/setup/recipes_AMI_Diarization_experiment/hparams_xvectors
@@ -1,6 +1,6 @@
-mkdir -p tests/tmp/AMI_row_3/save/ref_rttms
-mkdir -p tests/tmp/AMI_row_3/save/metadata
-mkdir -p tests/tmp/AMI_row_3/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/
+mkdir -p tests/tmp/AMI_row_03/save/ref_rttms
+mkdir -p tests/tmp/AMI_row_03/save/metadata
+mkdir -p tests/tmp/AMI_row_03/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/
 touch tests/tmp/AMI_row_3/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/sys_output.rttm
-ln -s `realpath tests/tmp/AMI_row_3/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/sys_output.rttm` tests/tmp/AMI_row_3/save/ref_rttms/fullref_ami_dev.rttm
-ln -s `realpath tests/tmp/AMI_row_3/save/ref_rttms/fullref_ami_dev.rttm` tests/tmp/AMI_row_3/save/ref_rttms/fullref_ami_eval.rttm
+ln -s `realpath tests/tmp/AMI_row_03/save/sys_rttms/Mix-Headset/AMI_dev/est_cos_SC/sys_output.rttm` tests/tmp/AMI_row_03/save/ref_rttms/fullref_ami_dev.rttm
+ln -s `realpath tests/tmp/AMI_row_03/save/ref_rttms/fullref_ami_dev.rttm` tests/tmp/AMI_row_03/save/ref_rttms/fullref_ami_eval.rttm
diff --git a/tests/recipes/setup/recipes_TIMIT_ASR_seq2seq_knowledge_distillation_save_teachers/hparams_save_teachers b/tests/recipes/setup/recipes_TIMIT_ASR_seq2seq_knowledge_distillation_save_teachers/hparams_save_teachers
index b5c602c8d3c55d5c335033da9ed8034b309bbafe..211434a61ff044d4282a0a17af8a3fe0d1ddbaf5 100755
--- a/tests/recipes/setup/recipes_TIMIT_ASR_seq2seq_knowledge_distillation_save_teachers/hparams_save_teachers
+++ b/tests/recipes/setup/recipes_TIMIT_ASR_seq2seq_knowledge_distillation_save_teachers/hparams_save_teachers
@@ -1,10 +1,10 @@
 mkdir -p tests/tmp/TIMIT_row_15/
 rm -f tests/tmp/TIMIT_row_15/tea_models.txt
-echo `find tests/tmp/TIMIT_row_5/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
-echo `find tests/tmp/TIMIT_row_6/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
-echo `find tests/tmp/TIMIT_row_7/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
-echo `find tests/tmp/TIMIT_row_8/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
-echo `find tests/tmp/TIMIT_row_9/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
+echo `find tests/tmp/TIMIT_row_05/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
+echo `find tests/tmp/TIMIT_row_06/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
+echo `find tests/tmp/TIMIT_row_07/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
+echo `find tests/tmp/TIMIT_row_08/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
+echo `find tests/tmp/TIMIT_row_09/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
 echo `find tests/tmp/TIMIT_row_10/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
 echo `find tests/tmp/TIMIT_row_11/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
 echo `find tests/tmp/TIMIT_row_12/save -type d | tail -n1`/model.ckpt >> tests/tmp/TIMIT_row_15/tea_models.txt
diff --git a/tests/recipes/setup/recipes_TIMIT_ASR_seq2seq_knowledge_distillation_train_kd/hparams_train_kd b/tests/recipes/setup/recipes_TIMIT_ASR_seq2seq_knowledge_distillation_train_kd/hparams_train_kd
index 852e2aa5c29c52dceb1da3e9362ec85b96836e04..f2fc8b65a9f7b41dbba3df82c000d5d4d6bf1b4a 100755
--- a/tests/recipes/setup/recipes_TIMIT_ASR_seq2seq_knowledge_distillation_train_kd/hparams_train_kd
+++ b/tests/recipes/setup/recipes_TIMIT_ASR_seq2seq_knowledge_distillation_train_kd/hparams_train_kd
@@ -1,2 +1,2 @@
-fpath=`find tests/tmp/TIMIT_row_5 | grep model.ckpt | tail -n 1`
-ln -s `realpath $fpath` tests/tmp/TIMIT_row_5/model.ckpt
+fpath=`find tests/tmp/TIMIT_row_05 | grep model.ckpt | tail -n 1`
+ln -s `realpath $fpath` tests/tmp/TIMIT_row_05/model.ckpt
diff --git a/tests/recipes/setup/recipes_VoxCeleb_SpeakerRec_speaker_verification_plda/hparams_verification_plda_xvector b/tests/recipes/setup/recipes_VoxCeleb_SpeakerRec_speaker_verification_plda/hparams_verification_plda_xvector
index 284b2f02fbc9f484b83f9f44d062820836cf363b..6f97f6da9eb621f7139c71b76db8a905629ff1c6 100755
--- a/tests/recipes/setup/recipes_VoxCeleb_SpeakerRec_speaker_verification_plda/hparams_verification_plda_xvector
+++ b/tests/recipes/setup/recipes_VoxCeleb_SpeakerRec_speaker_verification_plda/hparams_verification_plda_xvector
@@ -1 +1 @@
-ln -s `realpath tests/tmp/VoxCeleb_row_4/save/CK*/embedding_model*` tests/tmp/VoxCeleb_row_4/embedding_model.ckpt
+ln -s `realpath tests/tmp/VoxCeleb_row_04/save/CK*/embedding_model*` tests/tmp/VoxCeleb_row_04/embedding_model.ckpt
diff --git a/tests/recipes/timers-and-such.csv b/tests/recipes/timers-and-such.csv
index 37a878d86c27168cb94dab41496e72555e938a3e..2d451ea0b4a4078ba65c96197b9bf36002b61eb9 100644
--- a/tests/recipes/timers-and-such.csv
+++ b/tests/recipes/timers-and-such.csv
@@ -1,9 +1,9 @@
-Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
-LM,timers-and-such,recipes/timers-and-such/LM/train.py,recipes/timers-and-such/LM/hparams/train.yaml,recipes/timers-and-such/LM/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0 ,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,
-SLU,timers-and-such,recipes/timers-and-such/decoupled/train.py,recipes/timers-and-such/decoupled/hparams/train_LS_LM.yaml,recipes/timers-and-such/decoupled/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]"
-SLU,timers-and-such,recipes/timers-and-such/decoupled/train.py,recipes/timers-and-such/decoupled/hparams/train_TAS_LM.yaml,recipes/timers-and-such/decoupled/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]"
-SLU,timers-and-such,recipes/timers-and-such/direct/train.py,recipes/timers-and-such/direct/hparams/train.yaml,recipes/timers-and-such/direct/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,https://huggingface.co/speechbrain/slu-timers-and-such-direct-librispeech-asr,--data_folder_rirs=tests/tmp --data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]"
-SLU,timers-and-such,recipes/timers-and-such/direct/train_with_wav2vec2.py,recipes/timers-and-such/direct/hparams/train_with_wav2vec2.yaml,recipes/timers-and-such/direct/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]"
-SLU,timers-and-such,recipes/timers-and-such/multistage/train.py,recipes/timers-and-such/multistage/hparams/train_LS_LM.yaml,recipes/timers-and-such/multistage/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder_rirs=tests/tmp --data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]"
-SLU,timers-and-such,recipes/timers-and-such/multistage/train.py,recipes/timers-and-such/multistage/hparams/train_TAS_LM.yaml,recipes/timers-and-such/multistage/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder_rirs=tests/tmp --data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]"
-Tokenizer,timers-and-such,recipes/timers-and-such/Tokenizer/train.py,recipes/timers-and-such/Tokenizer/hparams/tokenizer_bpe51.yaml,recipes/timers-and-such/Tokenizer/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=17,
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance
+LM,timers-and-such,recipes/timers-and-such/LM/train.py,recipes/timers-and-such/LM/hparams/train.yaml,recipes/timers-and-such/LM/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0 ,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,,
+SLU,timers-and-such,recipes/timers-and-such/decoupled/train.py,recipes/timers-and-such/decoupled/hparams/train_LS_LM.yaml,recipes/timers-and-such/decoupled/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]",
+SLU,timers-and-such,recipes/timers-and-such/decoupled/train.py,recipes/timers-and-such/decoupled/hparams/train_TAS_LM.yaml,recipes/timers-and-such/decoupled/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]",Accuracy-Test_real=46.8%
+SLU,timers-and-such,recipes/timers-and-such/direct/train.py,recipes/timers-and-such/direct/hparams/train.yaml,recipes/timers-and-such/direct/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,https://huggingface.co/speechbrain/slu-timers-and-such-direct-librispeech-asr,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]",Accuracy-Test_real=77.5%
+SLU,timers-and-such,recipes/timers-and-such/direct/train_with_wav2vec2.py,recipes/timers-and-such/direct/hparams/train_with_wav2vec2.yaml,recipes/timers-and-such/direct/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]",Accuracy-Test_real=94.0%
+SLU,timers-and-such,recipes/timers-and-such/multistage/train.py,recipes/timers-and-such/multistage/hparams/train_LS_LM.yaml,recipes/timers-and-such/multistage/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]",
+SLU,timers-and-such,recipes/timers-and-such/multistage/train.py,recipes/timers-and-such/multistage/hparams/train_TAS_LM.yaml,recipes/timers-and-such/multistage/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR --csv_train=tests/samples/annotation/ASR_train.csv --csv_dev_real=tests/samples/annotation/ASR_train.csv --csv_dev_synth=tests/samples/annotation/ASR_train.csv --csv_all_real=tests/samples/annotation/ASR_train.csv --csv_test_synth=tests/samples/annotation/ASR_train.csv --csv_test_real=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,"file_exists=[test_real_wer.txt,test_synth_wer.txt]",Accuracy-Test_real=72.6%
+Tokenizer,timers-and-such,recipes/timers-and-such/Tokenizer/train.py,recipes/timers-and-such/Tokenizer/hparams/tokenizer_bpe51.yaml,recipes/timers-and-such/Tokenizer/prepare.py,recipes/timers-and-such/README.md,https://www.dropbox.com/sh/gmmum179ig9wz0x/AAAOSOi11yVymGXHp9LzYNrqa?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True --token_output=17,,
diff --git a/tests/samples/TTS/codes/LJ050-0131.npy b/tests/samples/TTS/codes/LJ050-0131.npy
new file mode 100644
index 0000000000000000000000000000000000000000..a9e12540ef272eddf93923c4a241e7fa1d7313ab
Binary files /dev/null and b/tests/samples/TTS/codes/LJ050-0131.npy differ
diff --git a/tests/samples/annotation/TTS_train.json b/tests/samples/annotation/TTS_train.json
index 6be20323b245c629b318ed75d058b6061571680e..48189e466070255975ee992d380de83634ff7887 100644
--- a/tests/samples/annotation/TTS_train.json
+++ b/tests/samples/annotation/TTS_train.json
@@ -1,67 +1,18 @@
 {
-  "LJ050-01310": {
+  "LJ050-0131": {
     "uttid": "LJ050-0131",
     "wav": "{data_root}/LJ050-0131.wav",
     "label": "unless a system is established for the frequent formal review of activities thereunder. in this regard",
     "segment": true,
     "label_phoneme": "AH N L EH S AH S IH S T AH M IH Z IH S T AE B L IH SH T spn F ER DH AH F R IY K W AH N T F AO R M AH L R IY V Y UW spn AH V AE K T IH V IH T IY Z DH EH R AH N D ER spn IH N DH IH S R AH G AA R D",
+    "phonemes": ["AH", "N", "L", "EH", "S", " ", "AH", " ", "S", "IH", "S", "T", "AH", "M", " ", "IH", "Z", " ", "IH", "S", "T", "AE", "B", "L", "IH", "SH", "T", " ", "F", "AO", "R", " ", "DH", "AH", " ", "F", "R", "IY", "K", "W", "AH", "N", "T", " ", "F", "AO", "R", "M", "AH", "L", " ", "R", "IY", "V", "Y", "UW", " ", "AH", "V", " ", "AE", "K", "T", "IH", "V", "IH", "T", "IY", "Z", " ", "DH", "EH", "R", "AH", "N", "D", "ER", ".", " ", "IH", "N", " ", "DH", "IH", "S", " ", "R", "AH", "G", "AA", "R", "D"],
     "spn_labels": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     "start": 0.0,
     "end": 7.58,
+    "duration": 7.58,
     "durations": "{data_root}/d_LJ050-0131.npy",
     "pitch": "{data_root}/LJ050-0131.npy",
-    "last_phoneme_flags": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1]
-  },
-  "LJ050-01311": {
-    "uttid": "LJ050-0131",
-    "wav": "{data_root}/LJ050-0131.wav",
-    "label": "unless a system is established for the frequent formal review of activities thereunder. in this regard",
-    "segment": true,
-    "label_phoneme": "AH N L EH S AH S IH S T AH M IH Z IH S T AE B L IH SH T spn F ER DH AH F R IY K W AH N T F AO R M AH L R IY V Y UW spn AH V AE K T IH V IH T IY Z DH EH R AH N D ER spn IH N DH IH S R AH G AA R D",
-    "spn_labels": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-    "start": 0.0,
-    "end": 7.58,
-    "durations": "{data_root}/d_LJ050-0131.npy",
-    "pitch": "{data_root}/LJ050-0131.npy",
-    "last_phoneme_flags": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1]
-  },
-  "LJ050-01312": {
-    "uttid": "LJ050-0131",
-    "wav": "{data_root}/LJ050-0131.wav",
-    "label": "unless a system is established for the frequent formal review of activities thereunder. in this regard",
-    "segment": true,
-    "label_phoneme": "AH N L EH S AH S IH S T AH M IH Z IH S T AE B L IH SH T spn F ER DH AH F R IY K W AH N T F AO R M AH L R IY V Y UW spn AH V AE K T IH V IH T IY Z DH EH R AH N D ER spn IH N DH IH S R AH G AA R D",
-    "spn_labels": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-    "start": 0.0,
-    "end": 7.58,
-    "durations": "{data_root}/d_LJ050-0131.npy",
-    "pitch": "{data_root}/LJ050-0131.npy",
-    "last_phoneme_flags": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1]
-  },
-  "LJ050-01313": {
-    "uttid": "LJ050-0131",
-    "wav": "{data_root}/LJ050-0131.wav",
-    "label": "unless a system is established for the frequent formal review of activities thereunder. in this regard",
-    "segment": true,
-    "label_phoneme": "AH N L EH S AH S IH S T AH M IH Z IH S T AE B L IH SH T spn F ER DH AH F R IY K W AH N T F AO R M AH L R IY V Y UW spn AH V AE K T IH V IH T IY Z DH EH R AH N D ER spn IH N DH IH S R AH G AA R D",
-    "spn_labels": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-    "start": 0.0,
-    "end": 7.58,
-    "durations": "{data_root}/d_LJ050-0131.npy",
-    "pitch": "{data_root}/LJ050-0131.npy",
-    "last_phoneme_flags": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1]
-  },
-  "LJ050-01314": {
-    "uttid": "LJ050-0131",
-    "wav": "{data_root}/LJ050-0131.wav",
-    "label": "unless a system is established for the frequent formal review of activities thereunder. in this regard",
-    "segment": true,
-    "label_phoneme": "AH N L EH S AH S IH S T AH M IH Z IH S T AE B L IH SH T spn F ER DH AH F R IY K W AH N T F AO R M AH L R IY V Y UW spn AH V AE K T IH V IH T IY Z DH EH R AH N D ER spn IH N DH IH S R AH G AA R D",
-    "spn_labels": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-    "start": 0.0,
-    "end": 7.58,
-    "durations": "{data_root}/d_LJ050-0131.npy",
-    "pitch": "{data_root}/LJ050-0131.npy",
-    "last_phoneme_flags": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1]
+    "last_phoneme_flags": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1],
+    "spk_id": "0000"
   }
-}
\ No newline at end of file
+}
diff --git a/tests/samples/annotation/response_generation_train_multiwoz.json b/tests/samples/annotation/response_generation_train_multiwoz.json
new file mode 100644
index 0000000000000000000000000000000000000000..696a103a7a4417487ee9139a7f2660217f327d97
--- /dev/null
+++ b/tests/samples/annotation/response_generation_train_multiwoz.json
@@ -0,0 +1,74 @@
+{
+    "PMUL0698.json_1": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food ."
+        ],
+        "reply": "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+        "length": 178
+    },
+    "PMUL0698.json_3": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range ."
+        ],
+        "reply": "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+        "length": 213.33333333333331
+    },
+    "PMUL0698.json_5": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range .",
+            "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+            "i also need a train . the train should leave after 16:15 and should leave on sunday ."
+        ],
+        "reply": "can i have more information for the train you are needing ? where are you departing from and arriving to ?",
+        "length": 196.2
+    },
+    "PMUL0698.json_7": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range .",
+            "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+            "i also need a train . the train should leave after 16:15 and should leave on sunday .",
+            "can i have more information for the train you are needing ? where are you departing from and arriving to ?",
+            "i am leaving from cambridge and going to norwich ."
+        ],
+        "reply": "i have train tr1840 leaving at 16:36 is that okay ?",
+        "length": 137.71428571428572
+    },
+    "PMUL0698.json_9": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range .",
+            "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+            "i also need a train . the train should leave after 16:15 and should leave on sunday .",
+            "can i have more information for the train you are needing ? where are you departing from and arriving to ?",
+            "i am leaving from cambridge and going to norwich .",
+            "i have train tr1840 leaving at 16:36 is that okay ?",
+            "book for 5 people and get me the reference number"
+        ],
+        "reply": "you are all set . reference number is njb87pap . is there anything else i can help you with today ?",
+        "length": 177.55555555555554
+    },
+    "PMUL0698.json_11": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range .",
+            "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+            "i also need a train . the train should leave after 16:15 and should leave on sunday .",
+            "can i have more information for the train you are needing ? where are you departing from and arriving to ?",
+            "i am leaving from cambridge and going to norwich .",
+            "i have train tr1840 leaving at 16:36 is that okay ?",
+            "book for 5 people and get me the reference number",
+            "you are all set . reference number is njb87pap . is there anything else i can help you with today ?",
+            "no , this is all i will need . thank you ."
+        ],
+        "reply": "thank for calling us today . i hope you have a good trip .",
+        "length": 135.0909090909091
+    }
+}
\ No newline at end of file
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/ASR.yaml b/tests/templates/fetching_ddp_dynbatch_finetuning/ASR.yaml
index 85a09d5d8ce16af8d2821d27872d658283539ffc..a61f332b29697b531892a64050c643d9eb2ef99d 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/ASR.yaml
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/ASR.yaml
@@ -6,7 +6,7 @@
 #
 # Goal: more clarity on when, where & how to use
 #    speechbrain.utils.parameter_transfer.Pretrainer
-#    speechbrain.pretrained.interfaces.Pretrained
+#    speechbrain.inference.interfaces.Pretrained
 #
 # Authors:  Andreas Nautsch 2023
 # # ############################################################################
@@ -132,26 +132,58 @@ lm_model: !new:speechbrain.lobes.models.RNNLM.RNNLM
     dnn_neurons: 512
     return_hidden: True  # For inference
 
+# Define your scorers for beam searcher
+
+# If the lm_scorer is set, a language model
+# is applied (with a weight specified in scorer).
+rnnlm_scorer: !new:speechbrain.decoders.scorer.RNNLMScorer
+    language_model: !ref <lm_model>
+    temperature: !ref <temperature_lm>
+
+# If ctc_scorer is set, the decoder uses CTC + attention beamsearch. This
+# improves the performance, but slows down decoding.
+ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
+    eos_index: !ref <eos_index>
+    blank_index: !ref <blank_index>
+    ctc_fc: !ref <ctc_lin>
+
+# If coverage_scorer is set, coverage penalty is applied based on accumulated
+# attention weights during beamsearch.
+coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
+    vocab_size: !ref <output_neurons>
+
+# Gathering all scorers in a scorer instance for beamsearch:
+# - full_scorers are scorers which score on full vocab set, while partial_scorers
+# are scorers which score on pruned tokens.
+# - The number of pruned tokens is decided by scorer_beam_scale * beam_size.
+# - For some scorers like ctc_scorer, ngramlm_scorer, putting them
+# into full_scorers list would be too heavy. partial_scorers are more
+# efficient because they score on pruned tokens at little cost of
+# performance drop. For other scorers, please see the speechbrain.decoders.scorer.
+test_scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
+    scorer_beam_scale: 1.5
+    full_scorers: [
+        !ref <rnnlm_scorer>,
+        !ref <coverage_scorer>]
+    partial_scorers: [!ref <ctc_scorer>]
+    weights:
+        rnnlm: !ref <lm_weight>
+        coverage: !ref <coverage_penalty>
+        ctc: !ref <ctc_weight_decode>
+
 # The final decoding on the test set can be more computationally demanding.
 # In this case, we use the LM + CTC probabilities during decoding as well.
 # Please, remove this part if you need a faster decoder.
-test_search: !new:speechbrain.decoders.S2SRNNBeamSearchLM
+test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
     embedding: !ref <embedding>
     decoder: !ref <decoder>
     linear: !ref <seq_lin>
-    ctc_linear: !ref <ctc_lin>
-    language_model: !ref <lm_model>
     bos_index: !ref <bos_index>
     eos_index: !ref <eos_index>
-    blank_index: !ref <blank_index>
     min_decode_ratio: !ref <min_decode_ratio>
     max_decode_ratio: !ref <max_decode_ratio>
     beam_size: !ref <test_beam_size>
     eos_threshold: !ref <eos_threshold>
     using_max_attn_shift: !ref <using_max_attn_shift>
     max_attn_shift: !ref <max_attn_shift>
-    coverage_penalty: !ref <coverage_penalty>
-    lm_weight: !ref <lm_weight>
-    ctc_weight: !ref <ctc_weight_decode>
     temperature: !ref <temperature>
-    temperature_lm: !ref <temperature_lm>
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/README.md b/tests/templates/fetching_ddp_dynbatch_finetuning/README.md
index 636bcd762d50c03730c93fc625a5347c561aa8cf..40c411d6326735a4cc2eddac860455d4a88ceb15 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/README.md
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/README.md
@@ -37,7 +37,7 @@ provides symlinks in `source_pretrained` to an expected CKPT that is created dur
 
 How to run with DDP:
 ```shell
-CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../.. python3 -m torch.distributed.launch --nproc_per_node=2 finetune.py finetune.yaml --distributed_launch --distributed_backend='nccl'
+CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../.. torchrun --nproc_per_node=2 finetune.py finetune.yaml
 ```
 
 ## Sanity check: standard model card from one of our HuggingFace repos
@@ -58,7 +58,7 @@ cd ../../.. && PYTHONPATH=. python tests/templates/fetching_ddp_dynbatch_finetun
 
 How to run with DDP:
 ```shell
-CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../.. python3 -m torch.distributed.launch --nproc_per_node=2 finetune_fetch_once.py finetune_fetch_once.yaml --distributed_launch --distributed_backend='nccl'
+CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../.. torchrun --nproc_per_node=2 finetune_fetch_once.py finetune_fetch_once.yaml
 ```
 
 
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/finetune.py b/tests/templates/fetching_ddp_dynbatch_finetuning/finetune.py
index d8c2a07d4fb7a15a8dc6f70c3a6404485f6e51a6..5a9d367a7945c9fee86f0f9786199327bb1b19d6 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/finetune.py
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/finetune.py
@@ -16,7 +16,7 @@ from copy import deepcopy
 from tqdm.contrib import tqdm
 from torch.utils.data import DataLoader
 from hyperpyyaml import load_hyperpyyaml
-from speechbrain.pretrained import EncoderDecoderASR
+from speechbrain.inference.ASR import EncoderDecoderASR
 from speechbrain.utils.distributed import run_on_main, ddp_barrier
 from speechbrain.utils.data_utils import batch_pad_right
 from speechbrain.dataio.dataset import DynamicItemDataset
@@ -67,11 +67,11 @@ def eval_reporting(reports, single_node=False):
 def eval_test_use_recipe_dataio(
     encoder_decoder_asr, test_set, test_kwargs, reporter, single_node=False
 ):
-    """Bypassing speechbrain.pretrained.Pretrained.load_audio with recipe dataio (speechbrain.dataio.dataio.read_audio).
+    """Bypassing speechbrain.inference.interfaces.Pretrained.load_audio with recipe dataio (speechbrain.dataio.dataio.read_audio).
 
     Parameters
     ----------
-    encoder_decoder_asr: speechbrain.pretrained.EncoderDecoderASR
+    encoder_decoder_asr: speechbrain.inference.ASR.EncoderDecoderASR
         Pretrained interface (other interfaces will require other functions to be called; this is an example).
     test_set: dict
         Data loader options for testing.
@@ -118,7 +118,7 @@ def eval_test_batch_from_scratch(
 
     Parameters
     ----------
-    encoder_decoder_asr: speechbrain.pretrained.EncoderDecoderASR
+    encoder_decoder_asr: speechbrain.inference.ASR.EncoderDecoderASR
         Pretrained interface (other interfaces will require other functions to be called; this is an example).
     test_set: Dataset, DataLoader
         If a DataLoader is given, it is iterated directly. Otherwise passed to `sb.dataio.dataloader.make_dataloader()`.
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/finetune.yaml b/tests/templates/fetching_ddp_dynbatch_finetuning/finetune.yaml
index 3b423c9685cce75a249e3d4e34f53d61e10ef786..ae2faa3b1cbc17a63516c11071d756d391140467 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/finetune.yaml
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/finetune.yaml
@@ -6,7 +6,7 @@
 #
 # Goal: more clarity on when, where & how to use
 #    speechbrain.utils.parameter_transfer.Pretrainer
-#    speechbrain.pretrained.interfaces.Pretrained
+#    speechbrain.inference.interfaces.Pretrained
 #
 # Authors:  Andreas Nautsch 2023
 # # ############################################################################
@@ -32,6 +32,11 @@ train_annotation: ../../../templates/speech_recognition/train.json
 valid_annotation: ../../../templates/speech_recognition/valid.json
 test_annotation: ../../../templates/speech_recognition/test.json
 
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
 # Training parameters
 number_of_epochs: 2
 number_of_ctc_epochs: 1
@@ -74,15 +79,74 @@ dynamic_batch_sampler:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Added noise and reverb come from OpenRIR dataset, automatically
-# downloaded and prepared with this Environmental Corruption class.
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
+
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: False
+    concat_original: True
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 # Loads the ASR brain for training/fine-tuning mode
 params: !include:ASR.yaml
@@ -100,11 +164,6 @@ test_search: !ref <params[test_search]>
 bos_index: !ref <params[bos_index]>
 eos_index: !ref <params[eos_index]>
 
-# Adds speech change + time and frequency dropouts (time-domain implementation).
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <params[sample_rate]>
-    speeds: [95, 100, 105]
-
 # Objects in "modules" dict will have their parameters moved to the correct
 # device, as well as having train()/eval() called on them by the Brain class
 modules:
@@ -114,7 +173,6 @@ modules:
     ctc_lin: !ref <ctc_lin>
     seq_lin: !ref <seq_lin>
     normalize: !ref <normalize>
-    env_corrupt: !ref <env_corrupt>
     lm_model: !ref <lm_model>
 
 # Gathering all the submodels in a single model object.
@@ -168,14 +226,14 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         counter: !ref <epoch_counter>
 
 # Pretraining diversified
-pretrained_path_ASR: !new:speechbrain.pretrained.fetching.FetchSource
-    - !name:speechbrain.pretrained.fetching.FetchFrom.LOCAL
+pretrained_path_ASR: !new:speechbrain.utils.fetching.FetchSource
+    - !name:speechbrain.utils.fetching.FetchFrom.LOCAL
     - speechbrain/asr-crdnn-rnnlm-librispeech/model.ckpt
-pretrained_path_LM: !new:speechbrain.pretrained.fetching.FetchSource
-    - !name:speechbrain.pretrained.fetching.FetchFrom.HUGGING_FACE
+pretrained_path_LM: !new:speechbrain.utils.fetching.FetchSource
+    - !name:speechbrain.utils.fetching.FetchFrom.HUGGING_FACE
     - speechbrain/asr-crdnn-rnnlm-librispeech/lm.ckpt
-pretrained_path_tokenizer: !new:speechbrain.pretrained.fetching.FetchSource
-    - !name:speechbrain.pretrained.fetching.FetchFrom.URI
+pretrained_path_tokenizer: !new:speechbrain.utils.fetching.FetchSource
+    - !name:speechbrain.utils.fetching.FetchFrom.URI
     - https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech/resolve/main/tokenizer.ckpt
 
 pretrainer_ASR: !new:speechbrain.utils.parameter_transfer.Pretrainer
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/finetune_fetch_once.py b/tests/templates/fetching_ddp_dynbatch_finetuning/finetune_fetch_once.py
index f72e4677926461b13d17c314bc0b5079d9b05ac0..1f2c8c1213c7970b369048bce905229fb9176ccb 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/finetune_fetch_once.py
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/finetune_fetch_once.py
@@ -16,7 +16,7 @@ from copy import deepcopy
 from tqdm.contrib import tqdm
 from torch.utils.data import DataLoader
 from hyperpyyaml import load_hyperpyyaml
-from speechbrain.pretrained import EncoderDecoderASR
+from speechbrain.inference.ASR import EncoderDecoderASR
 from speechbrain.utils.distributed import run_on_main, ddp_barrier
 from speechbrain.utils.data_utils import batch_pad_right
 from speechbrain.dataio.dataset import DynamicItemDataset
@@ -67,11 +67,11 @@ def eval_reporting(reports, single_node=False):
 def eval_test_use_recipe_dataio(
     encoder_decoder_asr, test_set, test_kwargs, reporter, single_node=False
 ):
-    """Bypassing speechbrain.pretrained.Pretrained.load_audio with recipe dataio (speechbrain.dataio.dataio.read_audio).
+    """Bypassing speechbrain.inference.interfaces.Pretrained.load_audio with recipe dataio (speechbrain.dataio.dataio.read_audio).
 
     Parameters
     ----------
-    encoder_decoder_asr: speechbrain.pretrained.EncoderDecoderASR
+    encoder_decoder_asr: speechbrain.inference.ASR.EncoderDecoderASR
         Pretrained interface (other interfaces will require other functions to be called; this is an example).
     test_set: dict
         Data loader options for testing.
@@ -118,7 +118,7 @@ def eval_test_batch_from_scratch(
 
     Parameters
     ----------
-    encoder_decoder_asr: speechbrain.pretrained.EncoderDecoderASR
+    encoder_decoder_asr: speechbrain.inference.ASR.EncoderDecoderASR
         Pretrained interface (other interfaces will require other functions to be called; this is an example).
     test_set: Dataset, DataLoader
         If a DataLoader is given, it is iterated directly. Otherwise passed to `sb.dataio.dataloader.make_dataloader()`.
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/finetune_fetch_once.yaml b/tests/templates/fetching_ddp_dynbatch_finetuning/finetune_fetch_once.yaml
index 186cb408a1f8bd5ed7b0c911cfe0ac58c1aabbdf..fe89b0bdea1434d8d2d2ac78483d84fdc425f8da 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/finetune_fetch_once.yaml
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/finetune_fetch_once.yaml
@@ -6,7 +6,7 @@
 #
 # Goal: more clarity on when, where & how to use
 #    speechbrain.utils.parameter_transfer.Pretrainer
-#    speechbrain.pretrained.interfaces.Pretrained
+#    speechbrain.inference.interfaces.Pretrained
 #
 # Authors:  Andreas Nautsch 2023
 # # ############################################################################
@@ -32,6 +32,11 @@ train_annotation: ../../../templates/speech_recognition/train.json
 valid_annotation: ../../../templates/speech_recognition/valid.json
 test_annotation: ../../../templates/speech_recognition/test.json
 
+# Data for augmentation
+data_folder_noise: !ref <data_folder>/noise # The noisy sequencies for data augmentation will automatically be downloaded here.
+NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
+noise_annotation: !ref <save_folder>/noise.csv #The data manifest files are created by the data preparation script
+
 # Training parameters
 number_of_epochs: 2
 number_of_ctc_epochs: 1
@@ -74,15 +79,73 @@ dynamic_batch_sampler:
 epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
     limit: !ref <number_of_epochs>
 
-# Added noise and reverb come from OpenRIR dataset, automatically
-# downloaded and prepared with this Environmental Corruption class.
-env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt
-    openrir_folder: !ref <data_folder_rirs>
-    babble_prob: 0.0
-    reverb_prob: 0.0
-    noise_prob: 1.0
-    noise_snr_low: 0
-    noise_snr_high: 15
+# Download and prepare the dataset of noisy sequences for augmentation
+prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
+    URL: !ref <NOISE_DATASET_URL>
+    dest_folder: !ref <data_folder_noise>
+    ext: wav
+    csv_file: !ref <noise_annotation>
+
+
+# Add noise to input signal
+snr_low: 0  # Min SNR for noise augmentation
+snr_high: 15  # Max SNR for noise augmentation
+
+add_noise: !new:speechbrain.augment.time_domain.AddNoise
+    csv_file: !ref <noise_annotation>
+    snr_low: !ref <snr_low>
+    snr_high: !ref <snr_high>
+    noise_sample_rate: !ref <sample_rate>
+    clean_sample_rate: !ref <sample_rate>
+    num_workers: !ref <num_workers>
+
+# Speed perturbation
+speed_changes: [95, 100, 105]  # List of speed changes for time-stretching
+
+speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
+    orig_freq: !ref <sample_rate>
+    speeds: !ref <speed_changes>
+
+# Frequency drop: randomly drops a number of frequency bands to zero.
+drop_freq_low: 0  # Min frequency band dropout probability
+drop_freq_high: 1  # Max frequency band dropout probability
+drop_freq_count_low: 1  # Min number of frequency bands to drop
+drop_freq_count_high: 3  # Max number of frequency bands to drop
+drop_freq_width: 0.05  # Width of frequency bands to drop
+
+drop_freq: !new:speechbrain.augment.time_domain.DropFreq
+    drop_freq_low: !ref <drop_freq_low>
+    drop_freq_high: !ref <drop_freq_high>
+    drop_freq_count_low: !ref <drop_freq_count_low>
+    drop_freq_count_high: !ref <drop_freq_count_high>
+    drop_freq_width: !ref <drop_freq_width>
+
+# Time drop: randomly drops a number of temporal chunks.
+drop_chunk_count_low: 1  # Min number of audio chunks to drop
+drop_chunk_count_high: 5  # Max number of audio chunks to drop
+drop_chunk_length_low: 1000  # Min length of audio chunks to drop
+drop_chunk_length_high: 2000  # Max length of audio chunks to drop
+
+drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
+    drop_length_low: !ref <drop_chunk_length_low>
+    drop_length_high: !ref <drop_chunk_length_high>
+    drop_count_low: !ref <drop_chunk_count_low>
+    drop_count_high: !ref <drop_chunk_count_high>
+
+# Augmenter: Combines previously defined augmentations to perform data augmentation
+wav_augment: !new:speechbrain.augment.augmenter.Augmenter
+    parallel_augment: False
+    concat_original: True
+    repeat_augment: 1
+    shuffle_augmentations: False
+    min_augmentations: 4
+    max_augmentations: 4
+    augment_prob: 1.0
+    augmentations: [
+        !ref <add_noise>,
+        !ref <speed_perturb>,
+        !ref <drop_freq>,
+        !ref <drop_chunk>]
 
 # Loads the ASR brain for training/fine-tuning mode
 params: !include:ASR.yaml
@@ -100,11 +163,6 @@ test_search: !ref <params[test_search]>
 bos_index: !ref <params[bos_index]>
 eos_index: !ref <params[eos_index]>
 
-# Adds speech change + time and frequency dropouts (time-domain implementation).
-augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
-    sample_rate: !ref <params[sample_rate]>
-    speeds: [95, 100, 105]
-
 # Objects in "modules" dict will have their parameters moved to the correct
 # device, as well as having train()/eval() called on them by the Brain class
 modules:
@@ -114,7 +172,6 @@ modules:
     ctc_lin: !ref <ctc_lin>
     seq_lin: !ref <seq_lin>
     normalize: !ref <normalize>
-    env_corrupt: !ref <env_corrupt>
     lm_model: !ref <lm_model>
 
 # Gathering all the submodels in a single model object.
@@ -168,14 +225,14 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         counter: !ref <epoch_counter>
 
 # Pretraining diversified
-pretrained_path_ASR: !new:speechbrain.pretrained.fetching.FetchSource
-    - !name:speechbrain.pretrained.fetching.FetchFrom.LOCAL
+pretrained_path_ASR: !new:speechbrain.utils.fetching.FetchSource
+    - !name:speechbrain.utils.fetching.FetchFrom.LOCAL
     - speechbrain/asr-crdnn-rnnlm-librispeech/model.ckpt
-pretrained_path_LM: !new:speechbrain.pretrained.fetching.FetchSource
-    - !name:speechbrain.pretrained.fetching.FetchFrom.HUGGING_FACE
+pretrained_path_LM: !new:speechbrain.utils.fetching.FetchSource
+    - !name:speechbrain.utils.fetching.FetchFrom.HUGGING_FACE
     - speechbrain/asr-crdnn-rnnlm-librispeech/lm.ckpt
-pretrained_path_tokenizer: !new:speechbrain.pretrained.fetching.FetchSource
-    - !name:speechbrain.pretrained.fetching.FetchFrom.URI
+pretrained_path_tokenizer: !new:speechbrain.utils.fetching.FetchSource
+    - !name:speechbrain.utils.fetching.FetchFrom.URI
     - https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech/resolve/main/tokenizer.ckpt
 
 pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/multisource_mini_recipe.py b/tests/templates/fetching_ddp_dynbatch_finetuning/multisource_mini_recipe.py
index 89ae89e3423aae64d284a0981a432492292b38f9..9ee76ff4f698fb5bacc9a1f93d4357b2d646f698 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/multisource_mini_recipe.py
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/multisource_mini_recipe.py
@@ -66,10 +66,7 @@ class SLU(sb.Brain):
         p_seq = self.hparams.log_softmax(logits)
 
         # Compute outputs
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
             return p_seq, asr_tokens_lens
         else:
             p_tokens, scores = self.hparams.beam_searcher(
@@ -80,10 +77,7 @@ class SLU(sb.Brain):
     def compute_objectives(self, predictions, batch, stage):
         """Computes the loss (NLL) given predictions and targets."""
 
-        if (
-            stage == sb.Stage.TRAIN
-            and self.batch_count % show_results_every != 0
-        ):
+        if stage == sb.Stage.TRAIN and self.step % show_results_every != 0:
             p_seq, asr_tokens_lens = predictions
         else:
             p_seq, asr_tokens_lens, predicted_tokens = predictions
@@ -99,9 +93,7 @@ class SLU(sb.Brain):
         # (No ctc loss)
         loss = loss_seq
 
-        if (stage != sb.Stage.TRAIN) or (
-            self.batch_count % show_results_every == 0
-        ):
+        if (stage != sb.Stage.TRAIN) or (self.step % show_results_every == 0):
             # Decode token terms to words
             predicted_semantics = [
                 tokenizer.decode_ids(utt_seq).split(" ")
@@ -125,26 +117,8 @@ class SLU(sb.Brain):
 
         return loss
 
-    def fit_batch(self, batch):
-        """Train the parameters given a single batch in input"""
-        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
-        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
-        loss.backward()
-        if self.check_gradients(loss):
-            self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.batch_count += 1
-        return loss.detach()
-
-    def evaluate_batch(self, batch, stage):
-        """Computations needed for validation/test batches"""
-        predictions = self.compute_forward(batch, stage=stage)
-        loss = self.compute_objectives(predictions, batch, stage=stage)
-        return loss.detach()
-
     def on_stage_start(self, stage, epoch):
         """Gets called at the beginning of each epoch"""
-        self.batch_count = 0
 
         if stage != sb.Stage.TRAIN:
 
@@ -316,7 +290,6 @@ if __name__ == "__main__":
 
     show_results_every = 100  # plots results every N iterations
 
-    # If --distributed_launch then
     # create ddp_group with the right communication protocol
     sb.utils.distributed.ddp_init_group(run_opts)
 
@@ -356,7 +329,7 @@ if __name__ == "__main__":
 
     # We download and pretrain the tokenizer
     run_on_main(hparams["pretrainer"].collect_files)
-    hparams["pretrainer"].load_collected(device=run_opts["device"])
+    hparams["pretrainer"].load_collected()
 
     # Brain class initialization
     slu_brain = SLU(
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/multisource_mini_recipe.yaml b/tests/templates/fetching_ddp_dynbatch_finetuning/multisource_mini_recipe.yaml
index df0796ea1969e8fb33a63705c14ab73665b51162..c0644957041540e7b5ce04e7a0dc4576e3a218da 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/multisource_mini_recipe.yaml
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/multisource_mini_recipe.yaml
@@ -65,7 +65,7 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
 
 
 # Models
-asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
+asr_model: !apply:speechbrain.inference.ASR.EncoderDecoderASR.from_hparams
     # source: speechbrain/asr-crdnn-rnnlm-librispeech  # could create a local path issue; specific to this testing folder
     source: speechbrain/asr-crdnn-transformerlm-librispeech
     run_opts: {"device":"cuda:0"}
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/single_node_pretrained.py b/tests/templates/fetching_ddp_dynbatch_finetuning/single_node_pretrained.py
index ac3e5bcbde0cac71f18f8c79d3c26e773ad5ea4a..e8cda44b2548c94160153e9a83fbe9bd2d5e9bac 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/single_node_pretrained.py
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/single_node_pretrained.py
@@ -10,8 +10,8 @@ Authors:
 import logging
 import speechbrain as sb
 from copy import deepcopy
-from speechbrain.pretrained import EncoderDecoderASR
-from speechbrain.pretrained.fetching import FetchFrom, FetchSource
+from speechbrain.inference.ASR import EncoderDecoderASR
+from speechbrain.utils.fetching import FetchFrom, FetchSource
 
 
 logger = logging.getLogger(__name__)
diff --git a/tests/templates/fetching_ddp_dynbatch_finetuning/source_pretrained/pretrained.yaml b/tests/templates/fetching_ddp_dynbatch_finetuning/source_pretrained/pretrained.yaml
index 819324d8735b92e710a57a4d3aea8b5a0c260960..bc9003f6e8ea5bd6623dd5ad06ed7588682ad76f 100644
--- a/tests/templates/fetching_ddp_dynbatch_finetuning/source_pretrained/pretrained.yaml
+++ b/tests/templates/fetching_ddp_dynbatch_finetuning/source_pretrained/pretrained.yaml
@@ -6,7 +6,7 @@
 #
 # Goal: more clarity on when, where & how to use
 #    speechbrain.utils.parameter_transfer.Pretrainer
-#    speechbrain.pretrained.interfaces.Pretrained
+#    speechbrain.inference.interfaces.Pretrained
 #
 # Authors:  Andreas Nautsch 2023
 # # ############################################################################
diff --git a/tests/unittests/test_augment.py b/tests/unittests/test_augment.py
index e4e245471a41682157264279887ab06af791b0cf..b9c6994f61fec876ac038162e500d32954ec17fb 100644
--- a/tests/unittests/test_augment.py
+++ b/tests/unittests/test_augment.py
@@ -1,10 +1,11 @@
 import os
 import torch
+import torchaudio
 from speechbrain.dataio.dataio import write_audio
 
 
 def test_add_noise(tmpdir, device):
-    from speechbrain.processing.speech_augmentation import AddNoise
+    from speechbrain.augment.time_domain import AddNoise
 
     # Test concatenation of batches
     wav_a = torch.sin(torch.arange(8000.0, device=device)).unsqueeze(0)
@@ -37,8 +38,6 @@ def test_add_noise(tmpdir, device):
         w.write(f"1, 1.0, {noisefile}, wav,\n")
 
     # Edge cases
-    no_noise = AddNoise(mix_prob=0.0).to(device)
-    assert no_noise(test_waveform, wav_lens).allclose(test_waveform)
     no_noise = AddNoise(snr_low=1000, snr_high=1000)
     assert no_noise(test_waveform, wav_lens).allclose(test_waveform)
     all_noise = AddNoise(csv_file=csv, snr_low=-1000, snr_high=-1000)
@@ -51,12 +50,11 @@ def test_add_noise(tmpdir, device):
 
 
 def test_add_reverb(tmpdir, device):
-    from speechbrain.processing.speech_augmentation import AddReverb
+    from speechbrain.augment.time_domain import AddReverb
 
     test_waveform = torch.sin(torch.arange(16000.0, device=device)).unsqueeze(0)
     impulse_response = torch.zeros(1, 8000, device=device)
     impulse_response[0, 0] = 1.0
-    wav_lens = torch.ones(1, device=device)
 
     # Put ir waveform into temporary file
     ir1 = os.path.join(tmpdir, "ir1.wav")
@@ -83,28 +81,22 @@ def test_add_reverb(tmpdir, device):
         w.write(f"2, 0.5, {ir2}, wav,\n")
         w.write(f"3, 0.5, {ir3}, wav,\n")
 
-    # Edge case
-    no_reverb = AddReverb(csv, reverb_prob=0.0).to(device)
-    assert no_reverb(test_waveform, wav_lens).allclose(test_waveform)
-
     # Normal cases
     add_reverb = AddReverb(csv, sorting="original")
-    reverbed = add_reverb(test_waveform, wav_lens)[:, 0:1000]
+    reverbed = add_reverb(test_waveform)[:, 0:1000]
     assert reverbed.allclose(test_waveform[:, 0:1000], atol=1e-1)
-    reverbed = add_reverb(test_waveform, wav_lens)[:, 0:1000]
+    reverbed = add_reverb(test_waveform)[:, 0:1000]
     assert reverbed.allclose(test_waveform[:, 0:1000], atol=1e-1)
-    reverbed = add_reverb(test_waveform, wav_lens)[:, 0:1000]
+    reverbed = add_reverb(test_waveform)[:, 0:1000]
     assert reverbed.allclose(ir3_result[:, 0:1000], atol=2e-1)
 
 
 def test_speed_perturb(device):
-    from speechbrain.processing.speech_augmentation import SpeedPerturb
+    from speechbrain.augment.time_domain import SpeedPerturb
 
     test_waveform = torch.sin(torch.arange(16000.0, device=device)).unsqueeze(0)
 
     # Edge cases
-    no_perturb = SpeedPerturb(16000, perturb_prob=0.0).to(device)
-    assert no_perturb(test_waveform).allclose(test_waveform)
     no_perturb = SpeedPerturb(16000, speeds=[100]).to(device)
     assert no_perturb(test_waveform).allclose(test_waveform)
 
@@ -113,43 +105,18 @@ def test_speed_perturb(device):
     assert half_speed(test_waveform).allclose(test_waveform[:, ::2], atol=3e-1)
 
 
-def test_babble(device):
-    from speechbrain.processing.speech_augmentation import AddBabble
-
-    test_waveform = torch.stack(
-        (
-            torch.sin(torch.arange(16000.0, device=device)),
-            torch.cos(torch.arange(16000.0, device=device)),
-        )
-    )
-    lengths = torch.ones(2, device=device)
-
-    # Edge cases
-    no_babble = AddBabble(mix_prob=0.0).to(device)
-    assert no_babble(test_waveform, lengths).allclose(test_waveform)
-    no_babble = AddBabble(speaker_count=1, snr_low=1000, snr_high=1000)
-    assert no_babble(test_waveform, lengths).allclose(test_waveform)
-
-    # One babbler just averages the two speakers
-    babble = AddBabble(speaker_count=1).to(device)
-    expected = (test_waveform + test_waveform.roll(1, 0)) / 2
-    assert babble(test_waveform, lengths).allclose(expected, atol=1e-4)
-
-
 def test_drop_freq(device):
-    from speechbrain.processing.speech_augmentation import DropFreq
+    from speechbrain.augment.time_domain import DropFreq
 
     test_waveform = torch.sin(torch.arange(16000.0, device=device)).unsqueeze(0)
 
     # Edge cases
-    no_drop = DropFreq(drop_prob=0.0).to(device)
-    assert no_drop(test_waveform).allclose(test_waveform)
-    no_drop = DropFreq(drop_count_low=0, drop_count_high=0)
+    no_drop = DropFreq(drop_freq_count_low=0, drop_freq_count_high=0)
     assert no_drop(test_waveform).allclose(test_waveform)
 
     # Check case where frequency range *does not* include signal frequency
     drop_diff_freq = DropFreq(drop_freq_low=0.5, drop_freq_high=0.9)
-    assert drop_diff_freq(test_waveform).allclose(test_waveform, atol=1e-1)
+    assert drop_diff_freq(test_waveform).allclose(test_waveform, atol=5e-1)
 
     # Check case where frequency range *does* include signal frequency
     drop_same_freq = DropFreq(drop_freq_low=0.28, drop_freq_high=0.28)
@@ -159,14 +126,12 @@ def test_drop_freq(device):
 
 
 def test_drop_chunk(device):
-    from speechbrain.processing.speech_augmentation import DropChunk
+    from speechbrain.augment.time_domain import DropChunk
 
     test_waveform = torch.sin(torch.arange(16000.0, device=device)).unsqueeze(0)
     lengths = torch.ones(1, device=device)
 
     # Edge cases
-    no_drop = DropChunk(drop_prob=0.0).to(device)
-    assert no_drop(test_waveform, lengths).allclose(test_waveform)
     no_drop = DropChunk(drop_length_low=0, drop_length_high=0).to(device)
     assert no_drop(test_waveform, lengths).allclose(test_waveform)
     no_drop = DropChunk(drop_count_low=0, drop_count_high=0).to(device)
@@ -196,18 +161,488 @@ def test_drop_chunk(device):
     assert drop_amplitude.allclose(orig_amplitude, atol=1e-2)
 
 
+def test_fast_drop_chunk():
+    from speechbrain.augment.time_domain import FastDropChunk
+
+    test_waveform = torch.ones([8, 200, 12])
+
+    # Edge cases
+    no_drop = FastDropChunk(drop_length_low=0, drop_length_high=0)
+    assert no_drop(test_waveform).allclose(test_waveform)
+    no_drop = FastDropChunk(drop_count_low=0, drop_count_high=0)
+    assert no_drop(test_waveform).allclose(test_waveform)
+    no_drop = FastDropChunk(drop_start=0, drop_end=0)
+    assert no_drop(test_waveform).allclose(test_waveform)
+
+
 def test_clip(device):
-    from speechbrain.processing.speech_augmentation import DoClip
+    from speechbrain.augment.time_domain import DoClip
 
     test_waveform = torch.sin(torch.arange(16000.0, device=device)).unsqueeze(0)
 
     # Edge cases
-    no_clip = DoClip(clip_prob=0.0).to(device)
-    assert no_clip(test_waveform).allclose(test_waveform)
     no_clip = DoClip(clip_low=1, clip_high=1).to(device)
     assert no_clip(test_waveform).allclose(test_waveform)
 
     # Sort of a reimplementation of clipping, but its one function call.
-    expected = test_waveform.clamp(min=-0.5, max=0.5)
+    expected = 2 * test_waveform.clamp(min=-0.5, max=0.5)
     half_clip = DoClip(clip_low=0.5, clip_high=0.5).to(device)
     assert half_clip(test_waveform).allclose(expected)
+
+
+def test_rand_amp():
+    from speechbrain.augment.time_domain import RandAmp
+
+    rand_amp = RandAmp(amp_low=0, amp_high=0)
+    signal = torch.rand(4, 500)
+    output = rand_amp(signal)
+    assert output.mean().mean(0) == 0
+
+
+def test_channel_drop():
+    from speechbrain.augment.time_domain import ChannelDrop
+
+    signal = torch.rand(4, 256, 8)
+    ch_drop = ChannelDrop(drop_rate=0.5)
+    output = ch_drop(signal)
+    assert signal.shape == output.shape
+
+    signal = torch.rand(4, 256, 8)
+    ch_drop = ChannelDrop(drop_rate=0.0)
+    output = ch_drop(signal)
+    assert torch.equal(signal, output)
+
+    signal = torch.rand(4, 256, 8)
+    ch_drop = ChannelDrop(drop_rate=1.0)
+    output = ch_drop(signal)
+    assert torch.sum(output) == 0
+
+
+def test_channel_swap():
+    from speechbrain.augment.time_domain import ChannelSwap
+
+    signal = torch.rand(4, 256, 8)
+    ch_swap = ChannelSwap(min_swap=1, max_swap=5)
+    output = ch_swap(signal)
+    assert signal.shape == output.shape
+
+    signal = torch.rand(4, 256, 8)
+    ch_swap = ChannelSwap(min_swap=0, max_swap=0)
+    output = ch_swap(signal)
+    assert torch.equal(signal, output)
+
+
+def test_rand_shift():
+    from speechbrain.augment.freq_domain import RandomShift
+
+    signal = torch.rand(4, 256, 8)
+    lengths = torch.tensor([0.1, 0.2, 0.9, 1.0])
+    rand_shift = RandomShift(min_shift=10, max_shift=50, dim=1)
+    output, lengths = rand_shift(signal, lengths)
+    assert signal.shape == output.shape
+    assert torch.equal(signal, output) == 0
+
+    signal = torch.rand(4, 256, 8)
+    rand_shift = RandomShift(min_shift=1, max_shift=2, dim=2)
+    output, lengths = rand_shift(signal, lengths)
+    assert signal.shape == output.shape
+    assert torch.equal(signal, output) == 0
+
+    signal = torch.rand(4, 256)
+    rand_shift = RandomShift(min_shift=10, max_shift=50, dim=1)
+    output, lengths = rand_shift(signal, lengths)
+    assert signal.shape == output.shape
+    assert torch.equal(signal, output) == 0
+
+    signal = torch.rand(4, 256, 8)
+    rand_shift = RandomShift(min_shift=0, max_shift=0, dim=1)
+    output, lengths = rand_shift(signal, lengths)
+    assert torch.equal(signal, output)
+
+    signal = torch.Tensor([1, 0, 0])
+    rand_shift = RandomShift(min_shift=1, max_shift=1, dim=0)
+    output, lengths = rand_shift(signal, lengths)
+    assert torch.equal(output, torch.Tensor([0, 1, 0]))
+
+
+def test_pink_noise():
+    from speechbrain.augment.time_domain import pink_noise_like
+
+    signal = torch.rand(4, 256)
+    noise = pink_noise_like(signal)
+    assert signal.shape == noise.shape
+
+    signal = torch.rand(4, 256, 8)
+    noise = pink_noise_like(signal)
+    assert signal.shape == noise.shape
+
+    signal = torch.rand(4, 257, 8)
+    noise = pink_noise_like(signal)
+    assert signal.shape == noise.shape
+
+    noise_fft = torch.fft.fft(noise, dim=1)
+    mean_first_fft_points = noise_fft.abs()[:, 0:10, :].mean()
+    mean_last_fft_points = noise_fft.abs()[:, 118:128, :].mean()
+    assert torch.all(mean_first_fft_points > mean_last_fft_points)
+
+    # Test blue noise
+    noise = pink_noise_like(signal, alpha_low=-1.0, alpha_high=-1.0)
+    noise_fft = torch.fft.fft(noise, dim=1)
+    mean_first_fft_points = noise_fft.abs()[:, 0:10, :].mean()
+    mean_last_fft_points = noise_fft.abs()[:, 118:128, :].mean()
+    assert torch.all(mean_first_fft_points < mean_last_fft_points)
+
+
+def test_SpectrogramDrop():
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    mean = spectrogram.mean()
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="zeros",
+        dim=1,
+    )
+    output = drop(spectrogram)
+    assert mean > output.mean()
+    assert spectrogram.shape == output.shape
+
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    mean = spectrogram.mean()
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="zeros",
+        dim=2,
+    )
+    output = drop(spectrogram)
+    assert mean > output.mean()
+    assert spectrogram.shape == output.shape
+
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="mean",
+        dim=1,
+    )
+    output = drop(spectrogram.clone())
+    assert spectrogram.shape == output.shape
+    assert not torch.equal(spectrogram, output)
+
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="mean",
+        dim=2,
+    )
+    output = drop(spectrogram.clone())
+    assert spectrogram.shape == output.shape
+    assert not torch.equal(spectrogram, output)
+
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="cutcat",
+        dim=1,
+    )
+    output = drop(spectrogram.clone())
+    assert spectrogram.shape == output.shape
+    assert not torch.equal(spectrogram, output)
+
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="cutcat",
+        dim=2,
+    )
+    output = drop(spectrogram.clone())
+    assert spectrogram.shape == output.shape
+    assert not torch.equal(spectrogram, output)
+
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="swap",
+        dim=1,
+    )
+    output = drop(spectrogram.clone())
+    assert spectrogram.shape == output.shape
+    assert not torch.equal(spectrogram, output)
+
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="swap",
+        dim=2,
+    )
+    output = drop(spectrogram.clone())
+    assert spectrogram.shape == output.shape
+    assert not torch.equal(spectrogram, output)
+    assert torch.allclose(spectrogram.mean(), output.mean())
+
+    # Important: understand why sometimes with random selection, spectrogram and output are the same....
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="random_selection",
+        dim=1,
+    )
+    output = drop(spectrogram.clone())
+    assert spectrogram.shape == output.shape
+    assert not torch.equal(spectrogram, output)
+
+    from speechbrain.augment.freq_domain import SpectrogramDrop
+
+    spectrogram = torch.rand(4, 100, 40)
+    drop = SpectrogramDrop(
+        drop_length_low=1,
+        drop_length_high=15,
+        drop_count_low=1,
+        drop_count_high=3,
+        replace="random_selection",
+        dim=2,
+    )
+    output = drop(spectrogram.clone())
+    assert spectrogram.shape == output.shape
+    assert not torch.equal(spectrogram, output)
+
+    from speechbrain.augment.codec import CodecAugment
+
+    if "ffmpeg" in torchaudio.list_audio_backends():
+        waveform = torch.rand(4, 16000)
+        augmenter = CodecAugment(16000)
+        output_waveform = augmenter(waveform)
+        assert not torch.allclose(waveform, output_waveform)
+
+    from speechbrain.augment.time_domain import DropBitResolution
+
+    dropper = DropBitResolution()
+    signal = torch.rand(4, 16000)
+    signal_dropped = dropper(signal)
+    assert not torch.equal(signal, signal_dropped)
+    assert signal.shape == signal_dropped.shape
+
+    from speechbrain.augment.time_domain import DropBitResolution
+
+    dropper = DropBitResolution(target_dtype="int8")
+    signal = torch.rand(4, 16000)
+    signal_dropped = dropper(signal)
+    assert not torch.equal(signal, signal_dropped)
+    assert signal.shape == signal_dropped.shape
+
+    from speechbrain.augment.time_domain import DropBitResolution
+
+    dropper = DropBitResolution(target_dtype="int16")
+    signal = torch.rand(4, 16000)
+    signal_dropped = dropper(signal)
+    assert not torch.equal(signal, signal_dropped)
+    assert signal.shape == signal_dropped.shape
+
+    from speechbrain.augment.time_domain import DropBitResolution
+
+    dropper = DropBitResolution(target_dtype="float16")
+    signal = torch.rand(4, 16000)
+    signal_dropped = dropper(signal)
+    assert not torch.equal(signal, signal_dropped)
+    assert signal.shape == signal_dropped.shape
+
+
+def test_augment_pipeline():
+    from speechbrain.augment.time_domain import DropFreq, DropChunk
+    from speechbrain.augment.augmenter import Augmenter
+
+    freq_dropper = DropFreq()
+    chunk_dropper = DropChunk(drop_start=100, drop_end=16000, noise_factor=0)
+    augment = Augmenter(
+        parallel_augment=False,
+        concat_original=False,
+        min_augmentations=2,
+        max_augmentations=2,
+        augmentations=[freq_dropper, chunk_dropper],
+    )
+    signal = torch.rand([4, 16000])
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert len(output_signal) == 4
+    assert len(lenghts) == 4
+
+    freq_dropper = DropFreq()
+    chunk_dropper = DropChunk(drop_start=100, drop_end=16000, noise_factor=0)
+    augment = Augmenter(
+        parallel_augment=True,
+        concat_original=True,
+        min_augmentations=1,
+        max_augmentations=2,
+        augment_prob=0,
+        augmentations=[freq_dropper, chunk_dropper],
+    )
+    signal = torch.rand([4, 16000])
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert torch.equal(signal, output_signal)
+
+    freq_dropper = DropFreq()
+    chunk_dropper = DropChunk(drop_start=100, drop_end=16000, noise_factor=0)
+    augment = Augmenter(
+        parallel_augment=True,
+        concat_original=True,
+        min_augmentations=1,
+        max_augmentations=2,
+        augment_prob=1.0,
+        augmentations=[freq_dropper, chunk_dropper],
+        enable_augmentations=[False, False],
+    )
+    signal = torch.rand([4, 16000])
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert torch.equal(signal, output_signal)
+
+    freq_dropper = DropFreq()
+    chunk_dropper = DropChunk(drop_start=100, drop_end=16000, noise_factor=0)
+    augment = Augmenter(
+        parallel_augment=True,
+        concat_original=True,
+        min_augmentations=2,
+        max_augmentations=2,
+        augment_prob=1.0,
+        augmentations=[freq_dropper, chunk_dropper],
+        enable_augmentations=[True, False],
+    )
+    signal = torch.rand([4, 16000])
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert output_signal.shape[0] == signal.shape[0] * 2
+
+    augment = Augmenter(
+        parallel_augment=False,
+        concat_original=True,
+        min_augmentations=2,
+        max_augmentations=2,
+        augmentations=[freq_dropper, chunk_dropper],
+    )
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert len(output_signal) == 8
+    assert len(lenghts) == 8
+    assert torch.equal(output_signal[0:4], signal[0:4])
+
+    augment = Augmenter(
+        parallel_augment=True,
+        concat_original=False,
+        min_augmentations=2,
+        max_augmentations=2,
+        augmentations=[freq_dropper, chunk_dropper],
+    )
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert len(output_signal) == 8
+    assert len(lenghts) == 8
+
+    augment = Augmenter(
+        parallel_augment=True,
+        concat_original=True,
+        min_augmentations=2,
+        max_augmentations=2,
+        augmentations=[freq_dropper, chunk_dropper],
+    )
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert len(output_signal) == 12
+    assert len(lenghts) == 12
+    assert torch.equal(output_signal[0:4], signal[0:4])
+
+    augment = Augmenter(
+        parallel_augment=True,
+        concat_original=True,
+        min_augmentations=2,
+        max_augmentations=2,
+        repeat_augment=2,
+        shuffle_augmentations=True,
+        augmentations=[freq_dropper, chunk_dropper],
+    )
+
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert len(output_signal) == 20
+    assert len(lenghts) == 20
+    assert torch.equal(output_signal[0:4], signal[0:4])
+
+    augment = Augmenter(
+        parallel_augment=True,
+        concat_original=True,
+        min_augmentations=0,
+        max_augmentations=0,
+        repeat_augment=2,
+        shuffle_augmentations=True,
+        augmentations=[freq_dropper, chunk_dropper],
+    )
+
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert torch.equal(output_signal, signal)
+
+    augment = Augmenter(
+        parallel_augment=True,
+        concat_original=True,
+        min_augmentations=1,
+        max_augmentations=2,
+        repeat_augment=0,
+        shuffle_augmentations=True,
+        augmentations=[freq_dropper, chunk_dropper],
+    )
+
+    output_signal, lenghts = augment(
+        signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0])
+    )
+    assert torch.equal(output_signal, signal)
diff --git a/tests/unittests/test_checkpoints.py b/tests/unittests/test_checkpoints.py
index 623f5188c38e1e730054970cd921e79fcdace9fc..f9743a10683c785d36dde268a76f33491bc054b7 100644
--- a/tests/unittests/test_checkpoints.py
+++ b/tests/unittests/test_checkpoints.py
@@ -97,8 +97,8 @@ def test_checkpointer(tmpdir, device):
     assert other.param.data == torch.tensor([10.0], device=device)
 
     # Make sure checkpoints can't be name saved by the same name
-    with pytest.raises(FileExistsError):
-        recoverer.save_checkpoint(name="ep1")
+    # with pytest.raises(FileExistsError):
+    #    recoverer.save_checkpoint(name="ep1")
 
 
 def test_recovery_custom_io(tmpdir):
@@ -118,9 +118,8 @@ def test_recovery_custom_io(tmpdir):
                 fo.write(str(self.param))
 
         @mark_as_loader
-        def load(self, path, end_of_epoch, device):
+        def load(self, path, end_of_epoch):
             del end_of_epoch  # Unused
-            del device
             with open(path) as fi:
                 self.param = int(fi.read())
 
@@ -252,6 +251,49 @@ def test_multiple_ckpts_and_criteria(tmpdir):
     assert found_ckpts == [fifth_ckpt, fourth_ckpt]
 
 
+def test_average_ckpts(tmpdir):
+    from speechbrain.utils.checkpoints import Checkpointer, average_checkpoints
+
+    class Recoverable(torch.nn.Module):
+        def __init__(self, param):
+            super().__init__()
+            self.param = torch.nn.Parameter(torch.tensor([param]))
+
+        def forward(self, x):
+            return x * self.param
+
+    N_avg = 2
+    recoverable = Recoverable(1.0)
+    recoverables = {"recoverable": recoverable}
+    recoverer = Checkpointer(tmpdir, recoverables)
+
+    # save first checkpoint
+    recoverer.save_and_keep_only(
+        meta={"error": 5},
+        min_keys=["error"],
+        keep_recent=True,
+        num_to_keep=N_avg,
+    )
+
+    # Save another checkpoint
+    recoverable.param = torch.nn.Parameter(torch.tensor([3.0]))
+
+    recoverer.save_and_keep_only(
+        meta={"error": 4},
+        min_keys=["error"],
+        keep_recent=True,
+        num_to_keep=N_avg,
+    )
+
+    recoverer.recover_if_possible()
+
+    checkpoints = recoverer.find_checkpoints(max_num_checkpoints=N_avg,)
+
+    model_state_dict = average_checkpoints(checkpoints, "recoverable")
+
+    assert model_state_dict["param"] == 2.0
+
+
 def test_torch_meta(tmpdir, device):
     from speechbrain.utils.checkpoints import Checkpointer
 
@@ -293,7 +335,7 @@ def test_checkpoint_hook_register(tmpdir):
                 fo.write(str(self.param))
 
         @mark_as_loader
-        def load(self, path, end_of_epoch, device):
+        def load(self, path, end_of_epoch):
             del end_of_epoch  # Unused
             with open(path) as fi:
                 self.param = int(fi.read())
@@ -317,8 +359,7 @@ def test_checkpoint_hook_register(tmpdir):
                     fo.write(str(self.param))
 
             @mark_as_loader
-            def load(self, path, end_of_epoch):  # MISSING device
-                del end_of_epoch  # Unused
+            def load(self, path):  # MISSING end_of_epoch
                 with open(path) as fi:
                     self.param = int(fi.read())
 
@@ -333,7 +374,7 @@ def test_checkpoint_hook_register(tmpdir):
                 with open(path, "w") as fo:
                     fo.write(str(self.param))
 
-            def load(self, path, end_of_epoch, device):
+            def load(self, path, end_of_epoch):
                 del end_of_epoch  # Unused
                 with open(path) as fi:
                     self.param = int(fi.read())
diff --git a/tests/unittests/test_conformer.py b/tests/unittests/test_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b35d9adb28c62d82c1ab101126728356733a1aa
--- /dev/null
+++ b/tests/unittests/test_conformer.py
@@ -0,0 +1,98 @@
+import torch
+
+
+@torch.no_grad
+def test_streaming_conformer_layer(device):
+    """Test whether the Conformer encoder layer masking code path (used at train
+    time) is equivalent to a real streaming scenario."""
+
+    from speechbrain.lobes.models.transformer.Conformer import (
+        ConformerEncoderLayer,
+    )
+    from speechbrain.nnet.attention import RelPosEncXL
+    from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+    from speechbrain.lobes.models.transformer.TransformerASR import (
+        make_transformer_src_mask,
+    )
+
+    TOLERATED_MEAN_ERROR = 1.0e-6
+
+    bs, seq_len, num_feats = input_shape = 1, 24, 16
+    config = DynChunkTrainConfig(chunk_size=8, left_context_size=1)
+
+    assert (
+        seq_len % config.chunk_size == 0
+    ), "For this test, we assume the sequence length can evenly be divided"
+    num_chunks = seq_len // config.chunk_size
+
+    torch.manual_seed(1337)
+
+    module = ConformerEncoderLayer(
+        d_model=num_feats, d_ffn=num_feats * 2, nhead=1, kernel_size=5
+    ).to(device=device)
+    module.eval()
+
+    pos_encoder = RelPosEncXL(num_feats).to(device=device)
+
+    # build inputs
+    test_input = torch.randn(input_shape, device=device)
+
+    test_input_chunked = test_input.unfold(
+        1, size=config.chunk_size, step=config.chunk_size
+    )
+    test_input_chunked = test_input_chunked.transpose(1, 3)
+    assert test_input_chunked.shape == (
+        bs,
+        config.chunk_size,
+        num_feats,
+        num_chunks,
+    ), "Test bug: incorrect shape for the chunked input?"
+
+    # build the transformer mask for masked inference (dynchunktrain_config does
+    # not suffice)
+    src_mask = make_transformer_src_mask(
+        test_input, dynchunktrain_config=config
+    )
+
+    # masked inference
+    pos_embs_full = pos_encoder(test_input)
+    out_mask_path, _out_attn = module(
+        test_input,
+        src_mask=src_mask,
+        pos_embs=pos_embs_full,
+        dynchunktrain_config=config,
+    )
+
+    # streaming inference
+    mutable_ctx = module.make_streaming_context(
+        config.left_context_size * config.chunk_size
+    )
+    output_chunks = []
+
+    for i in range(num_chunks):
+        chunk_in = test_input_chunked[..., i]
+
+        # HACK due to pos embeddings
+        pos_embs_dummy_input = chunk_in
+        if mutable_ctx.mha_left_context is not None:
+            pos_embs_dummy_input = torch.empty(
+                (
+                    bs,
+                    config.chunk_size + mutable_ctx.mha_left_context.size(1),
+                    num_feats,
+                ),
+                device=device,
+            )
+
+        pos_embs_chunk = pos_encoder(pos_embs_dummy_input)
+        chunk_out, _chunk_attn = module.forward_streaming(
+            chunk_in, mutable_ctx, pos_embs=pos_embs_chunk
+        )
+        output_chunks.append(chunk_out)
+
+    out_stream_path = torch.cat(output_chunks, dim=1)
+
+    # check output embedding differences
+    abs_diff = (out_mask_path - out_stream_path).abs()
+
+    assert torch.mean(abs_diff).item() < TOLERATED_MEAN_ERROR
diff --git a/tests/unittests/test_ctc_segmentation.py b/tests/unittests/test_ctc_segmentation.py
index c6df1a0b493e48910939c2c430783a609cc4827e..4c3315e3b5251b990ec21daeca1ecb98ae1d63e9 100644
--- a/tests/unittests/test_ctc_segmentation.py
+++ b/tests/unittests/test_ctc_segmentation.py
@@ -1,4 +1,4 @@
-from speechbrain.pretrained import EncoderDecoderASR
+from speechbrain.inference.ASR import EncoderDecoderASR
 import pytest
 
 pytest.importorskip(
@@ -17,6 +17,9 @@ def asr_model():
     return asr_model
 
 
+@pytest.mark.skip(
+    reason="interface refactoring still pending (incl. YAML files on HuggingFace)"
+)
 def test_CTCSegmentation(asr_model: EncoderDecoderASR):
     """Test CTC segmentation.
 
diff --git a/tests/unittests/test_dataloader.py b/tests/unittests/test_dataloader.py
index 0e5d427586cde47d43a03c95c55a078770ca5154..67c1cb84156a4a49cf7e7b6a5d69a4bbc91ac887 100644
--- a/tests/unittests/test_dataloader.py
+++ b/tests/unittests/test_dataloader.py
@@ -17,7 +17,7 @@ def test_saveable_dataloader(tmpdir, device):
     assert second_item == dataset[1]
     # Now make a new dataloader and recover:
     new_dataloader = SaveableDataLoader(dataset, collate_fn=None)
-    new_dataloader._speechbrain_load(save_file, end_of_epoch=False, device=None)
+    new_dataloader._speechbrain_load(save_file, end_of_epoch=False)
     new_data_iterator = iter(new_dataloader)
     second_second_item = next(new_data_iterator)
     assert second_second_item == second_item
@@ -48,7 +48,7 @@ def test_saveable_dataloader_multiprocess(tmpdir):
             dataset, num_workers=num_parallel, collate_fn=None
         )
         new_dataloader._speechbrain_load(
-            save_file, end_of_epoch=False, device=None
+            save_file, end_of_epoch=False,
         )
         new_data_iterator = iter(new_dataloader)
         second_second_item = next(new_data_iterator)
@@ -80,7 +80,7 @@ def test_looped_loader(tmpdir):
         next(data_iterator)
     # Now make a new dataloader and recover:
     new_dataloader = LoopedLoader(data, epoch_length=2)
-    new_dataloader.load(save_file, end_of_epoch=False, device=None)
+    new_dataloader.load(save_file, end_of_epoch=False)
     new_data_iterator = iter(new_dataloader)
     next(new_data_iterator)
     with pytest.raises(StopIteration):
diff --git a/tests/unittests/test_dynamic_chunk_training.py b/tests/unittests/test_dynamic_chunk_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c74d10d1bff558960297ae7f8a7ca76de2af54d
--- /dev/null
+++ b/tests/unittests/test_dynamic_chunk_training.py
@@ -0,0 +1,38 @@
+def test_dynchunktrain_sampler():
+    from speechbrain.core import Stage
+    from speechbrain.utils.dynamic_chunk_training import (
+        DynChunkTrainConfig,
+        DynChunkTrainConfigRandomSampler,
+    )
+
+    # sanity check and cover for the random smapler
+
+    valid_cfg = DynChunkTrainConfig(16, 32)
+    test_cfg = DynChunkTrainConfig(16, 32)
+
+    sampler = DynChunkTrainConfigRandomSampler(
+        chunkwise_prob=1.0,
+        chunk_size_min=8,
+        chunk_size_max=8,
+        limited_left_context_prob=1.0,
+        left_context_chunks_min=16,
+        left_context_chunks_max=16,
+        test_config=valid_cfg,
+        valid_config=test_cfg,
+    )
+
+    sampled_train_config = sampler(Stage.TRAIN)
+    assert sampled_train_config.chunk_size == 8
+    assert sampled_train_config.left_context_size == 16
+
+    assert sampler(Stage.VALID) == valid_cfg
+    assert sampler(Stage.TEST) == test_cfg
+
+
+def test_dynchunktrain():
+    from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+
+    assert DynChunkTrainConfig(chunk_size=16).is_infinite_left_context()
+    assert not DynChunkTrainConfig(
+        chunk_size=16, left_context_size=4
+    ).is_infinite_left_context()
diff --git a/tests/unittests/test_filter_analysis.py b/tests/unittests/test_filter_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..020d2b4f2e2ccfa6f1192798bb5dfa2cda7f268f
--- /dev/null
+++ b/tests/unittests/test_filter_analysis.py
@@ -0,0 +1,17 @@
+from speechbrain.utils.filter_analysis import FilterProperties
+
+
+def test_simple_filter_stacks():
+    assert FilterProperties(window_size=3, stride=2).with_on_top(
+        FilterProperties(window_size=3, stride=2)
+    ) == FilterProperties(window_size=7, stride=4)
+
+    assert FilterProperties(window_size=3, stride=1).with_on_top(
+        FilterProperties(window_size=3, stride=1)
+    ) == FilterProperties(window_size=5, stride=1)
+
+
+def test_causal_filter_properties():
+    assert FilterProperties(
+        3, 1, causal=True
+    ).get_noncausal_equivalent() == FilterProperties(5, 1)
diff --git a/tests/unittests/test_k2.py b/tests/unittests/test_k2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b2afaff4439d40d3bc548bac66020b585044bb5
--- /dev/null
+++ b/tests/unittests/test_k2.py
@@ -0,0 +1,419 @@
+import os
+import pytest
+import tempfile
+import shutil
+import torch
+from pathlib import Path
+from tempfile import TemporaryDirectory
+import logging
+from speechbrain.k2_integration import k2
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture
+def tmp_csv_file(tmp_path):
+    csv_file = tmp_path / "train.csv"
+    with open(csv_file, "w") as f:
+        f.write("ID,duration,wav,spk_id,wrd\n")
+        f.write("1,1,1,1,hello world\n")
+        f.write("2,0.5,1,1,hello\n")
+    return csv_file
+
+
+def test_get_lexicon(tmp_path, tmp_csv_file):
+    # Define the inputs
+    lang_dir = tmp_path
+    csv_files = [tmp_csv_file]
+    vocab_files = []  # This list is empty for simplicity in this test.
+
+    # Call the function
+    from speechbrain.k2_integration.lexicon import prepare_char_lexicon
+
+    prepare_char_lexicon(
+        lang_dir, vocab_files, csv_files, add_word_boundary=False
+    )
+
+    # Read the output and assert its content
+    with open(lang_dir / "lexicon.txt", "r") as f:
+        assert f.read() == "<UNK> <unk>\nhello h e l l o\nworld w o r l d\n"
+
+
+def test_get_lexicon_with_boundary(tmp_path, tmp_csv_file):
+    # Define the inputs
+    lang_dir = tmp_path
+    csv_files = [tmp_csv_file]
+    vocab_files = []
+
+    # Call the function with word boundaries
+    from speechbrain.k2_integration.lexicon import prepare_char_lexicon
+
+    prepare_char_lexicon(
+        lang_dir, vocab_files, csv_files, add_word_boundary=True
+    )
+
+    # Read the output and assert its content
+    with open(lang_dir / "lexicon.txt", "r") as f:
+        assert (
+            f.read()
+            == "<UNK> <unk>\nhello h e l l o <eow>\nworld w o r l d <eow>\n"
+        )
+
+
+@pytest.fixture
+def mock_lexicon_file(tmp_path):
+    lexicon_content = "hello h e l l o\nworld w o r l d\n"
+    lexicon_file = tmp_path / "mock_lexicon.txt"
+    with open(lexicon_file, "w") as f:
+        f.write(lexicon_content)
+    return lexicon_file
+
+
+def test_read_lexicon(mock_lexicon_file):
+    expected_output = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("world", ["w", "o", "r", "l", "d"]),
+    ]
+
+    from speechbrain.k2_integration.lexicon import read_lexicon
+
+    output = read_lexicon(mock_lexicon_file)
+    assert output == expected_output
+
+
+def test_write_lexicon(tmp_path):
+    # Sample lexicon data.
+    lexicon_data = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("world", ["w", "o", "r", "l", "d"]),
+    ]
+
+    # Path to save the lexicon file.
+    lexicon_file = tmp_path / "test_lexicon.txt"
+
+    # Use the function to write lexicon to the file.
+    from speechbrain.k2_integration.lexicon import write_lexicon
+
+    write_lexicon(lexicon_file, lexicon_data)
+
+    # Expected content of the lexicon file.
+    expected_content = "hello h e l l o\nworld w o r l d\n"
+
+    # Read back the content of the file and assert its correctness.
+    with open(lexicon_file, "r") as f:
+        assert f.read() == expected_content
+
+
+def test_get_tokens_basic():
+    # Prepare a mock lexicon
+    lexicon = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("world", ["w", "o", "r", "l", "d"]),
+    ]
+    from speechbrain.k2_integration.prepare_lang import get_tokens
+
+    tokens = get_tokens(lexicon)
+    expected_tokens = ["d", "e", "h", "l", "o", "r", "w"]
+    assert tokens == expected_tokens
+
+
+def test_get_tokens_with_sil():
+    # Prepare a mock lexicon
+    lexicon = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("world", ["w", "o", "r", "l", "d", "SIL"]),
+    ]
+    with pytest.raises(AssertionError):
+        from speechbrain.k2_integration.prepare_lang import get_tokens
+
+        get_tokens(lexicon)
+
+
+def test_get_tokens_manually_add_sil():
+    # Prepare a mock lexicon
+    lexicon = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("world", ["w", "o", "r", "l", "d"]),
+    ]
+    from speechbrain.k2_integration.prepare_lang import get_tokens
+
+    tokens = get_tokens(lexicon, manually_add_sil_to_tokens=True)
+    expected_tokens = ["SIL", "d", "e", "h", "l", "o", "r", "w"]
+    assert tokens == expected_tokens
+
+
+def test_unique_pronunciations():
+    lexicon = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("world", ["w", "o", "r", "l", "d"]),
+    ]
+    from speechbrain.k2_integration.prepare_lang import add_disambig_symbols
+
+    new_lexicon, max_disambig = add_disambig_symbols(lexicon)
+    assert new_lexicon == lexicon
+    assert max_disambig == 0
+
+
+def test_repeated_pronunciations():
+    lexicon = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("greeting", ["h", "e", "l", "l", "o"]),
+    ]
+    from speechbrain.k2_integration.prepare_lang import add_disambig_symbols
+
+    new_lexicon, max_disambig = add_disambig_symbols(lexicon)
+    assert new_lexicon == [
+        ("hello", ["h", "e", "l", "l", "o", "#1"]),
+        ("greeting", ["h", "e", "l", "l", "o", "#2"]),
+    ]
+    assert max_disambig == 2
+
+
+def test_prefix_pronunciations():
+    lexicon = [("he", ["h", "e"]), ("hello", ["h", "e", "l", "l", "o"])]
+    from speechbrain.k2_integration.prepare_lang import add_disambig_symbols
+
+    new_lexicon, max_disambig = add_disambig_symbols(lexicon)
+    assert new_lexicon == [
+        ("he", ["h", "e", "#1"]),
+        ("hello", ["h", "e", "l", "l", "o"]),
+    ]
+    assert max_disambig == 1
+
+
+def test_mixed_pronunciations():
+    lexicon = [
+        ("he", ["h", "e"]),
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("hey", ["h", "e"]),
+        ("world", ["h", "e", "l", "l", "o"]),
+    ]
+    from speechbrain.k2_integration.prepare_lang import add_disambig_symbols
+
+    new_lexicon, max_disambig = add_disambig_symbols(lexicon)
+    # Correct the expected output based on function behavior
+    assert new_lexicon == [
+        ("he", ["h", "e", "#1"]),
+        ("hello", ["h", "e", "l", "l", "o", "#1"]),
+        ("hey", ["h", "e", "#2"]),
+        ("world", ["h", "e", "l", "l", "o", "#2"]),
+    ]
+    assert max_disambig == 2
+
+
+def test_lexicon_to_fst():
+    # Sample lexicon: Each word maps to a list of tokens
+    lexicon = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("world", ["w", "o", "r", "l", "d"]),
+    ]
+
+    # Maps from token to ID and word to ID
+    token2id = {
+        "<eps>": 0,
+        "h": 1,
+        "e": 2,
+        "l": 3,
+        "o": 4,
+        "w": 5,
+        "r": 6,
+        "d": 7,
+        "SIL": 8,
+        "#0": 9,  # for self-loop
+    }
+
+    word2id = {"<eps>": 0, "hello": 1, "world": 2, "#0": 3}  # for self-loop
+
+    from speechbrain.k2_integration.prepare_lang import lexicon_to_fst
+
+    fsa = lexicon_to_fst(
+        lexicon=lexicon,
+        token2id=token2id,
+        word2id=word2id,
+        sil_token="SIL",
+        sil_prob=0.5,
+        need_self_loops=True,  # Assuming you have the add_self_loops function implemented
+    )
+
+    # Ensure fsa is a valid k2 FSA
+    assert isinstance(fsa, k2.Fsa)
+
+
+def test_lexicon_to_fst_no_sil():
+    # Sample lexicon: Each word maps to a list of tokens
+    lexicon = [
+        ("hello", ["h", "e", "l", "l", "o"]),
+        ("world", ["w", "o", "r", "l", "d"]),
+    ]
+
+    # Maps from token to ID and word to ID
+    token2id = {
+        "<eps>": 0,
+        "h": 1,
+        "e": 2,
+        "l": 3,
+        "o": 4,
+        "w": 5,
+        "r": 6,
+        "d": 7,
+        "#0": 8,  # for self-loop
+    }
+
+    word2id = {"<eps>": 0, "hello": 1, "world": 2, "#0": 3}  # for self-loop
+
+    from speechbrain.k2_integration.prepare_lang import lexicon_to_fst_no_sil
+
+    fsa = lexicon_to_fst_no_sil(
+        lexicon=lexicon,
+        token2id=token2id,
+        word2id=word2id,
+        need_self_loops=True,  # Assuming you have the add_self_loops function implemented
+    )
+
+    # Ensure fsa is a valid k2 FSA
+    assert isinstance(fsa, k2.Fsa)
+
+
+def test_prepare_lang():
+    # Step 1: Setup
+    temp_dir = tempfile.mkdtemp()
+
+    # Create a simple lexicon for testing
+    lexicon_content = """
+    hello h e l l o
+    world w o r l d
+    """
+    with open(os.path.join(temp_dir, "lexicon.txt"), "w") as f:
+        f.write(lexicon_content.strip())
+
+    # Step 2: Run prepare_lang
+    from speechbrain.k2_integration.prepare_lang import prepare_lang
+
+    prepare_lang(temp_dir, sil_token="SIL", sil_prob=0.5)
+
+    # Step 3: Check the output
+    # Check if the expected files are present
+    for expected_file in [
+        "tokens.txt",
+        "words.txt",
+        "L.pt",
+        "L_disambig.pt",
+        "Linv.pt",
+    ]:
+        assert os.path.exists(os.path.join(temp_dir, expected_file))
+
+    # Step 4: Cleanup
+    shutil.rmtree(temp_dir)
+
+
+def test_lexicon_loading_and_conversion():
+    with TemporaryDirectory() as tmpdir:
+        tmpdir_path = Path(tmpdir)
+
+        # Create a small lexicon containing only two words.
+        lexicon_sample = """<UNK> <unk>
+hello h e l l o
+world w o r l d"""
+        lexicon_file = tmpdir_path.joinpath("lexicon.txt")
+        with open(lexicon_file, "w") as f:
+            f.write(lexicon_sample)
+
+        # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt using prepare_lang
+        from speechbrain.k2_integration.prepare_lang import prepare_lang
+
+        prepare_lang(tmpdir_path)
+
+        # Create a lexicon object
+        from speechbrain.k2_integration.lexicon import Lexicon
+
+        lexicon = Lexicon(tmpdir_path)
+
+        # Assert instance types
+        assert isinstance(lexicon.token_table, k2.SymbolTable)
+        assert isinstance(lexicon.word_table, k2.SymbolTable)
+        assert isinstance(lexicon.L, k2.Fsa)
+
+        # Test conversion from texts to token IDs
+        hello_tids = lexicon.word_table["hello"]
+        world_tids = lexicon.word_table["world"]
+        expected_tids = [hello_tids] + [world_tids]
+        assert lexicon.texts_to_word_ids(["hello world"])[0] == expected_tids
+
+        # Test out-of-vocabulary words
+        # Assuming that <UNK> exists in the tokens:
+        unk_tid = lexicon.word_table["<UNK>"]
+        hello_tids = lexicon.word_table["hello"]
+        expected_oov_tids = [hello_tids] + [unk_tid]
+        assert (
+            lexicon.texts_to_word_ids(["hello universe"])[0]
+            == expected_oov_tids
+        )
+
+        # Test with sil_token as separator
+        # Assuming that SIL exists in the tokens:
+        sil_tid = lexicon.token_table["SIL"]
+        hello_tids = lexicon.word_table["hello"]
+        world_tids = lexicon.word_table["world"]
+        expected_sil_tids = [hello_tids] + [sil_tid] + [world_tids]
+        assert (
+            lexicon.texts_to_word_ids(
+                ["hello world"],
+                add_sil_token_as_separator=True,
+                sil_token_id=sil_tid,
+            )[0]
+            == expected_sil_tids
+        )
+
+
+def test_ctc_k2_loss():
+    # Create a random batch of log-probs
+    batch_size = 4
+    log_probs = torch.randn(batch_size, 100, 30).requires_grad_(True)
+    log_probs = torch.nn.functional.log_softmax(log_probs, dim=-1)
+    input_lens = torch.tensor([1, 0.9, 0.8, 0.7])
+
+    # Create a temporary directory for lexicon and other files
+    with TemporaryDirectory() as tmpdir:
+        # Create a small lexicon containing only two words and write it to a file.
+        lexicon_sample = """<UNK> <unk>
+hello h e l l o
+world w o r l d"""
+        lexicon_file_path = f"{tmpdir}/lexicon.txt"
+        with open(lexicon_file_path, "w") as f:
+            f.write(lexicon_sample)
+
+        # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
+        from speechbrain.k2_integration.prepare_lang import prepare_lang
+
+        prepare_lang(tmpdir)
+
+        # Create a lexicon object
+        from speechbrain.k2_integration.lexicon import Lexicon
+
+        lexicon = Lexicon(tmpdir)
+
+        # Create a graph compiler
+        from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
+
+        graph_compiler = CtcGraphCompiler(lexicon, device=log_probs.device,)
+
+        # Create a random batch of texts
+        texts = ["hello world", "world hello", "hello", "world"]
+
+        # Compute the loss
+        from speechbrain.k2_integration.losses import ctc_k2
+
+        loss = ctc_k2(
+            log_probs=log_probs,
+            input_lens=input_lens,
+            graph_compiler=graph_compiler,
+            texts=texts,
+            reduction="mean",
+            beam_size=10,
+            use_double_scores=True,
+            is_training=True,
+        )
+
+        # Assertions
+        assert loss.requires_grad
+        assert loss.item() >= 0  # Loss should be non-negative
diff --git a/tests/unittests/test_profiling.py b/tests/unittests/test_profiling.py
deleted file mode 100644
index d3efe0880c776f3a5168c507b33bcaa43236a828..0000000000000000000000000000000000000000
--- a/tests/unittests/test_profiling.py
+++ /dev/null
@@ -1,1024 +0,0 @@
-def test_profile_class(device):
-    import torch
-    from torch.optim import SGD
-    from speechbrain.core import Brain
-    from speechbrain.utils.profiling import profile
-
-    @profile
-    class SimpleBrain(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    model = torch.nn.Linear(in_features=10, out_features=10, device=device)
-    inputs = torch.rand(10, 10, device=device)
-    targets = torch.rand(10, 10, device=device)
-    train_set = ([inputs, targets],)
-    valid_set = ([inputs, targets],)
-
-    # Profiling: __init__ constructor.
-    brain = SimpleBrain(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    assert brain.profiler is not None
-    assert brain.profiler.profiler is not None
-    # assert len(brain.profiler.key_averages()) == 2
-    # assert (brain.profiler.events().total_average().count >= 4)  # == 6  # before; config dependent: 7
-    assert (
-        len(brain.profiler.speechbrain_event_traces) == 1
-    )  # set & filled by the @profile decorator
-    """print(brain.profiler.key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                           Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-        aten::_has_compatible_shallow_copy_type        73.33%      11.000us        73.33%      11.000us       2.750us             4
-                                       aten::to        26.67%       4.000us        26.67%       4.000us       2.000us             2
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 15.000us
-    """
-
-    # Profiling: fit() for train operations.
-    # By default, @profile should also annotate fit & evaluate functions; here the fit function is tested only.
-    brain.fit(epoch_counter=range(10), train_set=train_set, valid_set=valid_set)
-    assert brain.profiler is not None
-    # assert len(brain.profiler.key_averages()) >= 60  # 72 with torch==1.10.1
-    # assert brain.profiler.events().total_average().count >= 2000  # 2832 with torch==1.10.1
-    assert len(brain.profiler.speechbrain_event_traces) == 2
-    # assert len(brain.profiler.speechbrain_event_traces[0]) >= 4  # == 6  # before; config dependent: 7
-    # assert len(brain.profiler.speechbrain_event_traces[1]) >= 2000  # 2862 with torch==1.10.1
-    """print(brain.profiler.key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              aten::l1_loss         2.60%     443.000us        26.15%       4.460ms     111.500us            40
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        12.97%       2.212ms        20.28%       3.459ms      86.475us            40
-                                               aten::linear         1.28%     219.000us        12.97%       2.212ms     110.600us            20
-                                                   aten::to         1.88%     320.000us        10.68%       1.822ms      11.387us           160
-                                             aten::isfinite         1.98%     338.000us         9.82%       1.674ms      55.800us            30
-                                                 aten::mean         1.10%     188.000us         9.31%       1.587ms      79.350us            20
-                                             aten::_to_copy         5.65%     964.000us         8.99%       1.533ms      15.330us           100
-                                                aten::stack         2.23%     380.000us         8.67%       1.479ms      29.580us            50
-                                               aten::matmul         1.62%     277.000us         7.63%       1.301ms      65.050us            20
-                                     aten::l1_loss_backward         1.24%     212.000us         6.78%       1.157ms      57.850us            20
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 17.054ms
-    """
-
-
-def test_profile_func(device):
-    # import torch
-    # from pytest import raises
-    # from torch.optim import SGD
-    # from speechbrain.core import Brain
-    # from torch.autograd.profiler import record_function
-    from speechbrain.utils.profiling import profile
-
-    # from speechbrain.utils.profiling import events_diff
-
-    """
-    class SimpleBrain(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    class SimpleBrainNittyGritty(Brain):
-        def compute_forward(self, batch, stage):
-            # example: one way of using torch.autograd.profiler.record_function
-            with record_function("is this faster (?)"):
-                this = self.modules.model(batch[0])
-            return this
-
-        def compute_objectives(self, predictions, batch, stage):
-            # example: one could also think of running comparative testing using record_function
-            with record_function("or that (?)"):
-                that = torch.nn.functional.l1_loss(predictions, batch[1])
-            return that
-    """
-
-    @profile
-    def train(brain, train_set, valid_set):
-        brain.fit(
-            epoch_counter=range(10), train_set=train_set, valid_set=valid_set
-        )
-
-    """
-    model = torch.nn.Linear(in_features=10, out_features=10, device=device)
-    inputs = torch.rand(10, 10, device=device)
-    targets = torch.rand(10, 10, device=device)
-    training_set = ([inputs, targets],)
-    validation_set = ([inputs, targets],)
-    simple_brain = SimpleBrain(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-
-    prof_simple = train(simple_brain, training_set, validation_set)
-    # print(prof_simple.key_averages().table(sort_by="cpu_time_total"))
-    # assert len(prof_simple.events()) >= 2500  # 2832 with torch==1.10.1
-    # assert len(prof_simple.key_averages()) >= 60  # 72 with torch==1.10.1
-
-    simple_brain_nitty_gritty = SimpleBrainNittyGritty(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    prof_nitty_gritty = train(
-        simple_brain_nitty_gritty, training_set, validation_set
-    )
-    # print(prof_nitty_gritty.key_averages().table(sort_by="cpu_time_total"))
-    # assert len(prof_nitty_gritty.events()) >= 2500  # 3030 with torch==1.10.1
-    # assert len(prof_nitty_gritty.key_averages()) >= 60  # 74 with torch==1.10.1
-    """
-
-    # The outputs of this diff are only for visualisation, ``simple_delta._build_tree()`` will throw an error.
-    """
-    simple_delta, nitty_gritty_delta = events_diff(
-        prof_simple.key_averages(), prof_nitty_gritty.key_averages()
-    )
-    # assert len(simple_delta) >= 4  # == 6  # before; config dependent: 7
-    # assert len(nitty_gritty_delta) >= 4  # == 8  # before
-    # assert simple_delta.total_average().count == 582 #Switching off becuase sometimes it fails
-    # assert nitty_gritty_delta.total_average().count == 780 #Switching off becuase sometimes it fails
-    with raises(Exception) as err_tree:
-        simple_delta._build_tree()  # as mentioned.
-    assert err_tree.type == AttributeError
-    with raises(Exception) as err_averages:
-        simple_delta.key_averages()  # as mentioned.
-    assert err_averages.type == AssertionError
-    " ""Both classes have alike numbers of function calls (given the same input data and train function).
-    Sparing where both have the same number of calls:
-
-    print(simple_delta.table(sort_by="cpu_time_total"))
-    ----------------  ------------  ------------  ------------  ------------  ------------  ------------
-                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    ----------------  ------------  ------------  ------------  ------------  ------------  ------------
-         aten::zeros        17.44%     240.000us        27.69%     381.000us       6.350us            60
-         aten::empty        26.96%     371.000us        27.11%     373.000us       1.492us           250
-        aten::detach        18.68%     257.000us        25.65%     353.000us       5.694us            62
-          aten::add_        25.00%     344.000us        25.00%     344.000us       5.931us            58
-              detach         6.98%      96.000us         9.45%     130.000us       2.097us            62
-         aten::zero_         4.94%      68.000us         4.94%      68.000us       0.756us            90
-    ----------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 1.376ms
-
-    print(nitty_gritty_delta.table(sort_by="cpu_time_total"))
-    ----------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                      Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    ----------------------  ------------  ------------  ------------  ------------  ------------  ------------
-        is this faster (?)        29.70%       1.024ms        76.19%       2.627ms     131.350us            20
-               or that (?)        29.29%       1.010ms        67.00%       2.310ms     115.500us            20
-               aten::zeros         8.32%     287.000us        15.52%     535.000us       5.350us           100
-               aten::empty        14.01%     483.000us        14.07%     485.000us       1.470us           330
-                aten::add_         9.92%     342.000us         9.92%     342.000us       5.700us            60
-              aten::detach         2.81%      97.000us         6.09%     210.000us       3.500us            60
-                    detach         3.28%     113.000us         4.00%     138.000us       2.300us            60
-               aten::zero_         2.67%      92.000us         2.67%      92.000us       0.708us           130
-    ----------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 3.448ms
-
-    Curiosity doesn't come for free ;-)
-    """
-
-
-def test_scheduler(device):
-    import torch
-    from pytest import raises
-    from torch.optim import SGD
-    from speechbrain.core import Brain
-    from speechbrain.utils.profiling import profile, schedule
-
-    @schedule
-    @profile
-    class SimpleBrain(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    model = torch.nn.Linear(in_features=10, out_features=10, device=device)
-    inputs = torch.rand(10, 10, device=device)
-    targets = torch.rand(10, 10, device=device)
-    train_set = ([inputs, targets],)
-    valid_set = ([inputs, targets],)
-    test_set = ([inputs, targets],)
-
-    # Profiling: __init__ constructor -- while scheduler: waiting --> nothing to report
-    brain = SimpleBrain(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    assert brain.profiler.profiler is None
-    assert len(brain.profiler.speechbrain_event_traces) == 0
-    with raises(Exception) as err:
-        brain.profiler.events()  # Tracing hasn't started, yet, so everything is in err. Scheduler says: wait.
-    assert err.type == AssertionError
-    assert brain.profiler.step_num == 0
-
-    # Profiling: fit() for train operations.
-    brain.fit(epoch_counter=range(10), train_set=train_set, valid_set=valid_set)
-    assert brain.profiler.step_num == 20
-    assert len(brain.profiler.speechbrain_event_traces) == 1
-    # assert len(brain.profiler.events()) >= 250  # 293 with torch==1.10.1
-    # assert len(brain.profiler.key_averages()) >= 60  # 73 with torch==1.10.1
-    """print(brain.profiler.key_averages().table(sort_by="cpu_time_total"))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              ProfilerStep*        55.48%       1.504ms        99.00%       2.684ms       1.342ms             2
-                                              aten::l1_loss         1.07%      29.000us         9.30%     252.000us      63.000us             4
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         4.57%     124.000us         7.01%     190.000us      47.500us             4
-                                               aten::linear         2.32%      63.000us         6.93%     188.000us      94.000us             2
-                                             aten::isfinite         0.89%      24.000us         5.35%     145.000us      48.333us             3
-                                                   aten::to         1.00%      27.000us         4.46%     121.000us       7.562us            16
-                                             aten::_to_copy         2.18%      59.000us         3.58%      97.000us       9.700us            10
-                                                aten::stack         0.89%      24.000us         3.28%      89.000us      17.800us             5
-                                     aten::l1_loss_backward         0.55%      15.000us         3.14%      85.000us      42.500us             2
-                                                 aten::mean         0.48%      13.000us         3.10%      84.000us      42.000us             2
-                                               aten::matmul         0.55%      15.000us         2.84%      77.000us      38.500us             2
-                                                   aten::ne         0.77%      21.000us         2.73%      74.000us      24.667us             3
-                                                   aten::mm         2.29%      62.000us         2.36%      64.000us      21.333us             3
-                                                  aten::div         1.36%      37.000us         2.25%      61.000us      20.333us             3
-                                    Optimizer.step#SGD.step         1.88%      51.000us         2.21%      60.000us      60.000us             1
-    autograd::engine::evaluate_function: L1LossBackward0...         0.11%       3.000us         2.07%      56.000us      56.000us             1
-                                                aten::zeros         1.44%      39.000us         1.99%      54.000us       6.750us             8
-                                            L1LossBackward0         0.11%       3.000us         1.95%      53.000us      53.000us             1
-                                                    aten::t         0.89%      24.000us         1.88%      51.000us      10.200us             5
-                                                  aten::cat         0.59%      16.000us         1.81%      49.000us       9.800us             5
-                                                 aten::div_         0.70%      19.000us         1.73%      47.000us      23.500us             2
-           autograd::engine::evaluate_function: MmBackward0         0.11%       3.000us         1.70%      46.000us      46.000us             1
-                                                  aten::abs         1.14%      31.000us         1.59%      43.000us       5.375us             8
-                                                MmBackward0         0.37%      10.000us         1.59%      43.000us      43.000us             1
-                                                  aten::sum         1.00%      27.000us         1.33%      36.000us      12.000us             3
-                                                aten::empty         1.29%      35.000us         1.29%      35.000us       1.207us            29
-          autograd::engine::evaluate_function: AddBackward0         0.63%      17.000us         1.25%      34.000us      34.000us             1
-                                                 aten::_cat         0.85%      23.000us         1.22%      33.000us       6.600us             5
-                                                 aten::add_         1.11%      30.000us         1.11%      30.000us       5.000us             6
-                                            aten::transpose         0.70%      19.000us         0.92%      25.000us       5.000us             5
-          autograd::engine::evaluate_function: DivBackward0         0.18%       5.000us         0.92%      25.000us      25.000us             1
-                                                 aten::norm         0.77%      21.000us         0.89%      24.000us       8.000us             3
-                          Optimizer.zero_grad#SGD.zero_grad         0.63%      17.000us         0.81%      22.000us      22.000us             1
-                                                 aten::item         0.55%      15.000us         0.77%      21.000us       3.000us             7
-                                                aten::copy_         0.77%      21.000us         0.77%      21.000us       2.100us            10
-                                               DivBackward0         0.15%       4.000us         0.74%      20.000us      20.000us             1
-    autograd::engine::evaluate_function: torch::autograd...         0.15%       4.000us         0.74%      20.000us      10.000us             2
-                                                  aten::mul         0.55%      15.000us         0.74%      20.000us       5.000us             4
-                                               aten::detach         0.37%      10.000us         0.70%      19.000us       3.167us             6
-                                           aten::as_strided         0.59%      16.000us         0.63%      17.000us       1.062us            16
-                                            aten::unsqueeze         0.41%      11.000us         0.59%      16.000us       2.667us             6
-                                        aten::empty_strided         0.59%      16.000us         0.59%      16.000us       1.455us            11
-                            torch::autograd::AccumulateGrad         0.30%       8.000us         0.59%      16.000us       8.000us             2
-                                           aten::is_nonzero         0.18%       5.000us         0.59%      16.000us       5.333us             3
-                                                 aten::view         0.55%      15.000us         0.55%      15.000us       3.000us             5
-                                              aten::random_         0.52%      14.000us         0.52%      14.000us       7.000us             2
-    autograd::engine::evaluate_function: UnsafeViewBackw...         0.11%       3.000us         0.52%      14.000us      14.000us             1
-                                                  aten::sub         0.48%      13.000us         0.48%      13.000us       4.333us             3
-                                                  aten::add         0.18%       5.000us         0.48%      13.000us      13.000us             1
-                                            aten::ones_like         0.15%       4.000us         0.44%      12.000us      12.000us             1
-                                                 aten::mul_         0.44%      12.000us         0.44%      12.000us       4.000us             3
-                                                     detach         0.33%       9.000us         0.44%      12.000us       2.000us             6
-                                        UnsafeViewBackward0         0.11%       3.000us         0.41%      11.000us      11.000us             1
-                                                 aten::abs_         0.18%       5.000us         0.37%      10.000us       5.000us             2
-                                           aten::empty_like         0.26%       7.000us         0.37%      10.000us       5.000us             2
-                                                aten::clamp         0.26%       7.000us         0.37%      10.000us      10.000us             1
-                                           aten::zeros_like         0.18%       5.000us         0.33%       9.000us       9.000us             1
-                                                   aten::eq         0.33%       9.000us         0.33%       9.000us       3.000us             3
-                                                aten::zero_         0.30%       8.000us         0.30%       8.000us       0.727us            11
-                                         aten::_unsafe_view         0.22%       6.000us         0.30%       8.000us       4.000us             2
-                                                aten::fill_         0.30%       8.000us         0.30%       8.000us       2.000us             4
-                                              aten::reshape         0.18%       5.000us         0.30%       8.000us       8.000us             1
-                                  aten::_local_scalar_dense         0.22%       6.000us         0.26%       7.000us       1.000us             7
-            autograd::engine::evaluate_function: TBackward0         0.11%       3.000us         0.26%       7.000us       7.000us             1
-                                              aten::resize_         0.22%       6.000us         0.22%       6.000us       1.200us             5
-                                           aten::reciprocal         0.22%       6.000us         0.22%       6.000us       6.000us             1
-                                                 aten::sgn_         0.15%       4.000us         0.15%       4.000us       4.000us             1
-                                                 TBackward0         0.07%       2.000us         0.15%       4.000us       4.000us             1
-                                       aten::_reshape_alias         0.11%       3.000us         0.11%       3.000us       3.000us             1
-                                            aten::clamp_max         0.11%       3.000us         0.11%       3.000us       3.000us             1
-                                         aten::resolve_conj         0.07%       2.000us         0.07%       2.000us       0.333us             6
-                                    aten::broadcast_tensors         0.07%       2.000us         0.07%       2.000us       1.000us             2
-                                               AddBackward0         0.04%       1.000us         0.04%       1.000us       1.000us             1
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 2.711ms  <===  above: Self CPU time total: 18.451ms (... the impact of warm-up)
-    """
-
-    @schedule
-    @profile
-    def train():
-        # The step() function is executed inside speechbrain.core.brain.fit and is property of the Brain's profiler.
-        # Above profiler and its scheduler are without power here, since prof.step() is not run - at all.
-        brain.fit(
-            epoch_counter=range(10), train_set=train_set, valid_set=valid_set
-        )
-
-    prof = train()
-    # since we used the same brain (which has its own profiler)
-    # assert brain.profiler.step_num == 20  # started again from 0 steps
-    assert len(brain.profiler.speechbrain_event_traces) == 2
-    # assert len(brain.profiler.events()) >= 250  # 293 with torch==1.10.1  # unchanged (overwritten with akin data)
-    # assert len(brain.profiler.key_averages()) >= 60  # 73 with torch==1.10.1 # unchanged (akin data)
-    # now, to the train function's profiler
-    assert (
-        prof.step_num == 0
-    )  # the prof.step() operation wasn't run (not in scope) -> its scheduler is unawaken!
-    assert not hasattr(prof, "speechbrain_event_traces")  # no trace collection
-    with raises(Exception) as err_prof:
-        prof.events()  # No tracing started with this one.
-    assert err_prof.type == AssertionError  # sparing: key_averages()
-
-    # But how to add profiling then if no writing access is there for a class... pretrained, for example:
-    class SimpleBrainUntracked(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    brain_or_pretrained = SimpleBrainUntracked(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-
-    # Set-up the profiler and hook it to the model.
-    scheduled_profiler = schedule(profile)
-    scheduled_profiler(brain_or_pretrained)
-
-    # Profiling: still too early for scheduler!
-    brain_or_pretrained.evaluate(test_set=test_set)  # -> step_num=1
-    assert brain_or_pretrained.profiler.step_num == 1
-    assert brain_or_pretrained.profiler.profiler is None
-
-    # Profiling: scheduler warms-up.
-    brain_or_pretrained.evaluate(
-        test_set=(
-            [inputs, targets],  # +1x test_set -> step_num=2
-            [inputs, targets],
-        )
-    )
-    assert brain_or_pretrained.profiler.step_num == 2
-    # brain_or_pretrained.profiler.profiler will be set (not None anymore)
-    # when run on cpu, there are no events - but cuda activities are recorded if existing
-    # see: https://github.com/speechbrain/speechbrain/issues/1469
-    if (
-        torch.profiler.ProfilerActivity.CUDA
-        in brain_or_pretrained.profiler.activities
-    ):
-        assert (
-            len(
-                set(
-                    [
-                        x.name
-                        for x in brain_or_pretrained.profiler.profiler.function_events
-                    ]
-                )
-                - {
-                    "cudaGetDeviceCount",
-                    "cudaGetDeviceProperties",
-                    "cudaDeviceSynchronize",
-                }
-            )
-        ) == 0
-    else:
-        assert len(brain_or_pretrained.profiler.events()) == 0
-
-    # Profiling: scheduler warms-up...
-    brain_or_pretrained.evaluate(
-        test_set=(
-            [inputs, targets],  # +1x test_set
-            [inputs, targets],  # +2x test_set -> step_num=3
-            [inputs, targets],
-        )
-    )
-    assert brain_or_pretrained.profiler.step_num == 3
-    if (
-        torch.profiler.ProfilerActivity.CUDA
-        in brain_or_pretrained.profiler.activities
-    ):
-        assert (
-            len(
-                set(
-                    [
-                        x.name
-                        for x in brain_or_pretrained.profiler.profiler.function_events
-                    ]
-                )
-                - {
-                    "cudaGetDeviceCount",
-                    "cudaGetDeviceProperties",
-                    "cudaDeviceSynchronize",
-                }
-            )
-        ) == 0
-    else:
-        assert len(brain_or_pretrained.profiler.events()) == 0
-
-    # Profiling: first trace!
-    brain_or_pretrained.evaluate(
-        test_set=(
-            [inputs, targets],  # +1x test_set
-            [inputs, targets],  # +2x test_set
-            [inputs, targets],  # +3x test_set -> step_num=4
-            [inputs, targets],
-        )
-    )
-    assert brain_or_pretrained.profiler.step_num == 4
-    # assert len(brain_or_pretrained.profiler.events()) >= 4  # == 10  # before
-    # assert len(brain_or_pretrained.profiler.key_averages()) >= 4  # == 5  # before
-    assert (
-        len(brain_or_pretrained.profiler.events()) >= 1
-    )  # 1 on CPU; more w/ CUDA
-
-
-def test_tracer(device):
-    import torch
-    from torch.optim import SGD
-    from speechbrain.core import Brain
-    from speechbrain.utils.profiling import profile, export
-
-    @export
-    @profile
-    class SimpleBrain(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    model = torch.nn.Linear(in_features=10, out_features=10, device=device)
-    inputs = torch.rand(10, 10, device=device)
-    targets = torch.rand(10, 10, device=device)
-    train_set = ([inputs, targets],)
-    valid_set = ([inputs, targets],)
-    test_set = ([inputs, targets],)
-
-    # Profiling: __init__ constructor and model training.
-    brain = SimpleBrain(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    brain.fit(epoch_counter=range(10), train_set=train_set, valid_set=valid_set)
-
-    # Pretrained example.
-    class SimpleBrainUntracked(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    # No tracing during __init__
-    brain_or_pretrained = SimpleBrainUntracked(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    profile(brain_or_pretrained, on_trace_ready=export(), with_stack=True)
-    brain_or_pretrained.evaluate(test_set=test_set)
-
-    # Set-up the profiler; hook it to the model, and benchmark inference.
-    brain_or_pretrained2 = SimpleBrainUntracked(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    logged_profiler = export(profile)
-    assert brain_or_pretrained2.profiler is None
-    logged_profiler(brain_or_pretrained2)
-    brain_or_pretrained2.evaluate(test_set=test_set)
-
-
-def test_aggregated_traces(device):
-    import torch
-    from torch.optim import SGD
-    from speechbrain.core import Brain
-    from speechbrain.utils.profiling import profile
-
-    @profile
-    class SimpleBrain(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    model = torch.nn.Linear(in_features=10, out_features=10, device=device)
-    inputs = torch.rand(10, 10, device=device)
-    targets = torch.rand(10, 10, device=device)
-    train_set = ([inputs, targets],)
-    valid_set = ([inputs, targets],)
-    test_set = (
-        [inputs, targets],
-        [inputs, targets],
-    )
-
-    # Profiling: __init__ constructor -- while scheduler: waiting --> nothing to report
-    brain = SimpleBrain(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-
-    # Profiling: empty traces
-    assert len(brain.profiler.speechbrain_event_traces) == 1
-    """
-    init_report = brain.profiler.merge_traces()
-    assert len(init_report) >= 1
-    # assert len(init_report) >= 4  # == 6  # before; config dependent: 7
-    assert len(brain.profiler.speechbrain_event_traces) == 1
-    " ""print(brain.profiler.key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                           Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-        aten::_has_compatible_shallow_copy_type        80.00%      12.000us        80.00%      12.000us       3.000us             4
-                                       aten::to        20.00%       3.000us        20.00%       3.000us       1.500us             2
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 15.000us
-
-    print(init_report.key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                           Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-        aten::_has_compatible_shallow_copy_type        80.00%      12.000us        80.00%      12.000us       3.000us             4
-                                       aten::to        20.00%       3.000us        20.00%       3.000us       1.500us             2
-    -------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 15.000us
-    """
-
-    brain.fit(epoch_counter=range(10), train_set=train_set, valid_set=valid_set)
-    assert len(brain.profiler.speechbrain_event_traces) == 2
-    # assert len(brain.profiler.speechbrain_event_traces[0]) >= 4  # == 6  # before; config dependent: 7
-    # assert len(brain.profiler.speechbrain_event_traces[1]) >= 2500  # 2862 with torch==1.10.1
-    # assert len(brain.profiler.events()) >= 2500  # 2832 with torch==1.10.1
-    # assert len(brain.profiler.events().key_averages()) >= 60  # 72 with torch==1.10.1
-    """print(brain.profiler.events().key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              aten::l1_loss         2.39%     415.000us        25.28%       4.392ms     109.800us            40
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        12.65%       2.198ms        20.06%       3.485ms      87.125us            40
-                                               aten::linear         1.41%     245.000us        13.24%       2.299ms     114.950us            20
-                                                   aten::to         2.04%     354.000us        10.59%       1.839ms      11.494us           160
-                                             aten::isfinite         2.12%     369.000us         9.87%       1.714ms      57.133us            30
-                                                 aten::mean         1.13%     196.000us         9.15%       1.589ms      79.450us            20
-                                                aten::stack         2.33%     404.000us         8.94%       1.553ms      31.060us            50
-                                             aten::_to_copy         5.67%     985.000us         8.67%       1.506ms      15.060us           100
-                                               aten::matmul         1.57%     273.000us         7.83%       1.360ms      68.000us            20
-                                     aten::l1_loss_backward         1.22%     212.000us         6.67%       1.158ms      57.900us            20
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 17.370ms
-    """
-
-    # Profiling: aggregate traces
-    """
-    short_report = brain.profiler.merge_traces()
-    assert len(short_report) >= 1
-    # assert len(short_report) >= 2500  # 2838 with torch==1.10.1
-    # assert len(short_report.key_averages()) >= 60  # 73 with torch==1.10.1
-    " ""print(short_report.key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              aten::l1_loss         2.39%     415.000us        25.26%       4.392ms     109.800us            40
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        12.64%       2.198ms        20.05%       3.485ms      87.125us            40
-                                               aten::linear         1.41%     245.000us        13.22%       2.299ms     114.950us            20
-                                                   aten::to         2.05%     357.000us        10.60%       1.842ms      11.370us           162
-                                             aten::isfinite         2.12%     369.000us         9.86%       1.714ms      57.133us            30
-                                                 aten::mean         1.13%     196.000us         9.14%       1.589ms      79.450us            20
-                                                aten::stack         2.32%     404.000us         8.93%       1.553ms      31.060us            50
-                                             aten::_to_copy         5.67%     985.000us         8.66%       1.506ms      15.060us           100
-                                               aten::matmul         1.57%     273.000us         7.82%       1.360ms      68.000us            20
-                                     aten::l1_loss_backward         1.22%     212.000us         6.66%       1.158ms      57.900us            20
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 17.385ms
-    """
-
-    brain.evaluate(test_set=test_set)
-    brain.evaluate(test_set=test_set)
-    brain.evaluate(test_set=test_set)
-    assert len(brain.profiler.speechbrain_event_traces) == 5
-    # assert len(brain.profiler.speechbrain_event_traces[0]) >= 4  # == 6  # before; config dependent: 7
-    # assert len(brain.profiler.speechbrain_event_traces[1]) >= 2500  # 2862 with torch==1.10.1
-    # assert len(brain.profiler.speechbrain_event_traces[2]) >= 125  # 143 with torch==1.10.1
-    # assert len(brain.profiler.speechbrain_event_traces[3]) >= 125  # 143 with torch==1.10.1
-    # assert len(brain.profiler.speechbrain_event_traces[4]) >= 125  # 143 with torch==1.10.1
-    # assert len(brain.profiler.events()) >= 125  # 141 with torch==1.10.1
-    # assert len(brain.profiler.events().key_averages()) >= 25  # 42 with torch==1.10.1
-    # the following is only for the last call of the 3x brain.evaluate()
-    """print(brain.profiler.events().key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              aten::l1_loss         3.54%      23.000us        37.38%     243.000us      60.750us             4
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        16.62%     108.000us        29.38%     191.000us      63.667us             3
-                                               aten::linear         2.77%      18.000us        20.46%     133.000us      66.500us             2
-                                             aten::isfinite         3.54%      23.000us        14.31%      93.000us      46.500us             2
-                                                 aten::mean         1.85%      12.000us        12.62%      82.000us      41.000us             2
-                                                aten::stack         3.23%      21.000us        12.15%      79.000us      19.750us             4
-                                               aten::matmul         2.62%      17.000us        11.69%      76.000us      38.000us             2
-                                                 aten::div_         2.92%      19.000us         7.54%      49.000us      24.500us             2
-                                                   aten::to         1.85%      12.000us         7.08%      46.000us       7.667us             6
-                                                   aten::mm         6.31%      41.000us         6.77%      44.000us      22.000us             2
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 650.000us
-    """
-
-    # Profiling: putting previous benchmark reporting together.
-    """
-    full_report = brain.profiler.merge_traces()
-    assert len(full_report) >= 1
-    # assert len(full_report.key_averages()) >= 60  # 73 with torch==1.10.1
-    # In this minimal example, only 73 functions matter.
-    # Some events are duplicated (perhaps from wrapping functions):
-    # => they appear stacked & EventList._remove_dup_nodes drops direct child events of same name as their parent.
-    num_events = sum([len(x) for x in brain.profiler.speechbrain_event_traces])
-    assert num_events >= 1
-    # assert num_events >= 3000  # 3297 with torch==1.10.1  # expected: 6 + 2862 + 3x143 = 3297
-    # Apparently, this depends on how this test is run (by its own or as part of the entire file's test suite).
-    # assert (num_events == len(full_report)) or (len(full_report) == len(set([x.id for x in full_report])))
-    # ... not tested, why
-    " ""print(full_report.key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              aten::l1_loss         2.21%     427.000us        27.58%       5.326ms     102.423us            52
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        11.89%       2.297ms        21.80%       4.210ms      85.918us            49
-                                               aten::linear         1.57%     304.000us        14.61%       2.821ms     108.500us            26
-                                             aten::isfinite         2.53%     488.000us        10.72%       2.071ms      57.528us            36
-                                                   aten::to         2.03%     392.000us        10.42%       2.013ms      11.183us           180
-                                                 aten::mean         1.15%     223.000us         9.81%       1.894ms      72.846us            26
-                                                aten::stack         2.34%     452.000us         9.54%       1.842ms      29.710us            62
-                                             aten::_to_copy         5.49%       1.061ms         8.51%       1.643ms      14.670us           112
-                                               aten::matmul         1.65%     318.000us         8.48%       1.638ms      63.000us            26
-                                                 aten::div_         1.50%     290.000us         6.95%       1.343ms      47.964us            28
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 19.311ms
-    """
-    # 19.311ms = 3x ~650.000us + 17.370ms + 15.000us <=> 1st & 2nd call of brain.evaluate() = 1276us = 2x 638us
-    # max([x.time_range.end for x in full_report]) -> 41965 (us)
-
-
-def test_profile_details(device):
-    import torch
-
-    # from copy import deepcopy
-    from torch.optim import SGD
-    from speechbrain.core import Brain
-    from speechbrain.utils.profiling import (
-        profile_analyst,
-        profile_optimiser,
-        export,
-        # events_diff,
-    )
-
-    class SimpleBrain(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    model = torch.nn.Linear(in_features=10, out_features=10, device=device)
-    inputs = torch.rand(10, 10, device=device)
-    targets = torch.rand(10, 10, device=device)
-    train_set = ([inputs, targets],)
-    valid_set = ([inputs, targets],)
-    test_set = (
-        [inputs, targets],
-        [inputs, targets],
-        [inputs, targets],
-        [inputs, targets],
-        [inputs, targets],
-        [inputs, targets],
-    )
-
-    brain_analyst = profile_analyst(
-        SimpleBrain(
-            {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-        )
-    )
-    brain_optimiser = profile_optimiser(
-        SimpleBrain(
-            {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-        )
-    )
-
-    assert len(brain_analyst.profiler.speechbrain_event_traces) == 0
-    brain_analyst.fit(
-        epoch_counter=range(10), train_set=train_set, valid_set=valid_set
-    )
-    assert len(brain_analyst.profiler.speechbrain_event_traces) == 1
-    # assert len(brain_analyst.profiler.speechbrain_event_traces[0]) >= 250  # 296 with torch==1.10.1
-    # assert len(brain_analyst.profiler.events()) >= 250  # 293 with torch==1.10.1
-    # assert len(brain_analyst.profiler.events().key_averages()) >= 60  # 73 with torch==1.10.1
-    """print(brain_analyst.profiler.events().key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls   Total FLOPs
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              ProfilerStep*        36.62%       4.345ms        98.50%      11.686ms       5.843ms         796 b      -1.66 Kb             2            --
-                                              aten::l1_loss         2.50%     297.000us        14.37%       1.705ms     426.250us          16 b        -800 b             4            --
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         3.91%     464.000us        11.01%       1.306ms     326.500us       1.55 Kb         -80 b             4            --
-                                             aten::isfinite         3.10%     368.000us         9.65%       1.145ms     381.667us           3 b         -18 b             3            --
-                                                aten::stack         2.79%     331.000us         8.46%       1.004ms     200.800us       1.57 Kb           0 b             5            --
-                                                   aten::to         1.96%     232.000us         7.25%     860.000us      53.750us          40 b           0 b            16            --
-                                               aten::linear         1.40%     166.000us         6.55%     777.000us     388.500us         800 b           0 b             2            --
-                                     aten::l1_loss_backward         1.58%     188.000us         6.08%     721.000us     360.500us         400 b          -4 b             2            --
-                                             aten::_to_copy         4.28%     508.000us         5.29%     628.000us      62.800us          40 b           0 b            10            --
-                                                 aten::mean         0.89%     105.000us         4.44%     527.000us     263.500us           8 b           8 b             2            --
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 11.864ms
-    """
-    # 6-batch inference
-    brain_analyst.evaluate(test_set=test_set)
-    assert len(brain_analyst.profiler.speechbrain_event_traces) == 2
-    # assert len(brain_analyst.profiler.speechbrain_event_traces[0]) >= 250  # 296 with torch==1.10.1
-    # assert len(brain_analyst.profiler.speechbrain_event_traces[1]) >= 125  # 144 with torch==1.10.1
-    # as of evaluate() call
-    # assert len(brain_analyst.profiler.events()) >= 125  # 142 with torch==1.10.1
-    # assert len(brain_analyst.profiler.events().key_averages()) >= 25  # 42 with torch==1.10.1
-    """print(brain_analyst.profiler.events().key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls   Total FLOPs
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              ProfilerStep*        19.24%       1.129ms        96.92%       5.687ms       2.844ms         796 b      -1.61 Kb             2            --
-                                              aten::l1_loss         5.16%     303.000us        35.50%       2.083ms     520.750us          16 b        -800 b             4            --
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         7.41%     435.000us        25.95%       1.523ms     761.500us       1.55 Kb         -40 b             2            --
-                                                aten::stack         6.15%     361.000us        18.20%       1.068ms     267.000us       1.56 Kb           0 b             4            --
-                                               aten::linear         3.43%     201.000us        15.78%     926.000us     463.000us         800 b           0 b             2            --
-                                                 aten::mean         2.42%     142.000us        11.93%     700.000us     350.000us           8 b           8 b             2            --
-                                             aten::isfinite         3.72%     218.000us        10.84%     636.000us     318.000us           2 b         -12 b             2            --
-                                                  aten::cat         2.68%     157.000us         8.95%     525.000us     131.250us       1.56 Kb           0 b             4            --
-                                               aten::matmul         3.83%     225.000us         8.88%     521.000us     260.500us         800 b           0 b             2            --
-                                                 aten::div_         3.34%     196.000us         6.97%     409.000us     204.500us           0 b          -8 b             2            --
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 5.868ms
-    """
-
-    brain_optimiser.fit(
-        epoch_counter=range(10), train_set=train_set, valid_set=valid_set
-    )
-    # key_avg_fit = deepcopy(brain_optimiser.profiler.events().key_averages())
-    """print(brain_optimiser.profiler.events().key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              ProfilerStep*        50.73%       1.874ms        98.86%       3.652ms       1.826ms         796 b      -1.66 Kb             2
-                                              aten::l1_loss         1.38%      51.000us        11.83%     437.000us     109.250us          16 b        -400 b             4
-                                             aten::isfinite         1.49%      55.000us         8.69%     321.000us     107.000us           3 b         -16 b             3
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         4.03%     149.000us         7.58%     280.000us      70.000us       1.55 Kb         -64 b             4
-                                               aten::linear         0.51%      19.000us         5.66%     209.000us     104.500us         800 b           0 b             2
-                                                  aten::abs         3.19%     118.000us         5.58%     206.000us      25.750us          24 b          12 b             8
-                                     aten::l1_loss_backward         0.92%      34.000us         4.28%     158.000us      79.000us         400 b          -4 b             2
-                                                aten::stack         1.00%      37.000us         3.87%     143.000us      28.600us       1.57 Kb           0 b             5
-                                                aten::empty         3.76%     139.000us         3.76%     139.000us       4.793us         544 b         544 b            29
-                                                   aten::to         0.76%      28.000us         3.76%     139.000us       8.688us          44 b           4 b            16
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 3.694ms
-    # to 11.864ms (analyst)
-    """
-
-    brain_optimiser.evaluate(test_set=test_set)
-    """
-    key_avg_evaluate = deepcopy(
-        brain_optimiser.profiler.events().key_averages()
-    )
-    """
-    """print(brain_optimiser.profiler.events().key_averages().table(sort_by="cpu_time_total", row_limit=10))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                              ProfilerStep*        24.80%     524.000us        96.50%       2.039ms       1.020ms         796 b      -1.61 Kb             2
-                                              aten::l1_loss         2.74%      58.000us        33.65%     711.000us     177.750us          16 b        -800 b             4
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         9.94%     210.000us        21.11%     446.000us     223.000us       1.55 Kb         -40 b             2
-                                             aten::isfinite         3.64%      77.000us        15.76%     333.000us     166.500us           2 b         -12 b             2
-                                               aten::linear         1.04%      22.000us        11.88%     251.000us     125.500us         800 b           0 b             2
-                                                 aten::mean         2.04%      43.000us        11.74%     248.000us     124.000us           8 b           8 b             2
-                                                aten::stack         3.31%      70.000us        10.18%     215.000us      53.750us       1.56 Kb           0 b             4
-                                                 aten::div_         4.83%     102.000us         7.90%     167.000us      83.500us           0 b          -8 b             2
-                                               aten::matmul         1.61%      34.000us         7.38%     156.000us      78.000us         800 b           0 b             2
-                                                   aten::ne         4.54%      96.000us         6.72%     142.000us      71.000us           2 b          -6 b             2
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 2.113ms
-    # to 5.868ms (analyst)
-    """
-    # same check as for analyst
-    assert len(brain_optimiser.profiler.speechbrain_event_traces) == 2
-    # assert len(brain_optimiser.profiler.speechbrain_event_traces[0]) >= 250  # 296 with torch==1.10.1
-    # assert len(brain_optimiser.profiler.speechbrain_event_traces[1]) >= 125  # 144 with torch==1.10.1
-    # as of evaluate() call
-    # assert len(brain_optimiser.profiler.events()) >= 125  # 142 with torch==1.10.1
-    # assert len(brain_optimiser.profiler.events().key_averages()) >= 25  # 42 with torch==1.10.1
-    # different config
-    assert (
-        brain_optimiser.profiler.record_shapes
-        != brain_analyst.profiler.record_shapes
-    )
-    assert (
-        brain_optimiser.profiler.with_stack != brain_analyst.profiler.with_stack
-    )
-    assert (
-        brain_optimiser.profiler.with_flops != brain_analyst.profiler.with_flops
-    )
-    # same config
-    assert (
-        brain_optimiser.profiler.with_modules
-        == brain_analyst.profiler.with_modules
-    )
-    assert (
-        brain_optimiser.profiler.profile_memory
-        == brain_analyst.profiler.profile_memory
-    )
-
-    """
-    # let's take a look at the diff
-    diff_fit, diff_evaluate = events_diff(key_avg_fit, key_avg_evaluate)
-    # assert len(diff_fit) >= 50  # 64 with torch==1.10.1
-    # assert len(diff_evaluate) >= 25  # 33 with torch==1.10.1
-    # assert diff_fit.total_average().count >= 250  # 273 with torch==1.10.1
-    # assert diff_evaluate.total_average().count >= 100  # 122 with torch==1.10.1
-    " ""For curiosity only... the printed FunctionEvents differ by (name, # of Calls)
-    print(diff_fit.table(sort_by="cpu_time_total"))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                             aten::isfinite         3.35%      55.000us        19.55%     321.000us     107.000us           3 b         -16 b             3
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         9.07%     149.000us        17.05%     280.000us      70.000us       1.55 Kb         -64 b             4
-                                                  aten::abs         7.19%     118.000us        12.55%     206.000us      25.750us          24 b          12 b             8
-                                     aten::l1_loss_backward         2.07%      34.000us         9.62%     158.000us      79.000us         400 b          -4 b             2
-                                                aten::stack         2.25%      37.000us         8.71%     143.000us      28.600us       1.57 Kb           0 b             5
-                                                aten::empty         8.47%     139.000us         8.47%     139.000us       4.793us         544 b         544 b            29
-                                                   aten::to         1.71%      28.000us         8.47%     139.000us       8.688us          44 b           4 b            16
-                                             aten::_to_copy         1.89%      31.000us         7.00%     115.000us      11.500us          40 b           0 b            10
-                                                   aten::mm         6.46%     106.000us         6.58%     108.000us      36.000us       1.17 Kb       1.17 Kb             3
-                                                   aten::ne         4.26%      70.000us         6.46%     106.000us      35.333us           3 b          -9 b             3
-                                                aten::zeros         1.95%      32.000us         6.03%      99.000us      12.375us          32 b           0 b             8
-    autograd::engine::evaluate_function: L1LossBackward0...         0.37%       6.000us         5.97%      98.000us      98.000us         396 b          -4 b             1
-                                            L1LossBackward0         0.24%       4.000us         5.60%      92.000us      92.000us         400 b           0 b             1
-                                                  aten::cat         0.97%      16.000us         5.05%      83.000us      16.600us       1.57 Kb           0 b             5
-                                                  aten::div         2.80%      46.000us         4.81%      79.000us      26.333us          12 b           0 b             3
-                                    Optimizer.step#SGD.step         3.71%      61.000us         4.51%      74.000us      74.000us          -4 b         -20 b             1
-                                                 aten::_cat         2.13%      35.000us         4.08%      67.000us      13.400us       1.57 Kb           0 b             5
-                                        aten::empty_strided         3.41%      56.000us         3.41%      56.000us       5.091us          44 b          44 b            11
-           autograd::engine::evaluate_function: MmBackward0         0.37%       6.000us         3.41%      56.000us      56.000us           0 b        -400 b             1
-                                                    aten::t         2.13%      35.000us         3.11%      51.000us      10.200us           0 b           0 b             5
-                                                 aten::add_         3.05%      50.000us         3.05%      50.000us       8.333us           0 b           0 b             6
-                                                MmBackward0         0.43%       7.000us         3.05%      50.000us      50.000us         400 b           0 b             1
-                                                  aten::mul         2.01%      33.000us         2.62%      43.000us      10.750us           5 b          -3 b             4
-                                                  aten::sum         2.01%      33.000us         2.50%      41.000us      13.667us          40 b          40 b             3
-          autograd::engine::evaluate_function: DivBackward0         0.55%       9.000us         2.38%      39.000us      39.000us          -4 b          -8 b             1
-                                                 aten::norm         2.13%      35.000us         2.38%      39.000us      13.000us          12 b          12 b             3
-                          Optimizer.zero_grad#SGD.zero_grad         1.46%      24.000us         1.89%      31.000us      31.000us          -4 b          -4 b             1
-                                               DivBackward0         0.24%       4.000us         1.83%      30.000us      30.000us           4 b           0 b             1
-                                                 aten::item         1.04%      17.000us         1.77%      29.000us       4.143us           0 b           0 b             7
-          autograd::engine::evaluate_function: AddBackward0         0.43%       7.000us         1.77%      29.000us      29.000us          40 b           0 b             1
-                                              aten::random_         1.71%      28.000us         1.71%      28.000us      14.000us           0 b           0 b             2
-                                                  aten::sub         1.71%      28.000us         1.71%      28.000us       9.333us         800 b         800 b             3
-                                                aten::copy_         1.71%      28.000us         1.71%      28.000us       2.800us           0 b           0 b            10
-    autograd::engine::evaluate_function: torch::autograd...         0.30%       5.000us         1.71%      28.000us      14.000us        -440 b           0 b             2
-                                              aten::resize_         1.64%      27.000us         1.64%      27.000us       5.400us       1.57 Kb       1.57 Kb             5
-                                                  aten::add         0.85%      14.000us         1.64%      27.000us      27.000us           4 b           0 b             1
-                                                   aten::eq         1.58%      26.000us         1.58%      26.000us       8.667us           3 b           3 b             3
-                                            aten::unsqueeze         0.91%      15.000us         1.40%      23.000us       3.833us           0 b           0 b             6
-                            torch::autograd::AccumulateGrad         0.97%      16.000us         1.40%      23.000us      11.500us        -440 b        -440 b             2
-                                           aten::is_nonzero         0.43%       7.000us         1.40%      23.000us       7.667us           0 b           0 b             3
-                                               aten::detach         0.55%       9.000us         1.28%      21.000us       3.500us           0 b           0 b             6
-                                           aten::as_strided         1.16%      19.000us         1.16%      19.000us       1.188us           0 b           0 b            16
-                                                 aten::view         1.04%      17.000us         1.04%      17.000us       3.400us           0 b           0 b             5
-                                            aten::transpose         0.67%      11.000us         0.97%      16.000us       3.200us           0 b           0 b             5
-                                           aten::empty_like         0.37%       6.000us         0.91%      15.000us       7.500us         404 b           0 b             2
-                                                     detach         0.73%      12.000us         0.85%      14.000us       2.333us           0 b           0 b             6
-                                                aten::clamp         0.61%      10.000us         0.85%      14.000us      14.000us           4 b           4 b             1
-                                  aten::_local_scalar_dense         0.73%      12.000us         0.79%      13.000us       1.857us           0 b           0 b             7
-                                                 aten::mul_         0.79%      13.000us         0.79%      13.000us       4.333us           0 b           0 b             3
-                                            aten::ones_like         0.24%       4.000us         0.73%      12.000us      12.000us           4 b           0 b             1
-                                           aten::zeros_like         0.30%       5.000us         0.73%      12.000us      12.000us         400 b           0 b             1
-    autograd::engine::evaluate_function: UnsafeViewBackw...         0.18%       3.000us         0.73%      12.000us      12.000us           0 b           0 b             1
-                                        UnsafeViewBackward0         0.12%       2.000us         0.55%       9.000us       9.000us           0 b           0 b             1
-                                           aten::reciprocal         0.49%       8.000us         0.49%       8.000us       8.000us           4 b           4 b             1
-                                              aten::reshape         0.24%       4.000us         0.43%       7.000us       7.000us           0 b           0 b             1
-            autograd::engine::evaluate_function: TBackward0         0.12%       2.000us         0.43%       7.000us       7.000us           0 b           0 b             1
-                                                aten::zero_         0.37%       6.000us         0.37%       6.000us       0.545us           0 b           0 b            11
-                                                aten::fill_         0.37%       6.000us         0.37%       6.000us       1.500us           0 b           0 b             4
-                                                 aten::sgn_         0.30%       5.000us         0.30%       5.000us       5.000us           0 b           0 b             1
-                                                 TBackward0         0.06%       1.000us         0.30%       5.000us       5.000us           0 b           0 b             1
-                                            aten::clamp_max         0.24%       4.000us         0.24%       4.000us       4.000us           0 b           0 b             1
-                                       aten::_reshape_alias         0.18%       3.000us         0.18%       3.000us       3.000us           0 b           0 b             1
-                                         aten::resolve_conj         0.12%       2.000us         0.12%       2.000us       0.333us           0 b           0 b             6
-                                               AddBackward0         0.06%       1.000us         0.06%       1.000us       1.000us           0 b           0 b             1
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 1.642ms
-
-
-    print(diff_evaluate.table(sort_by="cpu_time_total"))
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-                                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-    enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        16.03%     210.000us        34.05%     446.000us     223.000us       1.55 Kb         -40 b             2
-                                             aten::isfinite         5.88%      77.000us        25.42%     333.000us     166.500us           2 b         -12 b             2
-                                                aten::stack         5.34%      70.000us        16.41%     215.000us      53.750us       1.56 Kb           0 b             4
-                                                   aten::ne         7.33%      96.000us        10.84%     142.000us      71.000us           2 b          -6 b             2
-                                                aten::empty        10.08%     132.000us        10.08%     132.000us       8.250us          80 b          80 b            16
-                                                   aten::to         1.37%      18.000us         8.63%     113.000us      18.833us          16 b           0 b             6
-                                                aten::zeros         3.13%      41.000us         8.55%     112.000us      28.000us          16 b           0 b             4
-                                                  aten::cat         1.45%      19.000us         8.40%     110.000us      27.500us       1.56 Kb           0 b             4
-                                                   aten::mm         7.02%      92.000us         7.33%      96.000us      48.000us         800 b         800 b             2
-                                                  aten::abs         4.58%      60.000us         7.33%      96.000us      16.000us          16 b           8 b             6
-                                             aten::_to_copy         2.21%      29.000us         7.25%      95.000us      23.750us          16 b           0 b             4
-                                                 aten::_cat         2.82%      37.000us         6.95%      91.000us      22.750us       1.56 Kb           0 b             4
-                                              aten::resize_         3.51%      46.000us         3.51%      46.000us      11.500us       1.56 Kb       1.56 Kb             4
-                                        aten::empty_strided         3.36%      44.000us         3.36%      44.000us      11.000us          16 b          16 b             4
-                                                  aten::sub         2.98%      39.000us         2.98%      39.000us      19.500us         800 b         800 b             2
-                                                  aten::sum         2.14%      28.000us         2.90%      38.000us      19.000us           0 b           0 b             2
-                                                 aten::add_         2.82%      37.000us         2.82%      37.000us      18.500us           0 b           0 b             2
-                                                    aten::t         1.68%      22.000us         2.75%      36.000us      18.000us           0 b           0 b             2
-                                            aten::unsqueeze         1.98%      26.000us         2.67%      35.000us       8.750us           0 b           0 b             4
-                                                   aten::eq         2.37%      31.000us         2.37%      31.000us      15.500us           2 b           2 b             2
-                                                  aten::mul         2.37%      31.000us         2.37%      31.000us      15.500us           2 b           2 b             2
-                                           aten::is_nonzero         0.53%       7.000us         1.76%      23.000us      11.500us           0 b           0 b             2
-                                                 aten::item         0.92%      12.000us         1.76%      23.000us       5.750us           0 b           0 b             4
-                                                aten::copy_         1.68%      22.000us         1.68%      22.000us       5.500us           0 b           0 b             4
-                                           aten::as_strided         1.37%      18.000us         1.37%      18.000us       2.250us           0 b           0 b             8
-                                                 aten::view         1.30%      17.000us         1.30%      17.000us       4.250us           0 b           0 b             4
-                                               aten::detach         0.31%       4.000us         1.15%      15.000us       7.500us           0 b           0 b             2
-                                            aten::transpose         0.69%       9.000us         1.07%      14.000us       7.000us           0 b           0 b             2
-                                                     detach         0.84%      11.000us         0.84%      11.000us       5.500us           0 b           0 b             2
-                                  aten::_local_scalar_dense         0.84%      11.000us         0.84%      11.000us       2.750us           0 b           0 b             4
-                                                aten::fill_         0.46%       6.000us         0.46%       6.000us       3.000us           0 b           0 b             2
-                                                aten::zero_         0.31%       4.000us         0.31%       4.000us       1.000us           0 b           0 b             4
-                                         aten::resolve_conj         0.31%       4.000us         0.31%       4.000us       1.000us           0 b           0 b             4
-    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
-    Self CPU time total: 1.310ms
-    """
-
-    # set hook afterwards
-    brain_analyst_raw = SimpleBrain(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    brain_optimiser_raw = SimpleBrain(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    brain_analyst_raw.fit(
-        epoch_counter=range(10), train_set=train_set, valid_set=valid_set
-    )
-    profile_analyst(brain_analyst_raw)
-    brain_analyst_raw.evaluate(test_set=test_set)
-    assert getattr(brain_analyst_raw.profiler, "record_shapes") is True
-    assert getattr(brain_analyst_raw.profiler, "with_stack") is True
-    assert getattr(brain_analyst_raw.profiler, "with_flops") is True
-
-    brain_optimiser_raw.fit(
-        epoch_counter=range(10), train_set=train_set, valid_set=valid_set
-    )
-    profile_optimiser(brain_optimiser_raw)
-    brain_optimiser_raw.evaluate(test_set=test_set)
-    assert getattr(brain_optimiser_raw.profiler, "record_shapes") is False
-    assert getattr(brain_optimiser_raw.profiler, "with_stack") is False
-    assert getattr(brain_optimiser_raw.profiler, "with_flops") is False
-
-    # wrap functions
-    @profile_analyst
-    def train_analyst(brain: SimpleBrain):
-        brain.fit(
-            epoch_counter=range(10), train_set=train_set, valid_set=valid_set
-        )
-
-    @export
-    @profile_optimiser
-    def evaluate_optimiser(brain: SimpleBrain):
-        brain.evaluate(test_set=test_set)
-
-    brain_raw = SimpleBrain(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    assert brain_raw.profiler is None
-    train_analyst(brain_raw)
-    assert brain_raw.profiler is None
-    evaluate_optimiser(brain_raw)
-    assert brain_raw.profiler is None
-
-    # profile classes
-    @export
-    @profile_analyst
-    class SimpleBrainAnalyst(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    @profile_optimiser
-    class SimpleBrainOptimiser(Brain):
-        def compute_forward(self, batch, stage):
-            return self.modules.model(batch[0])
-
-        def compute_objectives(self, predictions, batch, stage):
-            return torch.nn.functional.l1_loss(predictions, batch[1])
-
-    simple_brain_analyst = SimpleBrainAnalyst(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    assert getattr(simple_brain_analyst.profiler, "record_shapes") is True
-    assert getattr(simple_brain_analyst.profiler, "with_stack") is True
-    assert getattr(simple_brain_analyst.profiler, "with_flops") is True
-    simple_brain_analyst.evaluate(test_set=test_set)
-
-    simple_brain_optimiser = SimpleBrainOptimiser(
-        {"model": model}, lambda x: SGD(x, 0.1), run_opts={"device": device}
-    )
-    assert getattr(simple_brain_optimiser.profiler, "record_shapes") is False
-    assert getattr(simple_brain_optimiser.profiler, "with_stack") is False
-    assert getattr(simple_brain_optimiser.profiler, "with_flops") is False
-    simple_brain_optimiser.fit(
-        epoch_counter=range(10), train_set=train_set, valid_set=valid_set
-    )
diff --git a/tests/unittests/test_rescorer.py b/tests/unittests/test_rescorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6e1f3bc54d619fefadd7b023fb6154b21ca53f3
--- /dev/null
+++ b/tests/unittests/test_rescorer.py
@@ -0,0 +1,274 @@
+def test_rnnlmrescorer(tmpdir, device):
+    import torch
+    from sentencepiece import SentencePieceProcessor
+    from speechbrain.lobes.models.RNNLM import RNNLM
+    from speechbrain.utils.parameter_transfer import Pretrainer
+
+    source = "speechbrain/asr-crdnn-rnnlm-librispeech"
+    lm_model_path = source + "/lm.ckpt"
+    tokenizer_path = source + "/tokenizer.ckpt"
+
+    # Define your tokenizer and RNNLM from the HF hub
+    tokenizer = SentencePieceProcessor()
+    lm_model = RNNLM(
+        output_neurons=1000,
+        embedding_dim=128,
+        activation=torch.nn.LeakyReLU,
+        dropout=0.0,
+        rnn_layers=2,
+        rnn_neurons=2048,
+        dnn_blocks=1,
+        dnn_neurons=512,
+        return_hidden=True,
+    )
+
+    pretrainer = Pretrainer(
+        collect_in=tmpdir,
+        loadables={"lm": lm_model, "tokenizer": tokenizer},
+        paths={"lm": lm_model_path, "tokenizer": tokenizer_path},
+    )
+
+    pretrainer.collect_files()
+    pretrainer.load_collected()
+
+    from speechbrain.decoders.scorer import RNNLMRescorer, RescorerBuilder
+
+    rnnlm_rescorer = RNNLMRescorer(
+        language_model=lm_model,
+        tokenizer=tokenizer,
+        temperature=1.0,
+        bos_index=0,
+        eos_index=0,
+        pad_index=0,
+    )
+
+    # Define a rescorer builder
+    rescorer = RescorerBuilder(
+        rescorers=[rnnlm_rescorer], weights={"rnnlm": 1.0}
+    )
+
+    # Topk hypotheses
+    topk_hyps = [["HELLO", "HE LLO", "H E L L O"]]
+    topk_scores = [[-2, -2, -2]]
+    rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores)
+
+    # check all hyps are still there
+    for hyp in topk_hyps[0]:
+        assert hyp in rescored_hyps[0]
+
+    # check rescored_scores are sorted
+    for i in range(len(rescored_scores[0]) - 1):
+        assert rescored_scores[0][i] >= rescored_scores[0][i + 1]
+
+    # check normalized_text is working
+    text = "hello"
+    normalized_text = rnnlm_rescorer.normalize_text(text)
+    assert normalized_text == text.upper()
+
+    # check lm is on the right device
+    rnnlm_rescorer.to_device(device)
+    assert rnnlm_rescorer.lm.parameters().__next__().device.type == device
+
+    # check preprocess_func
+    padded_hyps, enc_hyps_length = rnnlm_rescorer.preprocess_func(topk_hyps)
+    assert padded_hyps.shape[0] == 3
+    assert len(padded_hyps) == 3
+
+
+def test_transformerlmrescorer(tmpdir, device):
+    import torch
+    from sentencepiece import SentencePieceProcessor
+    from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM
+    from speechbrain.utils.parameter_transfer import Pretrainer
+
+    source = "speechbrain/asr-transformer-transformerlm-librispeech"
+    lm_model_path = source + "/lm.ckpt"
+    tokenizer_path = source + "/tokenizer.ckpt"
+    tokenizer = SentencePieceProcessor()
+
+    lm_model = TransformerLM(
+        vocab=5000,
+        d_model=768,
+        nhead=12,
+        num_encoder_layers=12,
+        num_decoder_layers=0,
+        d_ffn=3072,
+        dropout=0.0,
+        activation=torch.nn.GELU,
+        normalize_before=False,
+    )
+
+    pretrainer = Pretrainer(
+        collect_in=tmpdir,
+        loadables={"lm": lm_model, "tokenizer": tokenizer},
+        paths={"lm": lm_model_path, "tokenizer": tokenizer_path},
+    )
+
+    _ = pretrainer.collect_files()
+    pretrainer.load_collected()
+
+    from speechbrain.decoders.scorer import (
+        TransformerLMRescorer,
+        RescorerBuilder,
+    )
+
+    transformerlm_rescorer = TransformerLMRescorer(
+        language_model=lm_model,
+        tokenizer=tokenizer,
+        temperature=1.0,
+        bos_index=1,
+        eos_index=2,
+        pad_index=0,
+    )
+
+    rescorer = RescorerBuilder(
+        rescorers=[transformerlm_rescorer], weights={"transformerlm": 1.0}
+    )
+
+    # Topk hypotheses
+    topk_hyps = [["HELLO", "HE LLO", "H E L L O"]]
+    topk_scores = [[-2, -2, -2]]
+    rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores)
+
+    # check all hyps are still there
+    for hyp in topk_hyps[0]:
+        assert hyp in rescored_hyps[0]
+
+    # check rescored_scores are sorted
+    for i in range(len(rescored_scores[0]) - 1):
+        assert rescored_scores[0][i] >= rescored_scores[0][i + 1]
+
+    # check normalized_text is working
+    text = "hello"
+    normalized_text = transformerlm_rescorer.normalize_text(text)
+    assert normalized_text == text.upper()
+
+    # check lm is on the right device
+    transformerlm_rescorer.to_device(device)
+    assert (
+        transformerlm_rescorer.lm.parameters().__next__().device.type == device
+    )
+
+    # check preprocess_func
+    padded_hyps, enc_hyps_length = transformerlm_rescorer.preprocess_func(
+        topk_hyps
+    )
+    assert padded_hyps.shape[0] == 3
+    assert len(padded_hyps) == 3
+
+
+def test_huggingfacelmrescorer(device):
+    from speechbrain.decoders.scorer import (
+        HuggingFaceLMRescorer,
+        RescorerBuilder,
+    )
+
+    source = "gpt2-medium"
+
+    huggingfacelm_rescorer = HuggingFaceLMRescorer(model_name=source,)
+
+    rescorer = RescorerBuilder(
+        rescorers=[huggingfacelm_rescorer], weights={"huggingfacelm": 1.0}
+    )
+
+    # Topk hypotheses
+    topk_hyps = [["HELLO", "HE LLO", "H E L L O"]]
+    topk_scores = [[-2, -2, -2]]
+    rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores)
+
+    # check all hyps are still there
+    for hyp in topk_hyps[0]:
+        assert hyp in rescored_hyps[0]
+
+    # check rescored_scores are sorted
+    for i in range(len(rescored_scores[0]) - 1):
+        assert rescored_scores[0][i] >= rescored_scores[0][i + 1]
+
+    # check normalized_text is working
+    text = "hello"
+    normalized_text = huggingfacelm_rescorer.normalize_text(text)
+    assert normalized_text == text
+
+    # check lm is on the right device
+    huggingfacelm_rescorer.to_device(device)
+    assert huggingfacelm_rescorer.lm.device.type == device
+
+    # check preprocess_func
+    padded_hyps = huggingfacelm_rescorer.preprocess_func(topk_hyps)
+    assert padded_hyps.input_ids.shape[0] == 3
+
+
+def test_rescorerbuilder(tmpdir, device):
+    import torch
+    from sentencepiece import SentencePieceProcessor
+    from speechbrain.lobes.models.RNNLM import RNNLM
+    from speechbrain.utils.parameter_transfer import Pretrainer
+
+    source = "speechbrain/asr-crdnn-rnnlm-librispeech"
+    lm_model_path = source + "/lm.ckpt"
+    tokenizer_path = source + "/tokenizer.ckpt"
+
+    # Define your tokenizer and RNNLM from the HF hub
+    tokenizer = SentencePieceProcessor()
+    lm_model = RNNLM(
+        output_neurons=1000,
+        embedding_dim=128,
+        activation=torch.nn.LeakyReLU,
+        dropout=0.0,
+        rnn_layers=2,
+        rnn_neurons=2048,
+        dnn_blocks=1,
+        dnn_neurons=512,
+        return_hidden=True,
+    )
+
+    pretrainer = Pretrainer(
+        collect_in=tmpdir,
+        loadables={"lm": lm_model, "tokenizer": tokenizer},
+        paths={"lm": lm_model_path, "tokenizer": tokenizer_path},
+    )
+
+    pretrainer.collect_files()
+    pretrainer.load_collected()
+
+    from speechbrain.decoders.scorer import (
+        RNNLMRescorer,
+        RescorerBuilder,
+        HuggingFaceLMRescorer,
+    )
+
+    rnnlm_rescorer = RNNLMRescorer(
+        language_model=lm_model,
+        tokenizer=tokenizer,
+        temperature=1.0,
+        bos_index=0,
+        eos_index=0,
+        pad_index=0,
+    )
+
+    source = "gpt2-medium"
+
+    huggingfacelm_rescorer = HuggingFaceLMRescorer(model_name=source,)
+
+    # check combine both rescorers
+    rescorer = RescorerBuilder(
+        rescorers=[rnnlm_rescorer, huggingfacelm_rescorer],
+        weights={"rnnlm": 1.0, "huggingfacelm": 1.0},
+    )
+    rescorer.move_rescorers_to_device(device)
+    # check lm is on the right device
+    assert rnnlm_rescorer.lm.parameters().__next__().device.type == device
+    assert huggingfacelm_rescorer.lm.device.type == device
+
+    # Topk hypotheses
+    topk_hyps = [["HELLO", "HE LLO", "H E L L O"]]
+    topk_scores = [[-2, -2, -2]]
+    rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores)
+
+    # check all hyps are still there
+    for hyp in topk_hyps[0]:
+        assert hyp in rescored_hyps[0]
+
+    # check rescored_scores are sorted
+    for i in range(len(rescored_scores[0]) - 1):
+        assert rescored_scores[0][i] >= rescored_scores[0][i + 1]
diff --git a/tests/unittests/test_samplers.py b/tests/unittests/test_samplers.py
index c145b0cdc13aec2a0a9a34ef90f8090ccf242688..e7ba03fffff602c2fb07053cea66a33f946e492f 100644
--- a/tests/unittests/test_samplers.py
+++ b/tests/unittests/test_samplers.py
@@ -46,3 +46,70 @@ def test_ConcatDatasetBatchSampler(device):
     non_cat_data = [x[:minlen] for x in non_cat_data]
     non_cat_data = np.array(non_cat_data)
     np.testing.assert_array_equal(non_cat_data.T, concat_data)
+
+    # check DynamicBatchSampler
+    from speechbrain.dataio.sampler import DynamicBatchSampler
+    from speechbrain.dataio.dataset import DynamicItemDataset
+    from speechbrain.dataio.dataloader import SaveableDataLoader
+    from speechbrain.dataio.batch import PaddedBatch
+
+    max_batch_length = 4
+    num_buckets = 5
+
+    item_lengths = [1, 2, 3, 4, 5, 6, 7]
+    items = [[length] * length for length in item_lengths]
+
+    dataset = {
+        "ex_{}".format(length): {"wav": torch.tensor(item), "duration": length}
+        for item, length in zip(items, item_lengths)
+    }
+    dataset = DynamicItemDataset(dataset)
+    dataset.set_output_keys(["wav"])
+
+    bsampler = DynamicBatchSampler(
+        dataset,
+        max_batch_length,
+        num_buckets,
+        lambda x: x["duration"],
+        shuffle=False,
+        batch_ordering="ascending",
+    )
+
+    dataloader = SaveableDataLoader(
+        dataset, batch_sampler=bsampler, collate_fn=PaddedBatch
+    )
+
+    assert next(iter(dataloader))["wav"].data.shape == torch.Size([1, 1])
+
+    bsampler = DynamicBatchSampler(
+        dataset,
+        max_batch_length,
+        num_buckets,
+        lambda x: x["duration"],
+        shuffle=False,
+        batch_ordering="descending",
+    )
+
+    dataloader = SaveableDataLoader(
+        dataset, batch_sampler=bsampler, collate_fn=PaddedBatch
+    )
+    assert next(iter(dataloader))["wav"].data.shape == torch.Size([1, 7])
+
+    max_batch_length = 10
+    num_buckets = 5
+
+    bsampler = DynamicBatchSampler(
+        dataset,
+        max_batch_length,
+        num_buckets,
+        lambda x: x["duration"],
+        shuffle=False,
+        batch_ordering="ascending",
+    )
+
+    dataloader = SaveableDataLoader(
+        dataset, batch_sampler=bsampler, collate_fn=PaddedBatch
+    )
+
+    for b in dataloader:
+        assert b["wav"].data.shape[1] <= max_batch_length
diff --git a/tests/unittests/test_schedulers.py b/tests/unittests/test_schedulers.py
index df81aa4ad245119518c27c9b782ace4cfacf1a54..ce3d76fae5897083a4d7dff5cc84fbce9056ebdc 100755
--- a/tests/unittests/test_schedulers.py
+++ b/tests/unittests/test_schedulers.py
@@ -26,3 +26,28 @@ def test_NewBobScheduler():
     prev_lr, next_lr = scheduler(1.1)
     assert next_lr == 0.4
     assert scheduler.current_patient == 3
+
+
+def test_WarmAndExpDecayLRSchedule():
+
+    from speechbrain.nnet.schedulers import WarmAndExpDecayLRSchedule
+    from speechbrain.nnet.linear import Linear
+    import torch
+
+    model = Linear(input_size=3, n_neurons=4)
+    optim = torch.optim.Adam(model.parameters(), lr=1)
+    scheduler = WarmAndExpDecayLRSchedule(
+        lr=1, n_warmup_steps=2, decay_factor=0.01, total_steps=6
+    )
+
+    scheduler(optim)
+    assert optim.param_groups[0]["lr"] == 0.0
+
+    scheduler(optim)
+    assert optim.param_groups[0]["lr"] == 0.5
+
+    scheduler(optim)
+    assert optim.param_groups[0]["lr"] == 1
+
+    scheduler(optim)
+    assert optim.param_groups[0]["lr"] == 0.31622776601683794
diff --git a/tests/unittests/test_streaming.py b/tests/unittests/test_streaming.py
new file mode 100644
index 0000000000000000000000000000000000000000..aec681b300cdbff69cee69aae464da7af31c4841
--- /dev/null
+++ b/tests/unittests/test_streaming.py
@@ -0,0 +1,47 @@
+import torch
+
+
+def test_streaming_feature_wrapper(device):
+    from speechbrain.lobes.features import StreamingFeatureWrapper
+    from speechbrain.utils.filter_analysis import FilterProperties
+    from speechbrain.utils.streaming import split_fixed_chunks
+
+    # dummy filter that lies about its properties
+    class DummySumModule(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x: torch.Tensor):
+            return x
+
+    props = FilterProperties(window_size=5, stride=2)
+
+    m = StreamingFeatureWrapper(DummySumModule(), props)
+
+    chunk_size = 3
+    chunk_size_frames = (props.stride - 1) * chunk_size
+
+    x = torch.tensor(
+        [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]], device=device
+    )
+
+    chunks = split_fixed_chunks(x, chunk_size_frames)
+    assert len(chunks) == 3
+
+    ctx = m.make_streaming_context()
+    outs = [m(chunk, ctx) for chunk in chunks]
+
+    # the streaming feature wrapper will truncate output module frames that are
+    # centered on "padding" frames (as described in the code)
+
+    # currently, the expected output is as follows:
+    assert torch.allclose(outs[0], torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]))
+    assert torch.allclose(outs[1], torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]))
+    assert torch.allclose(outs[2], torch.tensor([4.0, 5.0, 6.0, 7.0, 8.0]))
+
+    # thus we have outputs centered on:
+    # [0, 0, 2]
+    # [1, 3, 5]
+    # [4, 6, 8]
+    # which preserves the filter properties as expected, and the chunk size we
+    # requested.
diff --git a/tests/unittests/test_transformer_src_tgt_masks.py b/tests/unittests/test_transformer_src_tgt_masks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fde7936bd5168f8e5e0fbde10cfa123a093529b
--- /dev/null
+++ b/tests/unittests/test_transformer_src_tgt_masks.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn
+
+
+def test_make_transformer_src_tgt_masks(device):
+
+    from speechbrain.lobes.models.transformer.TransformerASR import (
+        make_transformer_src_tgt_masks,
+    )
+    from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+    from numpy import inf
+
+    config = DynChunkTrainConfig(chunk_size=4, left_context_size=3)
+
+    x = torch.rand(1, 18)
+    tgt = torch.rand(18, 18)
+    tgt[:, 15:] = 0
+
+    (
+        _,
+        tgt_key_padding_mask,
+        src_mask,
+        tgt_mask,
+    ) = make_transformer_src_tgt_masks(x, tgt, dynchunktrain_config=config)
+
+    # fmt: off
+    # flake8: noqa
+    expected_src_mask = torch.tensor(
+        [[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True,],[True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,],[True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,],]
+    )
+    expected_key_padding_mask = torch.tensor(
+        [[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True,],]
+    )
+
+    expected_tgt_mask = torch.tensor(
+        [[0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -inf,],[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,],]
+    )
+    # fmt: on
+
+    assert torch.all(torch.eq(src_mask, expected_src_mask))
+    assert torch.all(torch.eq(tgt_key_padding_mask, expected_key_padding_mask))
+    assert torch.all(torch.eq(tgt_mask, expected_tgt_mask))
+
+
+def test_make_transformer_src_mask(device):
+
+    from speechbrain.lobes.models.transformer.TransformerASR import (
+        make_transformer_src_mask,
+    )
+    from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
+
+    x = torch.rand(1, 18)
+
+    config = DynChunkTrainConfig(chunk_size=4, left_context_size=3)
+
+    src_mask = make_transformer_src_mask(x, False, config)
+
+    # fmt: off
+    # flake8: noqa
+    expected = torch.tensor(
+        [[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True,],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True,],[True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,],[True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,],]
+    )
+    # fmt: on
+
+    assert torch.all(torch.eq(src_mask, expected))
+
+
+def test_get_lookahead_mask(device):
+
+    from speechbrain.lobes.models.transformer.Transformer import (
+        get_lookahead_mask,
+    )
+    from numpy import inf
+
+    # fmt: off
+    # flake8: noqa
+    x = torch.LongTensor([[1, 1, 0], [2, 3, 0], [4, 5, 0]])
+
+    out = get_lookahead_mask(x)
+
+    expected = torch.tensor(
+        [[0.0, -inf, -inf], [0.0, 0.0, -inf], [0.0, 0.0, 0.0]]
+    )
+    # fmt: on
+
+    assert torch.all(torch.eq(out, expected))
+
+
+def test_get_key_padding_mask(device):
+
+    from speechbrain.lobes.models.transformer.Transformer import (
+        get_key_padding_mask,
+    )
+
+    # fmt: off
+    # flake8: noqa
+    x = torch.LongTensor([[1, 1, 0], [2, 3, 0], [4, 5, 0]])
+
+    out = get_key_padding_mask(x, 0)
+
+    expected = torch.tensor(
+        [[False, False, True], [False, False, True], [False, False, True]]
+    )
+    # fmt: on
+
+    assert torch.all(torch.eq(out, expected))
diff --git a/tests/utils/check_url.py b/tests/utils/check_url.py
index bb5159e875a1c9b1d2fbfe0cb0239116c481a24f..cdf31ce8355c22d79b18d173831622f59b3a5c1d 100644
--- a/tests/utils/check_url.py
+++ b/tests/utils/check_url.py
@@ -65,7 +65,7 @@ def get_all_urls(file_lst, avoid_urls):
     for path in file_lst:
         if ".gz" in path:
             continue
-
+        print(path)
         urls = get_url(path)
 
         for url in urls:
diff --git a/tests/utils/check_yaml.py b/tests/utils/check_yaml.py
index ae2894b83c6103a1c97767064f5f8d6b94c422dc..8eb1272bab9708ad7b7de955f4f47ff573fb2be6 100644
--- a/tests/utils/check_yaml.py
+++ b/tests/utils/check_yaml.py
@@ -7,6 +7,7 @@ Authors
 
 import os
 import re
+from speechbrain.core import run_opt_defaults
 
 
 def get_yaml_var(hparam_file):
@@ -38,8 +39,8 @@ def get_yaml_var(hparam_file):
             # Remove trailing characters
             line = line.rstrip()
 
-            # Check for variables (e.g., 'key:' or '- !ref')
-            if line.find(":") != -1 or line.find("- !ref") != -1:
+            # Check for variables (e.g., 'key:' or '!ref')
+            if line.find(":") != -1 or line.find("!ref") != -1:
                 var_name = line[: line.find(":")]
                 # The variables to check are like "key:" (we do not need to check
                 # subvariavles as " key:")
@@ -98,6 +99,11 @@ def detect_script_vars(script_file, var_lst):
                         continue  # no need to go through the other cases for this var
                 # case: hparams[f"{dataset}_annotation"] - only that structure at the moment
                 re_match = re.search(r"\[f.\{.*\}(.*).\]", line)
+                # case: getattr(self.hparams, f"{stage.name}_search".lower())
+                if re_match is None:
+                    re_match = re.search(
+                        r"self\.hparams, f\"\{.*\}(.*)\"", line
+                    )
                 if re_match is not None:
                     if re_match.group(1) in var:
                         print(
@@ -176,29 +182,10 @@ def check_yaml_vs_script(hparam_file, script_file):
     detected_vars_train = detect_script_vars(script_file, var_lst)
 
     # Check which variables are declared but not used
-    default_run_opt_keys = [
-        "debug",
-        "debug_batches",
-        "debug_epochs",
-        "device",
-        "cpu",
-        "data_parallel_backend",
-        "distributed_launch",
-        "distributed_backend",
-        "find_unused_parameters",
-        "jit_module_keys",
-        "compile_module_keys",
-        "--compile_mode",
-        "--compile_using_fullgraph",
-        "--compile_using_dynamic_shape_tracing",
-        "auto_mix_prec",
-        "max_grad_norm",
-        "nonfinite_patience",
-        "noprogressbar",
-        "ckpt_interval_minutes",
-        "grad_accumulation_factor",
-        "optimizer_step_limit",
+    default_run_opt_keys = list(run_opt_defaults.keys()) + [
+        "rescoring_lm_scale"
     ]
+
     unused_vars = list(
         set(var_lst) - set(detected_vars_train) - set(default_run_opt_keys)
     )
diff --git a/tests/utils/recipe_tests.py b/tests/utils/recipe_tests.py
index b24e27cfff580ed2fa74e72cd51e45a10c1b7fad..013fd47ca55f9be1640a079092f75d4e6351720c 100644
--- a/tests/utils/recipe_tests.py
+++ b/tests/utils/recipe_tests.py
@@ -138,7 +138,7 @@ def prepare_test(
         ) as csvf:
             reader = csv.DictReader(csvf, delimiter=",", skipinitialspace=True)
             for row_id, row in enumerate(reader):
-                recipe_id = f"{recipe_csvfile[:-4]}_row_{row_id+2}"
+                recipe_id = f"{recipe_csvfile[:-4]}_row_{row_id+2:02d}"
                 if not (
                     check_row_for_test(row, filters_fields, filters, test_field)
                 ):
@@ -502,7 +502,7 @@ def run_recipe_tests(
 
     # Run  script (check how to get std out, std err and save them in files)
     check = True
-    for i, recipe_id in enumerate(test_script.keys()):
+    for i, recipe_id in enumerate(sorted(test_script.keys())):
 
         # Check if the output folder is specified in test_field
         spec_outfold = False
@@ -762,7 +762,7 @@ def load_yaml_test(
     }
 
     # Read the csv recipe file and detect which tests we have to run
-    test_script, test_hparam, test_flag, test_check = prepare_test(
+    test_script, test_hparam, test_flag, test_check, _, _ = prepare_test(
         recipe_folder,
         script_field,
         hparam_field,
diff --git a/tests/utils/refactoring_checks.py b/tests/utils/refactoring_checks.py
index c9f125cee56bb5b43e40ac60f00521416b3b47b6..010e216e0feac721b236bfd232f9d77204293dc6 100644
--- a/tests/utils/refactoring_checks.py
+++ b/tests/utils/refactoring_checks.py
@@ -24,7 +24,7 @@ from torch.utils.data import DataLoader
 from hyperpyyaml import load_hyperpyyaml
 from speechbrain.utils.distributed import run_on_main  # noqa
 from speechbrain.utils.train_logger import FileTrainLogger
-from speechbrain.pretrained.interfaces import foreign_class  # noqa
+from speechbrain.inference.interfaces import foreign_class  # noqa
 from speechbrain.dataio.dataloader import LoopedLoader, make_dataloader
 
 
@@ -93,7 +93,7 @@ def get_model(repo, values, updates_dir=None, run_opts=None):
 
     Returns
     -------
-    A pretrained model with a speechbrain.pretrained.interface or a custom interface.
+    A pretrained model with a speechbrain.inference.interface or a custom interface.
     """
     # get the pretrained class; model & predictions
     kwargs = {
@@ -132,9 +132,9 @@ def get_model(repo, values, updates_dir=None, run_opts=None):
     print(f"\trepo: {repo}")
     # load pretrained model either via specified pretrained class or custom interface
     if "foreign" not in values.keys():
-        print(f'\tspeechbrain.pretrained.{values["cls"]}')
+        print(f'\tspeechbrain.inference.{values["cls"]}')
         print(f"\tobj.from_hparams({kwargs})")
-        obj = eval(f'speechbrain.pretrained.{values["cls"]}')
+        obj = eval(f'speechbrain.inference.{values["cls"]}')
         model = obj.from_hparams(**kwargs)
     else:
         kwargs["pymodule_file"] = values["foreign"]
diff --git a/tools/g2p.py b/tools/g2p.py
index 88d2b960380cc5c3d4a2c6e4544954e21ddcefae..8e757f2bcb436488b3b8701a0195ad55d77c7bf7 100644
--- a/tools/g2p.py
+++ b/tools/g2p.py
@@ -64,7 +64,7 @@ import traceback
 from cmd import Cmd
 from argparse import ArgumentParser
 from hyperpyyaml import load_hyperpyyaml
-from speechbrain.pretrained.interfaces import GraphemeToPhoneme
+from speechbrain.inference.text import GraphemeToPhoneme
 from tqdm.auto import tqdm
 
 MSG_MODEL_NOT_FOUND = "Model path not found"
@@ -77,7 +77,7 @@ def transcribe_text(g2p, text):
 
     Arguments
     ---------
-    g2p: speechbrain.pretrained.interfaces.GrpahemeToPhoneme
+    g2p: speechbrain.inference.text.GrpahemeToPhoneme
         a pretrained G2P model instance
 
     text: str
@@ -91,7 +91,7 @@ def transcribe_file(g2p, text_file_name, output_file_name=None, batch_size=64):
     """
     Transcribes a file with one example per line
 
-    g2p: speechbrain.pretrained.interfaces.GrpahemeToPhoneme
+    g2p: speechbrain.inference.text.GrpahemeToPhoneme
         a pretrained G2P model instance
 
     text_file_name: str
@@ -146,7 +146,7 @@ def transcribe_stream(g2p, text_file, output_file, batch_size=64, total=None):
 
     Arguments
     ---------
-    g2p: speechbrain.pretrained.interfaces.GrpahemeToPhoneme
+    g2p: speechbrain.inference.text.GrpahemeToPhoneme
         a pretrained G2P model instance
     text_file: file
         a file object from which text samples will be read
@@ -208,7 +208,7 @@ class InteractiveG2P(Cmd):
 
     Arguments
     ---------
-    model: speechbrain.pretrained.interfaces.GrpahemeToPhoneme
+    model: speechbrain.inference.text.GrpahemeToPhoneme
         a pretrained G2P model instance
     """
 
@@ -299,7 +299,7 @@ def load_g2p_checkpoint(
 
     Returns
     -------
-    g2p: speechbrain.pretrained.interfaces.GraphemeToPhoneme
+    g2p: speechbrain.inference.text.GraphemeToPhoneme
         a pretrained G2P model, initialized from a checkpoint
     """
     with open(hparams_file_name) as hparams_file:
diff --git a/tools/profiling/profile.py b/tools/profiling/profile.py
index f448f38d138d87b768ec86342e37137938fba71e..1d61d2aebf75f07a6d4cea1a33a07abc5ea50e37 100644
--- a/tools/profiling/profile.py
+++ b/tools/profiling/profile.py
@@ -21,18 +21,16 @@ from speechbrain.utils.profiling import (
     report_time,
     report_memory,
 )
-from speechbrain.pretrained import (
-    Pretrained,
-    EncoderDecoderASR,
-    EncoderASR,
-    EndToEndSLU,
-    EncoderClassifier,
-    SpeakerRecognition,
-    VAD,
-    SepformerSeparation,
-    SpectralMaskEnhancement,
-    SNREstimator,
-)
+
+from speechbrain.inference.interfaces import Pretrained
+from speechbrain.inference.ASR import EncoderDecoderASR, EncoderASR
+from speechbrain.inference.SLU import EndToEndSLU
+from speechbrain.inference.classifiers import EncoderClassifier
+from speechbrain.inference.speaker import SpeakerRecognition
+from speechbrain.inference.VAD import VAD
+from speechbrain.inference.separation import SepformerSeparation
+from speechbrain.inference.enhancement import SpectralMaskEnhancement
+from speechbrain.inference.metrics import SNREstimator
 from typing import Optional, List
 
 
diff --git a/tools/readme_builder.py b/tools/readme_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d2e6e178076b394b0fd8736948e2288a9cc8e6d
--- /dev/null
+++ b/tools/readme_builder.py
@@ -0,0 +1,155 @@
+#!/usr/bin/env python3
+"""
+Script: readme_builder.py
+
+Description:
+    This script creates the PERFORMANCE.md file, containing tables summarizing the performance
+    of models and tasks available in SpeechBrain. It fetches performance data from
+    the tests/recipes/*.csv files, where a special field called "performance" (e.g., Accuracy=85.7%)
+    is expected.
+
+Usage:
+    python readme_builder.py
+
+Authors:
+    - Mirco Ravanelli 2023
+"""
+
+import csv
+import re
+import argparse
+from speechbrain.utils.data_utils import get_all_files
+
+
+def create_table(fid_w, csv_file):
+    """
+    Reads the input CSV file and adds performance tables to the output file.
+
+    Args:
+        fid_w (file pointer): Pointer to the output performance file.
+        csv_file (str): Path to the recipe CSV file containing recipe information
+                        (e.g., 'tests/recipes/LibriSpeech.csv').
+
+    Returns:
+        None
+    """
+
+    # Read CSV file into a list of dictionaries
+    with open(csv_file, "r") as file:
+        csv_reader = csv.DictReader(file)
+        recipes_lst = [row for row in csv_reader]
+
+        dataset = recipes_lst[0].get("Dataset", "")
+        if not recipes_lst or "performance" not in recipes_lst[0]:
+            return
+
+        print(f"## {dataset} Dataset\n", file=fid_w)
+
+    # Filter recipes
+    recipes = {task: [] for task in set(row["Task"] for row in recipes_lst)}
+
+    for recipe_line in recipes_lst:
+        got_performance = len(recipe_line["performance"].strip()) > 0
+
+        if not got_performance:
+            continue
+
+        task = recipe_line["Task"]
+        recipes[task].append(recipe_line)
+
+    # Creating performance tables for each task
+    for task, recipes_task in recipes.items():
+        if not recipes_task:
+            continue  # Skip empty task
+
+        print(f"### {task} \n", file=fid_w)
+
+        performance_dict = extract_name_value_pairs(
+            recipes_task[0]["performance"]
+        )
+        performance_metrics = performance_dict.keys()
+        performance_metrics = " | ".join(performance_metrics) + " | "
+
+        print(
+            f"| Model | Checkpoints | HuggingFace | {performance_metrics}",
+            file=fid_w,
+        )
+        print(
+            "".join(["| --------"] * (3 + len(performance_dict))) + "|",
+            file=fid_w,
+        )
+
+        for recipe in recipes_task:
+            performance_dict = extract_name_value_pairs(recipe["performance"])
+            performance_values = " | ".join(performance_dict.values()) + " | "
+
+            str_res = (
+                f'[here]({recipe["Result_url"]})'
+                if recipe["Result_url"]
+                else "-"
+            )
+            hf_repo = (
+                f'[here]({recipe["HF_repo"]})' if recipe["HF_repo"] else "-"
+            )
+
+            performance_line = f' | {recipe["Hparam_file"]} | {str_res} | {hf_repo} | {performance_values}'
+            print(performance_line, file=fid_w)
+
+        print("\n", file=fid_w)
+
+
+def extract_name_value_pairs(input_string):
+    """
+    Extracts performance metrics and their values from the performance line.
+
+    Args:
+        input_string (str): The string containing the performance.
+
+    Returns:
+        dict: A dictionary containing the detected performance metrics and their values.
+    """
+    pattern = re.compile(r"(\w+(?:-\w+)?)=(\S+)")
+    matches = pattern.findall(input_string)
+    result = {name: value for name, value in matches}
+    return result
+
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser(
+        description=(
+            "Create the performance file from the recipe info csv files."
+        ),
+    )
+
+    parser.add_argument(
+        "--recipe_info_dir",
+        help="The directory where all the csv files containing the recipe info are stored. "
+        "E.g., tests/recipes/",
+    )
+    parser.add_argument(
+        "--output_file",
+        help="The path to the output performance file to create",
+    )
+
+    args = parser.parse_args()
+
+    file_w = open(args.output_file, "w")
+
+    # List of recipe files
+    recipe_files = get_all_files(
+        args.recipe_info_dir, match_and=[".csv"], exclude_or=["~"]
+    )
+
+    header = """# SpeechBrain Performance Report
+    This document provides an overview of the performance achieved on key datasets and tasks supported by SpeechBrain.
+    """
+
+    print(header, file=file_w)
+
+    for csv_file in sorted(recipe_files):
+        create_table(file_w, csv_file)
+
+    file_w.close()
+
+    print(args.output_file + " CREATED!")