diff --git a/recipes/MultiWOZ/response_generation/README.md b/recipes/MultiWOZ/response_generation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8af727633c58dbdb30aefaed8ca020c501600ed2
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/README.md
@@ -0,0 +1,49 @@
+# MultiWOZ Response Generation with GPT2 Model.
+This folder contains the scripts to finetune a gpt based system using MultiWOZ for response generation task.
+You can download MultiWOZ at https://github.com/budzianowski/multiwoz.
+The data will be automatically download in the specified data_folder.
+
+
+## Installing Extra Dependencies
+
+Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:
+
+```
+pip install -r extra_requirements.txt
+```
+
+# How to run
+```
+python train_with_gpt.py hparams/train_gpt.yaml --data_folder=/your/data/folder
+```
+The data will be automatically download in the specified data_folder.
+
+
+# Results
+
+| Model | Release | Hyperparams file | Test Cross-entropy Loss | Test PPL | Test BLEU 4| HuggingFace link | Full model link | GPUs |
+|:-------------:|:-------------:|:---------------------------:| :-----:| :-----:| :-----:| :-----:| :--------:|:--------:|
+| GPT2 | 2023-08-15 | train_gpt.yaml |  1.39 |  4.01 | 2.54e-04 |[model](https://huggingface.co/speechbrain/MultiWOZ-GPT-Response_Generation) | [model](https://www.dropbox.com/sh/vm8f5iavohr4zz9/AACrkOxXuxsrvJy4Cjpih9bQa?dl=0) | 1xV100 16GB |
+
+
+
+
+# **About SpeechBrain**
+- Website: https://speechbrain.github.io/
+- Code: https://github.com/speechbrain/speechbrain/
+- HuggingFace: https://huggingface.co/speechbrain/
+
+# **Citing**
+Please, cite SpeechBrain if you use it for your research or business.
+
+```bibtex
+@misc{speechbrain,
+  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
+  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
+  year={2021},
+  eprint={2106.04624},
+  archivePrefix={arXiv},
+  primaryClass={eess.AS},
+  note={arXiv:2106.04624}
+}
+```
diff --git a/recipes/MultiWOZ/response_generation/extra_requirements.txt b/recipes/MultiWOZ/response_generation/extra_requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..454a21c03558b3210da0efe42581688e9cfc4b1d
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/extra_requirements.txt
@@ -0,0 +1 @@
+sacrebleu
diff --git a/recipes/MultiWOZ/response_generation/hparams/train_gpt.yaml b/recipes/MultiWOZ/response_generation/hparams/train_gpt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..583b275c6610fa78250d9de2c15cdcc9ecef9308
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/hparams/train_gpt.yaml
@@ -0,0 +1,136 @@
+# ########################################
+# Model: GPT2LMHeadModel +  NLL
+# Authors:
+    # Pooneh Mousavi 2023
+    # Simone Alghisi 2023
+# ########################################
+
+# Seed needs to be set at top of yaml, before objects with parameters are made
+seed: 1995
+__set_seed: !apply:torch.manual_seed [!ref <seed>]
+
+# Dataset will be downloaded to the `data_original`
+data_folder: !PLACEHOLDER
+output_folder: !ref results/train_with_gpt2/<seed>
+replacements_path: mapping.pair
+save_folder: !ref <output_folder>/save
+train_log: !ref <output_folder>/train_log.txt
+bleu_4_test_file: !ref <output_folder>/bleu_4_test.txt
+bleu_4_valid_file: !ref <output_folder>/bleu_4_valid.txt
+
+# URL for the gpt2 model
+gpt_hub: gpt2
+gpt_folder: !ref <save_folder>/gpt_checkpoint
+
+# Path where data manifest files will be stored
+train_annotation: !ref <output_folder>/train.json
+valid_annotation: !ref <output_folder>/dev.json
+test_annotation: !ref <output_folder>/test.json
+
+skip_prep: False
+
+# The train logger writes training statistics to a file, as well as stdout.
+train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
+    save_file: !ref <train_log>
+
+# Special tokens
+bos_token: "BOS"
+eos_token: "EOS"
+
+system_token: "SPK_1"
+user_token: "SPK_2"
+
+special_tokens: [
+    !ref <bos_token>,
+    !ref <eos_token>,
+    !ref <system_token>,
+    !ref <user_token>
+]
+
+attr_to_special_tokens:
+    "bos_token": !ref <bos_token>
+    "eos_token": !ref <eos_token>
+    "additional_special_tokens": [!ref <system_token>, !ref <user_token>]
+
+# history_window, i.e. how many user-system exchanges consider as context.
+max_history: 5
+
+ignore_index: -100
+label_smoothing: 0
+
+# Training parameters
+number_of_epochs: 4
+batch_size: 8
+test_batch_size: 4
+lr: 1.97125e-4
+
+#freeze GPT model
+freeze_gptmodel: False
+num_beams: 3
+max_new_tokens: 50
+top_k: 45
+top_p: 0.9
+
+
+train_dataloader_options:
+    batch_size: !ref <batch_size>
+    shuffle: True
+    num_workers: 2
+    drop_last: False
+
+test_dataloader_options:
+    batch_size: !ref <test_batch_size>
+    shuffle: True
+    num_workers: 2
+    drop_last: True
+
+# Masks
+padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
+
+# gpt model
+gpt_model: !new:speechbrain.lobes.models.huggingface_gpt.HuggingFaceGPT
+    source: !ref <gpt_hub>
+    freeze: !ref <freeze_gptmodel>
+    save_path: !ref <gpt_folder>
+    max_new_tokens: !ref <max_new_tokens>
+    num_beams: !ref <num_beams>
+    top_k: !ref  <top_k>
+    top_p: !ref <top_p>
+
+epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
+    limit: !ref <number_of_epochs>
+
+modules:
+    gpt_model: !ref <gpt_model>
+
+model: !new:torch.nn.ModuleList
+    - [!ref <gpt_model>]
+
+
+ce_loss: !new:torch.nn.CrossEntropyLoss
+    ignore_index: !ref <ignore_index>
+    label_smoothing: !ref <label_smoothing>
+
+opt_class: !name:torch.optim.AdamW
+    lr: !ref <lr>
+
+
+lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
+    initial_value: !ref <lr>
+    improvement_threshold: 0.0025
+    annealing_factor: 0.9
+    patient: 0
+
+checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
+    checkpoints_dir: !ref <save_folder>
+    recoverables:
+        gpt_model: !ref <gpt_model>
+        lr_annealing_output: !ref <lr_annealing>
+        counter: !ref <epoch_counter>
+
+
+bleu_4_computer: !name:speechbrain.utils.bleu.BLEUStats
+    max_ngram_order: 4
+
+bleu_2_computer: !name:speechbrain.utils.bleu.BLEUStats
+    max_ngram_order: 2
diff --git a/recipes/MultiWOZ/response_generation/mapping.pair b/recipes/MultiWOZ/response_generation/mapping.pair
new file mode 100644
index 0000000000000000000000000000000000000000..34df41d01e93ce27039e721e1ffb55bf9267e5a2
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/mapping.pair
@@ -0,0 +1,83 @@
+it's	it is
+don't	do not
+doesn't	does not
+didn't	did not
+you'd	you would
+you're	you are
+you'll	you will
+i'm	i am
+they're	they are
+that's	that is
+what's	what is
+couldn't	could not
+i've	i have
+we've	we have
+can't	cannot
+i'd	i would
+i'd	i would
+aren't	are not
+isn't	is not
+wasn't	was not
+weren't	were not
+won't	will not
+there's	there is
+there're	there are
+. .	.
+restaurants	restaurant -s
+hotels	hotel -s
+laptops	laptop -s
+cheaper	cheap -er
+dinners	dinner -s
+lunches	lunch -s
+breakfasts	breakfast -s
+expensively	expensive -ly
+moderately	moderate -ly
+cheaply	cheap -ly
+prices	price -s
+places	place -s
+venues	venue -s
+ranges	range -s
+meals	meal -s
+locations	location -s
+areas	area -s
+policies	policy -s
+children	child -s
+kids	kid -s
+kidfriendly	kid friendly
+cards	card -s
+upmarket	expensive
+inpricey	cheap
+inches	inch -s
+uses	use -s
+dimensions	dimension -s
+driverange	drive range
+includes	include -s
+computers	computer -s
+machines	machine -s
+families	family -s
+ratings	rating -s
+constraints	constraint -s
+pricerange	price range
+batteryrating	battery rating
+requirements	requirement -s
+drives	drive -s
+specifications	specification -s
+weightrange	weight range
+harddrive	hard drive
+batterylife	battery life
+businesses	business -s
+hours	hour -s
+one	1
+two	2
+three	3
+four	4
+five	5
+six	6
+seven	7
+eight	8
+nine	9
+ten	10
+eleven	11
+twelve	12
+anywhere	any where
+good bye	goodbye
diff --git a/recipes/MultiWOZ/response_generation/multiwoz_prepare.py b/recipes/MultiWOZ/response_generation/multiwoz_prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..201cab3288b1ab9ed71d903f89c92500ffe3cf5a
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/multiwoz_prepare.py
@@ -0,0 +1,675 @@
+from itertools import product
+from statistics import mean
+from typing import Any, Dict, List, Optional, Set, Tuple
+import json
+import logging
+import os
+import re
+import shutil
+from tqdm import tqdm
+from speechbrain.utils.data_utils import download_file
+
+"""
+Data preparation.
+Download: https://github.com/budzianowski/multiwoz/tree/master/data
+
+The original one can be found at:
+https://github.com/jasonwu0731/trade-dst/blob/master/create_data.py
+Author
+------
+ * Pooneh Mousavi 2023
+ * Simone Alghisi 2023
+"""
+
+logger = logging.getLogger(__name__)
+MULTIWOZ_21_DATASET_URL = (
+    "https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip"
+)
+
+
+def prepare_mwoz_21(
+    data_folder: str, save_folder: str, replacements_path: str, skip_prep=False,
+) -> None:
+
+    """
+    This class prepares the JSON files for the MultiWOZ dataset.
+    Data will be automatically downloaded in the data_folder.
+    Download link: https://github.com/budzianowski/multiwoz/tree/master/data
+
+    Arguments
+    ---------
+    data_folder : str
+        Path to the folder where the original MultiWOZ dataset is stored.
+    save_folder : str
+        The directory where to store the JSON files.
+    replacements_path: str
+        File containing (from, to) pairs, one per line for preprocessing the text.
+    skip_prep: bool
+        If True, data preparation is skipped.
+
+
+    Example
+    -------
+    >>> data_folder = 'data/MultiWOZ_2.1'
+    >>> save_folder = 'MultiWOZ_prepared'
+    >>> replacements_path = 'mapping.pair'
+    >>> prepare_mwoz_21(data_folder, save_folder, replacements_path)
+    """
+
+    if skip_prep:
+        return
+
+    # Saving folder
+    if not os.path.exists(save_folder):
+        os.makedirs(save_folder)
+
+    # Setting ouput files
+    save_train = save_folder + "/train.json"
+    save_dev = save_folder + "/dev.json"
+    save_test = save_folder + "/test.json"
+
+    # If csv already exists, we skip the data preparation
+    if skip(save_train, save_dev, save_test):
+
+        msg = "%s already exists, skipping data preparation!" % (save_train)
+        logger.info(msg)
+
+        msg = "%s already exists, skipping data preparation!" % (save_dev)
+        logger.info(msg)
+
+        msg = "%s already exists, skipping data preparation!" % (save_test)
+        logger.info(msg)
+
+        return
+
+    # Download dataset
+    download_mwoz_21(data_folder)
+    data_folder = os.path.join(data_folder, "MultiWOZ_21")
+
+    # Additional checks to make sure the data folder contains MultiWOZ
+    check_multiwoz_folders(data_folder)
+
+    data_path = os.path.join(data_folder, "data.json")
+    train_split, dev_split, test_split = get_splits(data_folder)
+    # Creating json files for {train, dev, test} data
+    file_pairs = zip(
+        [train_split, dev_split, test_split], [save_train, save_dev, save_test],
+    )
+
+    for split, save_file in file_pairs:
+        build_dialogue_dataset(
+            data_path, split, save_file, replacements_path,
+        )
+
+
+def check_multiwoz_folders(data_folder):
+    """
+    Check if the data folder actually contains the MultiWOZ dataset.
+    If not, raises an error.
+    Returns
+    -------
+    None
+    Raises
+    ------
+    FileNotFoundError
+        If the data folder doesn't contain the MultiWOZ dataset.
+    """
+    files_str = "/data.json"
+    # Checking clips
+    if not os.path.exists(data_folder + files_str):
+        err_msg = (
+            "the folder %s does not exist (it is expected in "
+            "the MultiWOZ dataset)" % (data_folder + files_str)
+        )
+        raise FileNotFoundError(err_msg)
+
+
+def download_mwoz_21(destination):
+    """ Download the dataset repo, unpack it, and remove unnecessary elements.
+    Arguments
+    ---------
+    destination: str
+        Place to put dataset.
+    """
+    mwoz_21_archive = os.path.join(destination, "MultiWOZ_21.zip")
+    download_file(MULTIWOZ_21_DATASET_URL, mwoz_21_archive)
+    shutil.unpack_archive(mwoz_21_archive, destination)
+    shutil.rmtree(os.path.join(destination, "__MACOSX"))
+
+    mwoz_21 = os.path.join(destination, "MultiWOZ_21")
+    os.makedirs(mwoz_21, exist_ok=True)
+
+    mwoz_21_repo = os.path.join(destination, "MultiWOZ_2.1")
+    for relevant_file in ["data.json", "valListFile.txt", "testListFile.txt"]:
+        shutil.move(
+            os.path.join(mwoz_21_repo, relevant_file),
+            os.path.join(mwoz_21, relevant_file),
+        )
+
+    shutil.rmtree(mwoz_21_repo)
+
+
+def skip(save_train, save_dev, save_test):
+    """
+    Detects if the MultiWOZ data preparation has been already done.
+    If the preparation has been done, we can skip it.
+    Returns
+    -------
+    bool
+        if True, the preparation phase can be skipped.
+        if False, it must be done.
+    """
+
+    # Checking folders and save options
+    skip = False
+
+    if (
+        os.path.isfile(save_train)
+        and os.path.isfile(save_dev)
+        and os.path.isfile(save_test)
+    ):
+        skip = True
+
+    return skip
+
+
+def get_splits(dataset_folder) -> Tuple[List[str], List[str], List[str]]:
+    mwoz_21_dialouges = get_json_object(
+        os.path.join(dataset_folder, "data.json")
+    )
+    dialougues_keys: Set[str] = set(mwoz_21_dialouges.keys())
+    tr_split: List[str] = []
+    with open(os.path.join(dataset_folder, "valListFile.txt")) as f:
+        dev_split: List[str] = [key.strip() for key in f]
+    with open(os.path.join(dataset_folder, "testListFile.txt")) as f:
+        te_split: List[str] = [key.strip() for key in f]
+
+    for key in dialougues_keys:
+        if key not in dev_split and key not in te_split:
+            tr_split.append(key)
+
+    return tr_split, dev_split, te_split
+
+
+def build_dialogue_dataset(
+    data_path: str,
+    data_split: List[str],
+    save_file: str,
+    replacements_path: str,
+) -> None:
+    """
+    Returns the dialogue dataset for the corresponding data_path.
+
+    Arguments
+    ---------
+    data_path: str
+     Path to the folder where the original MultiWOZ dataset is stored.
+    data_split: list of str
+        List of strings containing MultiWOZ 2.1 keys of the dialogues
+        associated with a certain split (train, dev, test).
+    save_file: str
+        Path of the file where the dataset will be saved.
+    replacements_path: str
+        Path to file containing (from, to) pairs, one per line.
+
+    Returns
+    -------
+    dataset:
+        dataset, keys are str, values are dictionaries containing the
+        dialogue history, the system reply, and the mean length.
+    """
+    logger.info(f"Prepare {save_file}")
+    encode_dialogue_dataset(
+        save_file, data_path, data_split, replacements_path,
+    )
+
+
+def encode_dialogue_dataset(
+    save_file: str,
+    data_path: str,
+    data_split: List[str],
+    replacements_path: str,
+) -> None:
+    """
+    Wrapper function that loads processed data stored at
+    dst_folder/file_name. If they are not available, it processes the
+    original data and then saves them at dst_folder/file_name.
+
+    Arguments
+    ---------
+    save_file: str
+        Path of the file where the dataset will be saved.
+    data_path: str
+        Path to the folder where the original MultiWOZ dataset is stored.
+    data_split: list of str
+        List of strings containing MultiWOZ 2.1 keys of the dialogues
+        associated with a certain split (train, dev, test).
+    replacements_path: str
+        Path to file containing (from, to) pairs, one per line.
+    """
+
+    replacements = get_replacements(replacements_path)
+    logger.info(f"Extract dialogues from {data_path}")
+    # custom loading function to return the important elements of a dialogue
+    dialogues = load_dialogues(data_path, data_split, replacements)
+
+    logger.info("Create dataset")
+    dataset = create_dialogue_dataset(dialogues)
+    logger.info(f"Save dataset in {save_file}")
+    save_dialogue_dataset(dataset, save_file)
+
+
+def get_replacements(
+    replacements_path: str = "trade/utils/mapping.pair",
+) -> List[Tuple[str, str]]:
+    """
+    Get the replacements from a given file. Used by trade preprocessing.
+
+    Arguments
+    ---------
+    replacements_path: str
+        File containing from, to pairs, one per line.
+
+    Returns
+    -------
+    replacements: List of replacements, i.e. pairs of str
+        Pairs of elements used to substitute the first element with the second.
+    """
+    replacements = []
+    with open(replacements_path, "r") as fin:
+        for line in fin.readlines():
+            tok_from, tok_to = line.replace("\n", "").split("\t")
+            replacements.append((" " + tok_from + " ", " " + tok_to + " "))
+    return replacements
+
+
+def load_dialogues(
+    data_path: str, data_split: List[str], replacements: List[Tuple[str, str]],
+) -> List[List[Dict[str, Any]]]:
+    """
+    Load dialogues from data_path, apply trade pre-processing, revert the
+    subtokenization, and create a dictionary containing the dialogue id,
+    the turn id, and the corrected sequence.
+
+    Arguments
+    ---------
+    data_path: str
+        Path to the json file containing the data.
+    data_split: list of str
+        List of string containing MultiWOZ 2.1 keys of the dialogues
+        associated to a certain split (train, dev, test).
+    replacements_path: str
+        File containing (from, to) pairs, one per line.
+
+    Returns
+    -------
+    dialogues: list of list of dict, keys are str, values could be anything
+        List of dialogues. Each dialogue is a list of turns. Each turn is a
+        dict containing dialogue_idx, turn_idx, and the corrected sequence.
+    """
+
+    def get_preprocessed_seq(
+        original_seq: str, replacements: List[Tuple[str, str]]
+    ) -> str:
+        # apply trade normalization
+        trade_seq = normalize(original_seq, replacements)
+        # merge back subtokens
+        sequence = invert_trade_subtokenization(original_seq, trade_seq)
+        return sequence
+
+    dialogues: List[List[Dict[str, Any]]] = []
+
+    data = get_json_object(data_path)
+
+    for dialogue_idx in tqdm(data_split, desc="Load Dialogues"):
+        dial: List[Dict[str, Any]] = []
+        original_dialogue: dict = data[dialogue_idx]
+        turns: dict = original_dialogue["log"]
+        for i, turn in enumerate(turns):
+            sequence = get_preprocessed_seq(turn["text"], replacements)
+            to_save = {
+                "sequence": sequence,
+                "turn_idx": i,
+                "dialogue_idx": dialogue_idx,
+            }
+            dial.append(to_save)
+        dialogues.append(dial)
+    return dialogues
+
+
+def normalize(text, replacements):
+    # lower case every word
+    text = text.lower()
+
+    # replace white spaces in front and end
+    text = re.sub(r"^\s*|\s*$", "", text)
+
+    # hotel domain pfb30
+    text = re.sub(r"b&b", "bed and breakfast", text)
+    text = re.sub(r"b and b", "bed and breakfast", text)
+
+    # weird unicode bug
+    text = re.sub("(\u2018|\u2019)", "'", text)
+
+    # replace st.
+    text = text.replace(";", ",")
+    text = re.sub(r"$\/", "", text)
+    text = text.replace("/", " and ")
+
+    # replace other special characters
+    text = text.replace("-", " ")
+    text = re.sub(r'["\<>@\(\)]', "", text)  # remove
+
+    # insert white space before and after tokens:
+    for token in ["?", ".", ",", "!"]:
+        text = insertSpace(token, text)
+
+    # insert white space for 's
+    text = insertSpace("'s", text)
+
+    # replace it's, does't, you'd ... etc
+    text = re.sub("^'", "", text)
+    text = re.sub(r"'$", "", text)
+    text = re.sub(r"'\s", " ", text)
+    text = re.sub(r"\s'", " ", text)
+    for fromx, tox in replacements:
+        text = " " + text + " "
+        text = text.replace(fromx, tox)[1:-1]
+
+    # remove multiple spaces
+    text = re.sub(" +", " ", text)
+
+    # concatenate numbers
+    tokens = text.split()
+    i = 1
+    while i < len(tokens):
+        if re.match(r"^\d+$", tokens[i]) and re.match(r"\d+$", tokens[i - 1]):
+            tokens[i - 1] += tokens[i]
+            del tokens[i]
+        else:
+            i += 1
+    text = " ".join(tokens)
+    return text
+
+
+def insertSpace(token, text):
+    sidx = 0
+    while True:
+        sidx = text.find(token, sidx)
+        if sidx == -1:
+            break
+        if (
+            sidx + 1 < len(text)
+            and re.match("[0-9]", text[sidx - 1])
+            and re.match("[0-9]", text[sidx + 1])
+        ):
+            sidx += 1
+            continue
+        if text[sidx - 1] != " ":
+            text = text[:sidx] + " " + text[sidx:]
+            sidx += 1
+        if sidx + len(token) < len(text) and text[sidx + len(token)] != " ":
+            text = text[: sidx + 1] + " " + text[sidx + 1 :]
+        sidx += 1
+    return text
+
+
+TOKEN_EXCEPTIONS = {
+    "childs": "children",
+    "businesss": "businesses",
+    "inchs": "inches",
+}
+PATTERN_EXCEPTIONS = {"breakfasts": "b&bs"}
+
+
+def invert_trade_subtokenization(
+    original_seq: str,
+    trade_seq: str,
+    token_exceptions: Dict[str, str] = TOKEN_EXCEPTIONS,
+    pattern_exceptions: Dict[str, str] = PATTERN_EXCEPTIONS,
+    subtoken_special_chrs: List[str] = [" -", " _"],
+) -> str:
+    """
+    Invert all trade subtokenizations in a string given the original sequence.
+
+    Arguments
+    ---------
+    original_seq: str
+        The original sequence.
+    trade_seq: str
+        The sequence that has been pre-processed by trade.
+    token_exceptions: dict, keys are str, values are str
+        A dictionary to map merged token to their correct counterpart. E.g.
+        child -s is merged into childs, but the correct token is children.
+    pattern_exceptions: dict, keys are str, values are str
+        A dictionary to map patterns to their correct counterpart. E.g.
+        after the pre-processing "b&bs" is mapped to "bed and breakfast -s",
+        making the search of breakfasts impossible if not handled by such
+        exceptions.
+    subtoken_special_chrs: list of str
+        List containing the special characters that are used for subtokens.
+
+    Returns
+    -------
+    corrected_seq: str
+        The sequence corrected, i.e. subtokens replaced by tokens.
+    """
+    regex = "|".join(subtoken_special_chrs)
+    subtoken_pieces = re.split(regex, trade_seq, maxsplit=1)
+    search_after: int = 0
+    while len(subtoken_pieces) > 1:
+        # example: 'the wind is moderate -ly strong'
+        # split: ['the wind is moderate ', 'ly strong']
+        # split[0]: 'the wind is moderate' --> split on whitespace ['the', 'wind', 'is', 'moderate']
+        left_side = subtoken_pieces[0].split()
+        subtoken_left = left_side[-1]
+        # split[1]: 'ly strong' --> split on whitespace ['ly', 'strong']
+        right_side = subtoken_pieces[1].split()
+        subtoken_right = right_side[0]
+        # try merging the subtoken parts to form a token, i.e. moderate + ly
+        token = "".join([subtoken_left, subtoken_right])
+
+        if token in token_exceptions:
+            # if you match an exception, replace the token with the exception
+            token = token_exceptions[token]
+
+        # assume there are no tokens on left and right side of the subtokens' pieces
+        left_token = None  # if token is at the beginnig
+        right_token = None  # if token is at the end
+        # try looking for them
+        if len(left_side) > 1:
+            left_token = left_side[-2]
+        if len(right_side) > 1:
+            right_token = right_side[1]
+
+        # start from a complete match, and progressively remove left and right
+        # tokens to counter TRADE preprocessing of some tokens
+        # The order is
+        # 1. True, True
+        # 2. True, False
+        # 3. False, True
+        # 4. False, False
+        # basically, at the end you try looking only for the merged token
+        pattern: str = ""
+        idx: int = -1
+        for use_left, use_right in product((True, False), (True, False)):
+            pattern = token
+            if (left_token is not None) and use_left:
+                pattern = " ".join([left_token, pattern])
+            if right_token is not None and use_right:
+                pattern = " ".join([pattern, right_token])
+
+            # check if the pattern is in the exceptions
+            if pattern in pattern_exceptions:
+                pattern = pattern_exceptions[pattern]
+            # Search the pattern
+            idx = original_seq[search_after:].lower().find(pattern)
+            if idx > -1:
+                break
+
+        error: str = f"""
+            Pattern search failed in the following case:
+            PATTERN =  \t{pattern}
+            LEFT SIDE = \t{left_side}
+            RIGHT SIDE = \t{right_side}
+            ORIG SEQ = \t{original_seq[search_after:]}
+
+            This may be due to further TRADE pre-processing, or not correct merging operation.
+            To solve this, add a special rule for the token that breaks the code either as a
+            token_exception or a pattern_exception.
+        """
+
+        assert idx > -1, error
+        # move the index to avoid perfect matches with the same token
+        # TODO is probably better to move it of len(left_token + token) or
+        # len(token) depending on the match
+        search_after += idx + 1
+        # reconstruct the sentence with the matched pattern
+        trade_seq = " ".join([*left_side[:-1], token, *right_side[1:]])
+
+        # try splitting the sentence again and repeat the process
+        subtoken_pieces = re.split(regex, trade_seq, maxsplit=1)
+    # Good, no subtokens found: return trade seq
+    return trade_seq
+
+
+def get_json_object(data_path: str) -> dict:
+    """
+    A function to read a json object and return the python
+    dictionary associated to it.
+
+    Arguments
+    ---------
+    data_path: str
+        Path to a json file.
+
+    Returns
+    -------
+    loaded_json: dict
+        A loaded json object.
+    """
+    with open(data_path, "r") as data_file:
+        data = json.load(data_file)
+
+    return data
+
+
+def create_dialogue_dataset(
+    dialogues: List[List[Dict[str, Any]]]
+) -> Dict[str, Dict[str, Any]]:
+    """
+    Creates a dialogue dataset starting from a set of dialogues. Each
+    entry of the dataset contains the dialogue history and the system
+    reply in response to that.
+
+    Arguments
+    ---------
+    dialogues: list of list of dict, keys are str, values could be anything
+        List of dialogues. Each dialogue is a list of turns. Each turn is a
+        dict containing dialogue_idx, turn_idx, and the corrected sequence.
+    kwargs: any
+        Additional arguments for the current function.
+
+    Returns
+    -------
+    dataset: Dict[str, Dict[str, Any]]
+        Dataset, keys are str, values are dictionaries containing the
+        dialogue history and the system reply.
+    """
+
+    def create_dialogue_dataset_entry(
+        turn: Dict[str, Any], history: List[str]
+    ) -> Optional[Dict[str, Any]]:
+        """
+        Creates an entry if the current turn id is odd. An entry is
+        composed of the history, which contains the previous turns
+        of the current dialogue, and the reply of the system.
+
+        Arguments
+        ---------
+        turn: dict, keys are str, values could be anything
+            A dict containing, the dialogue id, the turn id, the sequence,
+            and the mean length.
+        replacements_path: str
+            Path to TRADE file containing (from, to) pairs, one per line.
+        kwargs: any
+            Additional arguments for the current function.
+
+        Returns
+        -------
+        entry: optional dict, keys are str, values could be anything
+            Entry of the dialogue dataset. It is a dict containing the history
+            of the dialogue, i.e. a list of turns, the reply of the system,
+            i.e. a turn, and the mean length.
+        """
+
+        turn_idx = turn["turn_idx"]
+        entry: Optional[Dict[str, Any]] = None
+        if turn_idx % 2 == 0:
+            # user turn, simply append it to the history
+            user_seq: str = turn["sequence"]
+            history.append(user_seq)
+        elif turn_idx % 2 == 1:
+            # system turn, create the dataset entry, and the append it to the history
+            system_seq: str = turn["sequence"]
+            history_mean_length = mean([len(turn) for turn in history])
+            entry = {
+                "history": history.copy(),
+                "reply": system_seq,
+                "length": history_mean_length + len(system_seq),
+            }
+            history.append(system_seq)
+        return entry
+
+    dataset: Dict[str, Dict[str, Any]] = {}
+    for dialogue in tqdm(dialogues, desc="Creating dataset"):
+        history: List[str] = []
+        for turn in dialogue:
+            # custom function to create a dataset entry
+            dataset_entry = create_dialogue_dataset_entry(turn, history)
+            # custom function to create a dataset key
+            key = create_entry_key(turn)
+            if dataset_entry is not None:
+                dataset[key] = dataset_entry
+    return dataset
+
+
+def create_entry_key(turn: Dict[str, Any]) -> str:
+    """
+    Creates the entry key for a given entry by considering dialogue id
+    and turn id for the given turn.
+
+    Arguments
+    ---------
+    turn: dict, keys are str, values could be anything
+        A dict containing, the dialogue id, the turn id, the sequence,
+        and the mean length.
+    kwargs: any
+        Additional arguments for the current function.
+
+    Returns
+    -------
+    key: str
+        The key for the given turn.
+    """
+    dialogue_idx = turn["dialogue_idx"]
+    turn_idx = turn["turn_idx"]
+    return f"{dialogue_idx}_{turn_idx}"
+
+
+def save_dialogue_dataset(
+    dataset: Dict[str, Dict[str, Any]], save_file: str
+) -> None:
+    """
+    Saves the dialogue dataset at dst_folder/file_name as a json file.
+
+    Arguments
+    ---------
+    dataset: Dict[str, Dict[str, Any]]
+        Dataset, keys are str, values are dictionaries containing the
+        dialogue history, the system reply, and the mean length.
+    save_file: str
+        Path to the folder where the original MultiWOZ dataset is stored.
+    """
+    with open(save_file, "w") as f:
+        json.dump(dataset, f, indent=4)
diff --git a/recipes/MultiWOZ/response_generation/train_with_gpt.py b/recipes/MultiWOZ/response_generation/train_with_gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..09630bef940ec8622e870189dbcbc65e01f157ae
--- /dev/null
+++ b/recipes/MultiWOZ/response_generation/train_with_gpt.py
@@ -0,0 +1,522 @@
+#!/usr/bin/env python3
+"""
+Recipe for training a gpt_based response generation model with MultiWOZ.
+The system employs GPT2 (https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf).
+This recipe takes the GPT2LMHeadModel to fine-tune for the response generation task on the NLL.
+
+To run this recipe, do the following:
+> python train_with_gpt.py hparams/train_gpt.yaml
+
+Authors
+ * Pooneh Mousavi 2023
+ * Simone Alghisi 2023
+"""
+
+
+import sys
+import speechbrain as sb
+import torch
+from itertools import chain
+from hyperpyyaml import load_hyperpyyaml
+from speechbrain.utils.distributed import run_on_main
+from transformers import GPT2Tokenizer
+import math
+from speechbrain.dataio.batch import PaddedBatch
+
+
+class ResGenBrain(sb.Brain):
+    def compute_forward(self, batch, stage):
+        """Computation pipeline based on a gpt decoder.
+        """
+        # Get required data from batch
+        batch = batch.to(self.device)
+        input_ids, _ = batch.input_ids
+        token_type_ids, _ = batch.token_type_ids
+
+        # Forward Pass
+        padding_mask = ~self.hparams.padding_mask(
+            input_ids, pad_idx=tokenizer.unk_token_id
+        )
+        outputs = self.modules.gpt_model(
+            input_ids, token_type_ids, padding_mask
+        ).logits
+
+        return outputs
+
+    def compute_objectives(self, predictions, batch, stage):
+        """Computes the NLL-loss using reply as label.
+        """
+        # Get required data from batch
+        batch = batch.to(self.device)
+        ids = batch.id
+        lm_labels, labels_lens = batch.lm_labels
+        history_bos, history_lens = batch.history_bos
+        reply_eos, reply_lens = batch.reply_eos
+        history_token_type, _ = batch.history_token_type
+
+        loss = self.hparams.ce_loss(
+            predictions.flatten(end_dim=-2), lm_labels.flatten()
+        )
+
+        if stage == sb.Stage.VALID:
+            # hyps = None
+            # current_epoch = self.hparams.epoch_counter.current
+            # if current_epoch % self.hparams.valid_search_interval == 0:
+            # history_bos = torch.LongTensor([hparams["bos_index"]] + (history_bos))
+            padding_mask = ~self.hparams.padding_mask(
+                history_bos, pad_idx=tokenizer.unk_token_id
+            )
+            hyps = self.modules.gpt_model.generate(
+                history_bos.detach(),
+                history_token_type.detach(),
+                padding_mask.detach(),
+            )
+        elif stage == sb.Stage.TEST:
+            padding_mask = ~self.hparams.padding_mask(
+                history_bos, pad_idx=tokenizer.unk_token_id
+            )
+            hyps = self.modules.gpt_model.generate(
+                history_bos.detach(),
+                history_token_type.detach(),
+                padding_mask.detach(),
+                "beam",
+            )
+
+        if stage != sb.Stage.TRAIN:
+            reply_truncated = [
+                reply_eos[i][
+                    : int(reply_lens[i].item() * reply_eos.shape[1] - 1)
+                ].detach()
+                for i in range(reply_eos.shape[0])
+            ]
+            predicted_words = tokenizer.batch_decode(
+                hyps[:, history_bos.shape[1] :],
+                skip_special_tokens=True,
+                clean_up_tokenization_spaces=True,
+            )
+            target_words = tokenizer.batch_decode(
+                reply_truncated,
+                skip_special_tokens=True,
+                clean_up_tokenization_spaces=True,
+            )
+            self.bleu_4_metric.append(ids, predicted_words, target_words)
+            self.bleu_2_metric.append(ids, predicted_words, target_words)
+            if stage != sb.Stage.TRAIN:
+                self.hyps.extend(predicted_words)
+                self.references.extend(target_words)
+
+        return loss
+
+    def fit_batch(self, batch):
+        """Trains the parameters given a single batch in input"""
+
+        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
+        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
+        loss.backward()
+        if self.check_gradients(loss):
+            self.optimizer.step()
+        self.optimizer.zero_grad()
+
+        return loss.detach()
+
+    def on_stage_start(self, stage, epoch):
+        """Gets called at the beginning of each epoch"""
+        if stage != sb.Stage.TRAIN:
+            self.bleu_4_metric = self.hparams.bleu_4_computer()
+            self.bleu_2_metric = self.hparams.bleu_2_computer()
+            self.hyps = []
+            self.references = []
+
+    def on_stage_end(self, stage, stage_loss, epoch):
+        """Gets called at the end of an epoch.
+
+        Arguments
+        ---------
+        stage : sb.Stage
+            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
+        stage_loss : float
+            The average loss for all of the data processed in this stage.
+        epoch : int
+            The currently-starting epoch. This is passed
+            `None` during the test stage.
+        """
+
+        # Store the train loss until the validation stage.
+        stage_stats = {"loss": stage_loss}
+        stage_stats["PPL"] = math.exp(stage_loss)
+        if stage == sb.Stage.TRAIN:
+            self.train_stats = stage_stats
+        else:
+            stage_stats["BLEU_4"] = self.bleu_4_metric.summarize("BLEU")
+            stage_stats["BLEU_2"] = self.bleu_2_metric.summarize("BLEU")
+        # Perform end-of-iteration things, like annealing, logging, etc.
+        if stage == sb.Stage.VALID:
+            # Update learning rate
+            old_lr, new_lr = self.hparams.lr_annealing(epoch)
+            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
+
+            # The train_logger writes a summary to stdout and to the logfile.
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={"epoch": epoch, "lr": old_lr},
+                train_stats=self.train_stats,
+                valid_stats=stage_stats,
+            )
+            # Save the current checkpoint and delete previous checkpoints.
+            self.checkpointer.save_and_keep_only(
+                meta={"PPL": stage_stats["PPL"]}, min_keys=["PPL"],
+            )
+            if epoch == hparams["number_of_epochs"] - 1:
+                with open(self.hparams.bleu_4_valid_file, "w") as w:
+                    self.bleu_4_metric.write_stats(w)
+                    for i in range(len(self.hyps)):
+                        w.write("target: " + str(self.references[i]) + "\n")
+                        w.write("predicted:" + str(self.hyps[i]) + "\n")
+                        w.write(
+                            "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+                        )
+
+        # We also write statistics about test data to stdout and to the logfile.
+        elif stage == sb.Stage.TEST:
+
+            self.hparams.train_logger.log_stats(
+                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
+                test_stats=stage_stats,
+            )
+            with open(self.hparams.bleu_4_test_file, "w") as w:
+                self.bleu_4_metric.write_stats(w)
+                for i in range(len(self.hyps)):
+                    w.write("target: " + str(self.references[i]) + "\n")
+                    w.write("predicted:" + str(self.hyps[i]) + "\n")
+                    w.write(
+                        "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+                    )
+
+    def init_optimizers(self):
+        "Initializes the model optimizer"
+        self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())
+
+        if self.checkpointer is not None:
+            self.checkpointer.add_recoverable("optimizer", self.optimizer)
+
+    def zero_grad(self, set_to_none=False):
+        self.optimizer.zero_grad(set_to_none)
+
+
+def add_special_tokens_(model, tokenizer, attr_to_special_token,) -> None:
+    orig_num_tokens = len(tokenizer.encoder)
+    num_added_tokens = tokenizer.add_special_tokens(
+        attr_to_special_token  # type: ignore
+    )  # doesn't add if they are already there
+    if num_added_tokens > 0:
+        model.resize_token_embeddings(
+            new_num_tokens=orig_num_tokens + num_added_tokens
+        )
+
+
+def dataio_prep(hparams, tokenizer):
+    """This function prepares the datasets to be used in the brain class.
+    It also defines the data processing pipeline through user-defined
+    functions. We expect `prepare_multiwoz` to have been called before
+    this, so that the `train.json`, `dev.json`,  and `test.json` manifest
+    files are available.
+    Arguments
+    ---------
+    hparams : dict
+        This dictionary is loaded from the `train.yaml` file, and it includes
+        all the hyperparameters needed for dataset construction and loading.
+    Returns
+    -------
+    datasets : dict
+        Contains two keys, "train" and "valid" that correspond
+        to the appropriate DynamicItemDataset object.
+    """
+
+    # convert special tokens to their ids
+    bos, eos, system, user = tokenizer.convert_tokens_to_ids(
+        hparams["special_tokens"]
+    )
+    # history_window, i.e. how many user-system exchanges consider as context (+1 to consider at least the last user turn)
+    history_window = 2 * hparams["max_history"] + 1
+
+    #  Define histoy pipeline:
+    @sb.utils.data_pipeline.takes("history")
+    @sb.utils.data_pipeline.provides(
+        "history",
+        "history_tokens_lists",
+        "history_ids",
+        "history_bos",
+        "history_token_type",
+    )
+    def history_pipeline(history):
+        yield history
+
+        # encode each turn of the history
+        history_tokens_lists = [tokenizer.encode(turn) for turn in history]
+        yield history_tokens_lists
+
+        # add speaker tokens to the history turns (user is even, system is odd)
+        # BEFORE:  [Hi how are you?], [I'm fine, thanks]
+        # AFTER:   [SPK_1 Hi how are you?], [SPK_2 I'm fine, thanks]
+        history_input_lists = [
+            [user if i % 2 == 0 else system] + encoded_turn
+            for i, encoded_turn in enumerate(history_tokens_lists)
+        ]
+
+        history_ids = history_input_lists[-history_window:]
+        # concatenate every token into a single list
+        # list(chain(*[[1, 2], [3, 4], [5]]))
+        # >>> [1, 2, 3, 4, 5]
+        history_ids = torch.LongTensor(list(chain(*history_ids)))
+        # without bos for lm_labels
+        yield history_ids
+
+        # create bos version for the input
+        history_bos = torch.cat((torch.tensor([bos]), history_ids))
+        yield history_bos
+
+        # create a mapping that associates each token in the input to a speaker
+        # INPUT: [SPK_1 Hi    how   are   you? ], [SPK_2 I'm   fine, thanks]
+        # TYPE:  [SPK_1 SPK_1 SPK_1 SPK_1 SPK_1], [SPK_2 SPK_2 SPK_2 SPK_2 ]
+        history_token_type_lists = [
+            [user if i % 2 == 0 else system] * len(encoded_turn)
+            for i, encoded_turn in enumerate(history_input_lists)
+        ]
+        history_token_type = torch.LongTensor(
+            list(
+                chain(
+                    *([[system]] + history_token_type_lists[-history_window:])
+                )
+            )
+        )
+
+        yield history_token_type
+
+    #  Define reply pipeline:
+    @sb.utils.data_pipeline.takes("reply")
+    @sb.utils.data_pipeline.provides(
+        "reply",
+        "reply_tokens_list",
+        "reply_ids",
+        "reply_eos",
+        "reply_token_type",
+    )
+    def reply_pipeline(reply):
+        yield reply
+
+        reply_tokens_list = tokenizer.encode(reply)
+        yield reply_tokens_list
+
+        # specify that the system will say the reply
+        reply_input_list = [system] + reply_tokens_list
+        reply_ids = torch.LongTensor(reply_input_list)
+        yield reply_ids
+
+        # create eos version of the reply for lm_labels
+        reply_eos = torch.cat((reply_ids, torch.tensor([eos])))
+        yield reply_eos
+
+        # specify the speaker for each token in the reply
+        reply_token_type = torch.LongTensor([system] * len(reply_input_list))
+        yield reply_token_type
+
+    # Define input_and_token_type_pipeline
+    @sb.utils.data_pipeline.takes(
+        "history_ids",
+        "history_bos",
+        "history_token_type",
+        "reply_ids",
+        "reply_eos",
+        "reply_token_type",
+    )
+    @sb.utils.data_pipeline.provides("input_ids", "token_type_ids", "lm_labels")
+    def input_and_token_type_pipeline(
+        history_ids,
+        history_bos,
+        history_token_type,
+        reply_ids,
+        reply_eos,
+        reply_token_type,
+    ):
+
+        # put history and reply together
+        # N.B. input_sequence = history_bos + reply_ids, we don't have eos in the input
+        input_ids = torch.cat((history_bos, reply_ids), -1)
+        yield input_ids
+
+        token_type_ids = torch.cat((history_token_type, reply_token_type), -1)
+        yield token_type_ids
+
+        # create the language model label (ground truth) for the current input
+        # -100 is a special tokens that is ignored during the loss computation
+        # the idea is to mask everything except the reply (withouth the speaker token)
+        # N.B. we don't have bos in the input
+        lm_labels = (
+            [hparams["ignore_index"]] * history_ids.shape[0]
+            + [hparams["ignore_index"]]
+            + reply_eos[1:].tolist()
+        )
+        lm_labels = torch.LongTensor(lm_labels)
+
+        yield lm_labels
+
+    # Define datasets. We also connect the dataset with the data processing
+    # functions defined above.
+    datasets = {}
+    data_info = {
+        "train": hparams["train_annotation"],
+        "valid": hparams["valid_annotation"],
+        "test": hparams["test_annotation"],
+    }
+    for dataset in data_info:
+        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
+            json_path=data_info[dataset],
+            replacements={"data_root": hparams["data_folder"]},
+            dynamic_items=[
+                reply_pipeline,
+                history_pipeline,
+                input_and_token_type_pipeline,
+            ],
+            output_keys=[
+                "id",
+                "input_ids",
+                "token_type_ids",
+                "history_bos",
+                "reply_eos",
+                "history_token_type",
+                "reply_token_type",
+                "lm_labels",
+            ],
+        )
+
+    return datasets
+
+
+# RECIPE BEGINS!
+if __name__ == "__main__":
+
+    # Reading command line arguments.
+    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
+
+    # Initialize ddp (useful only for multi-GPU DDP training).
+    sb.utils.distributed.ddp_init_group(run_opts)
+
+    # Load hyperparameters file with command-line overrides.
+    with open(hparams_file) as fin:
+        hparams = load_hyperpyyaml(fin, overrides)
+
+    # Create experiment directory
+    sb.create_experiment_directory(
+        experiment_directory=hparams["output_folder"],
+        hyperparams_to_save=hparams_file,
+        overrides=overrides,
+    )
+
+    # Dataset prep (parsing MultiWOZ)
+    from multiwoz_prepare import prepare_mwoz_21
+
+    run_on_main(
+        prepare_mwoz_21,
+        kwargs={
+            "data_folder": hparams["data_folder"],
+            "save_folder": hparams["output_folder"],
+            "replacements_path": hparams["replacements_path"],
+            "skip_prep": hparams["skip_prep"],
+        },
+    )
+
+    # Load tokenizer and add special tokens
+    tokenizer = GPT2Tokenizer.from_pretrained(
+        hparams["gpt_hub"], pad_token=None
+    )
+
+    #  Load pretrained GPT
+    hparams["gpt_model"] = hparams["gpt_model"].to(device=run_opts["device"])
+
+    # Add special tokens to the tokenizer and resize model embedding
+    add_special_tokens_(
+        hparams["gpt_model"].model, tokenizer, hparams["attr_to_special_tokens"]
+    )
+
+    class CustomPaddedBatch(PaddedBatch):
+        """PaddedBatch with custom padding values.
+
+        See the documentation of `speechbrain.dataio.batch.PaddedBatch`.
+
+        """
+
+        def __init__(self, examples, *args, **kwargs):
+            _, _, system, _ = tokenizer.convert_tokens_to_ids(
+                hparams["special_tokens"]
+            )
+            for k in [
+                "input_ids",
+                "history_bos",
+                "lm_labels",
+                "token_type_ids",
+                "history_token_type",
+            ]:
+                max_len = max([len(x[k]) for x in examples])
+                pad_value = 0
+                if k in [
+                    "input_ids",
+                    "history_bos",
+                    "token_type_ids",
+                    "history_token_type",
+                ]:
+                    pad_value = tokenizer.unk_token_id
+                elif k == "lm_labels":
+                    pad_value = hparams["ignore_index"]
+                for example in examples:
+                    x = example[k]
+                    if k in ["history_bos", "history_token_type"]:
+                        x = torch.cat(
+                            (example[k], torch.LongTensor([system])), -1
+                        )
+                        example[k] = torch.nn.functional.pad(
+                            x, [max_len - len(x), 0], value=pad_value
+                        )
+                    else:
+                        example[k] = torch.nn.functional.pad(
+                            x, [0, max_len - len(x)], value=pad_value
+                        )
+            super().__init__(examples, *args, **kwargs)
+
+    hparams["train_dataloader_options"]["collate_fn"] = CustomPaddedBatch
+    hparams["test_dataloader_options"]["collate_fn"] = CustomPaddedBatch
+
+    # Create dataset objects "train", "valid", and "test".
+    datasets = dataio_prep(hparams, tokenizer)
+
+    # Initialize the Brain object to prepare for mask training.
+    res_gen_brain = ResGenBrain(
+        modules=hparams["modules"],
+        opt_class=hparams["opt_class"],
+        hparams=hparams,
+        run_opts=run_opts,
+        checkpointer=hparams["checkpointer"],
+    )
+
+    # We load the pretrained whisper model
+    if "pretrainer" in hparams.keys():
+        run_on_main(hparams["pretrainer"].collect_files)
+        hparams["pretrainer"].load_collected(res_gen_brain.device)
+
+    # The `fit()` method iterates the training loop, calling the methods
+    # necessary to update the parameters of the model. Since all objects
+    # with changing state are managed by the Checkpointer, training can be
+    # stopped at any point, and will be resumed on next call.
+    res_gen_brain.fit(
+        epoch_counter=res_gen_brain.hparams.epoch_counter,
+        train_set=datasets["train"],
+        valid_set=datasets["valid"],
+        train_loader_kwargs=hparams["train_dataloader_options"],
+        valid_loader_kwargs=hparams["test_dataloader_options"],
+    )
+
+    # Load the best checkpoint for evaluation
+    test_stats = res_gen_brain.evaluate(
+        test_set=datasets["test"],
+        min_key="PPL",
+        test_loader_kwargs=hparams["test_dataloader_options"],
+    )
diff --git a/speechbrain/lobes/models/huggingface_gpt.py b/speechbrain/lobes/models/huggingface_gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..0de27a6e8c5db5b786a69ffeb07f78d74c07461d
--- /dev/null
+++ b/speechbrain/lobes/models/huggingface_gpt.py
@@ -0,0 +1,154 @@
+"""This lobe enables the integration of huggingface pretrained GPT2LMHeadModel model.
+
+Transformer from HuggingFace needs to be installed:
+https://huggingface.co/transformers/installation.html
+
+Authors
+ * Pooneh Mousavi 2023
+ * Simone Alghisi 2023
+"""
+
+import logging
+from torch import Tensor
+import torch
+import torch.nn as nn
+
+try:
+    from transformers import GPT2LMHeadModel
+except ImportError:
+    MSG = "Please install transformers from HuggingFace to use GPT2\n"
+    MSG += "E.G. run: pip install transformers"
+    raise ImportError(MSG)
+
+logger = logging.getLogger(__name__)
+
+
+class HuggingFaceGPT(nn.Module):
+    """This lobe enables the integration of HuggingFace pretrained GPT model.
+     Source paper whisper:
+        https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf
+    Transformer from HuggingFace needs to be installed:
+        https://huggingface.co/transformers/installation.html
+
+    The model can be finetuned. It will download automatically the model from
+    HuggingFace or use a local path.
+
+    Arguments
+    ---------
+    source : str
+        HuggingFace hub name: e.g "gpt2"
+    save_path : str
+        Path (dir) of the downloaded model.
+    freeze : bool (default: False)
+        If True, the model is frozen. If False, the model will be trained
+        alongside with the rest of the pipeline.
+    Example
+    -------
+    >>> model_hub = "gpt2"
+    >>> save_path = "savedir"
+    >>> model = HuggingFaceGPT(model_hub, save_path)
+    >>> tokens = torch.tensor([[1, 1]])
+    >>> tokens_type = torch.tensor([[1, 1]])
+    >>> attention_mask = torch.tensor([[1, 1]])
+    >>> outputs = model(tokens, tokens_type, attention_mask)
+    """
+
+    def __init__(
+        self,
+        source: str,
+        save_path: str,
+        freeze: bool = False,
+        max_new_tokens: int = 200,
+        min_length: int = 1,
+        top_k: int = 45,
+        top_p: float = 0.9,
+        num_beams: int = 8,
+        early_stopping: bool = True,
+    ) -> None:
+        super().__init__()
+        self.freeze = freeze
+        self.max_new_tokens = max_new_tokens
+        self.min_length = min_length
+        self.top_k = top_k
+        self.top_p = top_p
+        self.num_beams = num_beams
+        self.early_stopping = early_stopping
+        self.model = GPT2LMHeadModel.from_pretrained(
+            source, cache_dir=save_path
+        )
+        if self.freeze:
+            logger.warning("huggingface_GPT - GPT  is frozen.")
+            self.model.train()  # we keep it to train to have dropout and LN computed adequaly
+            for param in self.model.parameters():
+                param.requires_grad = False
+
+    def forward(
+        self, input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor,
+    ):
+        """ Takes an input a history of conversation and returns its corresponding reply.
+
+        Arguments
+        ---------
+        input_ids : torch.Tensor ()
+            A batch of input-id to transform to features.
+        token_type_ids : torch.Tensor
+            Token Type(Speaker) for each token in input_ids.
+        attention_mask : torch.Tensor ()
+            A batch of attention_mask.
+        """
+        with torch.set_grad_enabled(not self.freeze):
+            output = self.model.forward(
+                input_ids,
+                token_type_ids=token_type_ids,
+                attention_mask=attention_mask,
+            )
+        return output
+
+    def generate(
+        self,
+        input_ids: Tensor,
+        token_type_ids,
+        attention_mask: Tensor,
+        decoder_type="greedy",
+    ):
+        """ Takes an input a history of conversation and returns its corresponding reply.
+
+        Arguments
+        --------
+        input_ids : torch.Tensor ()
+            A batch of input-id   which are dialogue context tokens
+        decoder_type : Str
+            It shows strategy for autoregressive decoding either beam seach or greedy.
+        attention_mask : torch.Tensor ()
+            A batch of attention_mask.
+        """
+
+        with torch.no_grad():
+            if decoder_type == "beam":
+                # beam decoding based on the input_ids which are dialogue context tokens (here only history)
+                hyp = self.model.generate(
+                    input_ids=input_ids,
+                    token_type_ids=token_type_ids,
+                    attention_mask=attention_mask,
+                    do_sample=True,
+                    max_new_tokens=self.max_new_tokens,
+                    min_length=self.min_length,
+                    top_k=self.top_k,
+                    top_p=self.top_p,
+                    num_beams=self.num_beams,
+                    num_return_sequences=1,
+                    # pad_token_id=50258,
+                    eos_token_id=50258,
+                    early_stopping=self.early_stopping,
+                )
+            else:
+                # greedy decoding based on the input_ids which are dialogue context tokens (here only history)
+                hyp = self.model.generate(
+                    input_ids,
+                    token_type_ids=token_type_ids,
+                    max_new_tokens=self.max_new_tokens,
+                    # pad_token_id=50258,
+                    eos_token_id=50258,
+                    attention_mask=attention_mask,
+                )
+        return hyp
diff --git a/speechbrain/pretrained/interfaces.py b/speechbrain/pretrained/interfaces.py
index fc68f68618586384baba4ca747e68e08cd2ed1b0..29412ab4651ad9b8eaac464822d71fec990ef8f3 100644
--- a/speechbrain/pretrained/interfaces.py
+++ b/speechbrain/pretrained/interfaces.py
@@ -38,6 +38,7 @@ from speechbrain.utils.callchains import lengths_arg_exists
 from speechbrain.utils.superpowers import import_from_path
 from speechbrain.dataio.dataio import length_to_mask
 from speechbrain.processing.NMF import spectral_phase
+from itertools import chain
 
 logger = logging.getLogger(__name__)
 
@@ -4259,3 +4260,132 @@ class PIQAudioInterpreter(Pretrained):
     def forward(self, wavs, wav_lens=None):
         """Runs the classification"""
         return self.interpret_batch(wavs, wav_lens)
+
+
+class ResponseGenerator(Pretrained):
+    """A ready-to-use Response Generator  model
+
+    The class can be used to generate and continue dialogue given the user input.
+    The given YAML must contain the fields specified in the *_NEEDED[] lists.
+    It needs to be used with custom.py to load the expanded GPT model with added tokens like bos,eos, and speaker's tokens.
+
+    Example
+    -------
+    >>> from speechbrain.pretrained import ResponseGenerator
+
+    >>> tmpdir = getfixture("tmpdir")
+    >>> res_gen_model = ResponseGenerator.from_hparams(source="speechbrain/MultiWOZ-GPT-Response_Generation",
+    ... savedir="tmpdir",
+    ... pymodule_file="custom.py")
+    >>> response = res_gen_model.generate_response("I want to book a table for dinner")
+    """
+
+    HPARAMS_NEEDED = ["tokenizer"]
+    MODULES_NEEDED = ["gpt-model"]
+
+    def __init__(self, *args, **kwargs):
+
+        super().__init__(*args, **kwargs)
+        #  Load model
+        self.model = self.hparams.model
+        # convert special tokens to their ids
+        (
+            self.bos,
+            self.eos,
+            self.system,
+            self.user,
+        ) = self.model.tokenizer.convert_tokens_to_ids(
+            self.hparams.special_tokens
+        )
+        self.history_window = 2 * self.hparams.max_history + 1
+        self.history = []
+
+    def generate_response(self, turn):
+        """
+        Complete a dialogue given the user's input.
+        Arguments
+        ---------
+        turn: str
+            User input which is the last turn of the dialogue.
+
+        Returns
+        -------
+        response
+            Generated response for the user input based on the dialogue history.
+        """
+
+        self.history.append(turn)
+        history_bos, history_token_type = self.prepare_input()
+        history_bos = history_bos.unsqueeze(0)
+        history_token_type = history_token_type.unsqueeze(0)
+        padding_mask = ~self.hparams.padding_mask(
+            history_bos, pad_idx=self.model.tokenizer.unk_token_id
+        )
+        hyps = self.model.generate(
+            history_bos.detach(),
+            history_token_type.detach(),
+            padding_mask.detach(),
+            "beam",
+        )
+        predicted_words = self.model.tokenizer.batch_decode(
+            hyps[:, history_bos.shape[1] :],
+            skip_special_tokens=True,
+            clean_up_tokenization_spaces=True,
+        )
+        response = predicted_words[0]
+        self.history.append(response)
+        return response
+
+    def prepare_input(self):
+        """ Convert user input and previous histories to the format acceptable for  GPT model.
+            It appends all previous history and input and truncates it based on max_history value.
+            It then tokenizes the input and generates additional input that determines the type of each token (Sytem or User).
+
+        Arguments
+        ---------
+
+        Returns
+        -------
+        history_bos:
+            Tokenized history+input values with appropriate speaker token appended before each turn.
+        history_token_type:
+            Type of each token basd on who is uttered that token (either User or Sytem)
+        """
+        history_tokens_lists = [
+            self.model.tokenizer.encode(turn) for turn in self.history
+        ]
+        # add speaker tokens to the history turns (user is even, system is odd)
+        # BEFORE:  [Hi how are you?], [I'm fine, thanks]
+        # AFTER:   [SPK_1 Hi how are you?], [SPK_2 I'm fine, thanks]
+        history_input_lists = [
+            [self.user if i % 2 == 0 else self.system] + encoded_turn
+            for i, encoded_turn in enumerate(history_tokens_lists)
+        ]
+        history_ids = history_input_lists[-self.history_window :]
+        # concatenate every token into a single list
+        # list(chain(*[[1, 2], [3, 4], [5]]))
+        # >>> [1, 2, 3, 4, 5]
+        history_ids = torch.LongTensor(list(chain(*history_ids)))
+        # create bos version for the input
+        history_bos = torch.cat(
+            (torch.tensor([self.bos]), history_ids, torch.tensor([self.system]))
+        )
+        # create a mapping that associates each token in the input to a speaker
+        # INPUT: [SPK_1 Hi    how   are   you? ], [SPK_2 I'm   fine, thanks]
+        # TYPE:  [SPK_1 SPK_1 SPK_1 SPK_1 SPK_1], [SPK_2 SPK_2 SPK_2 SPK_2 ]
+        history_token_type_lists = [
+            [self.user if i % 2 == 0 else self.system] * len(encoded_turn)
+            for i, encoded_turn in enumerate(history_input_lists)
+        ]
+        history_token_type = torch.LongTensor(
+            list(
+                chain(
+                    *(
+                        [[self.system]]
+                        + history_token_type_lists[-self.history_window :]
+                        + [[self.system]]
+                    )
+                )
+            )
+        )
+        return history_bos, history_token_type
diff --git a/speechbrain/utils/bleu.py b/speechbrain/utils/bleu.py
index cfde75b84d0bbe1ec9a938cc4108e61b28519a42..0a6dcd51304033cfddad12e5ef54648e8ec1c769 100644
--- a/speechbrain/utils/bleu.py
+++ b/speechbrain/utils/bleu.py
@@ -46,12 +46,18 @@ class BLEUStats(MetricStats):
     0.0
     """
 
-    def __init__(
-        self, lang="en", merge_words=True,
-    ):
+    def __init__(self, lang="en", merge_words=True, max_ngram_order=4):
+        # Check extra-dependency for computing the bleu score
+        try:
+            from sacrebleu.metrics import BLEU
+        except ImportError:
+            print(
+                "Please install sacrebleu (https://pypi.org/project/sacrebleu/) in order to use the BLEU metric"
+            )
 
         self.clear()
         self.merge_words = merge_words
+        self.bleu = BLEU(max_ngram_order=max_ngram_order)
 
         self.predicts = []
         self.targets = None
@@ -97,15 +103,7 @@ class BLEUStats(MetricStats):
         * See MetricStats.summarize()
         """
 
-        # Check extra-dependency for computing the bleu score
-        try:
-            import sacrebleu
-        except ImportError:
-            print(
-                "Please install sacrebleu (https://pypi.org/project/sacrebleu/) in order to use the BLEU metric"
-            )
-
-        scores = sacrebleu.corpus_bleu(self.predicts, self.targets)
+        scores = self.bleu.corpus_score(self.predicts, self.targets)
         details = {}
         details["BLEU"] = scores.score
         details["BP"] = scores.bp
diff --git a/tests/recipes/MultiWOZ.csv b/tests/recipes/MultiWOZ.csv
new file mode 100644
index 0000000000000000000000000000000000000000..bdcf31c8123f95c7ae3432b632f6881184afa8c5
--- /dev/null
+++ b/tests/recipes/MultiWOZ.csv
@@ -0,0 +1,2 @@
+Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
+Response-Generation,MultiWOZ,recipes/MultiWOZ/response_generation/train_with_gpt.py,recipes/MultiWOZ/response_generation/hparams/train_gpt.yaml,recipes/MultiWOZ/response_generation/multiwoz_prepare.py,recipes/MultiWOZ/response_generation/README.md,https://www.dropbox.com/sh/vm8f5iavohr4zz9/AACrkOxXuxsrvJy4Cjpih9bQa?dl=0,https://huggingface.co/speechbrain/MultiWOZ-GPT-Response_Generation,--data_folder=tests/samples/ASR/  --train_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --valid_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --test_annotation=tests/samples/annotation/response_generation_train_multiwoz.json --number_of_epochs=2 --skip_prep=True,,
diff --git a/tests/samples/annotation/response_generation_train_multiwoz.json b/tests/samples/annotation/response_generation_train_multiwoz.json
new file mode 100644
index 0000000000000000000000000000000000000000..696a103a7a4417487ee9139a7f2660217f327d97
--- /dev/null
+++ b/tests/samples/annotation/response_generation_train_multiwoz.json
@@ -0,0 +1,74 @@
+{
+    "PMUL0698.json_1": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food ."
+        ],
+        "reply": "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+        "length": 178
+    },
+    "PMUL0698.json_3": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range ."
+        ],
+        "reply": "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+        "length": 213.33333333333331
+    },
+    "PMUL0698.json_5": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range .",
+            "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+            "i also need a train . the train should leave after 16:15 and should leave on sunday ."
+        ],
+        "reply": "can i have more information for the train you are needing ? where are you departing from and arriving to ?",
+        "length": 196.2
+    },
+    "PMUL0698.json_7": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range .",
+            "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+            "i also need a train . the train should leave after 16:15 and should leave on sunday .",
+            "can i have more information for the train you are needing ? where are you departing from and arriving to ?",
+            "i am leaving from cambridge and going to norwich ."
+        ],
+        "reply": "i have train tr1840 leaving at 16:36 is that okay ?",
+        "length": 137.71428571428572
+    },
+    "PMUL0698.json_9": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range .",
+            "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+            "i also need a train . the train should leave after 16:15 and should leave on sunday .",
+            "can i have more information for the train you are needing ? where are you departing from and arriving to ?",
+            "i am leaving from cambridge and going to norwich .",
+            "i have train tr1840 leaving at 16:36 is that okay ?",
+            "book for 5 people and get me the reference number"
+        ],
+        "reply": "you are all set . reference number is njb87pap . is there anything else i can help you with today ?",
+        "length": 177.55555555555554
+    },
+    "PMUL0698.json_11": {
+        "history": [
+            "i am looking for a local place to dine in the centre that serves chinese food .",
+            "i have restaurants matching your criteria in all price ranges . do you have a preference on price ?",
+            "i need the address , postcode and the price range .",
+            "ok how about charlie chan , located at regent street city centre . postcode is cb21db with a cheap price . can i help you further today ?",
+            "i also need a train . the train should leave after 16:15 and should leave on sunday .",
+            "can i have more information for the train you are needing ? where are you departing from and arriving to ?",
+            "i am leaving from cambridge and going to norwich .",
+            "i have train tr1840 leaving at 16:36 is that okay ?",
+            "book for 5 people and get me the reference number",
+            "you are all set . reference number is njb87pap . is there anything else i can help you with today ?",
+            "no , this is all i will need . thank you ."
+        ],
+        "reply": "thank for calling us today . i hope you have a good trip .",
+        "length": 135.0909090909091
+    }
+}
\ No newline at end of file