diff --git a/dnsmos.py b/dnsmos.py new file mode 100644 index 0000000000000000000000000000000000000000..fa11ce8d503cd3e470221f10ed7ff2d9ba702eab --- /dev/null +++ b/dnsmos.py @@ -0,0 +1,91 @@ +import os +from typing import Dict, List, Tuple, Union +import torch +import onnxruntime as ort +import numpy as np + +class DNSMOS: + SAMPLE_LENGTH = 9.01 + def __init__(self, dnsmos_path: str = os.path.join(os.path.dirname(__file__), "dnsmos.onnx"), device: Union[Tuple[str, int], torch.device] = None, sample_rate: int = 16000, cache_session: bool = False): + if not os.path.exists(dnsmos_path): + raise FileNotFoundError(f"DNSMOS model not found at {dnsmos_path}") + + if device is None: + self.device = torch.device("cpu") + self.execution_providers = ["CPUExecutionProvider"] + elif isinstance(device, tuple): + self.device = torch.device(device[0], device[1]) + elif isinstance(device, torch.device): + self.device = device + else: + raise ValueError("Invalid device argument") + + if sample_rate not in [16000, 8000]: + raise ValueError(f"Sample rate {sample_rate} not supported by DNSMOS. Must be 16000 or 8000") + + + + if self.device.type == "cpu": + self.execution_providers = ["CPUExecutionProvider"] + else: + self.execution_providers = [("CUDAExecutionProvider", {"device_id": self.device.index}), "CPUExecutionProvider"] + + self.dnsmos_path = dnsmos_path + self.sample_rate = sample_rate + self.cache_session = cache_session + if self.cache_session: + self.session = ort.InferenceSession(self.dnsmos_path, providers=self.execution_providers) + + + self.poly_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) + self.poly_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) + self.poly_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) + + def _split_tensor(self, denoised: torch.Tensor) -> List[torch.Tensor]: + if denoised.ndim == 1: + denoised = denoised.unsqueeze(0) + elif denoised.ndim > 2: + raise ValueError("Tensor must be 1D or 2D") + + sample_length = int(self.SAMPLE_LENGTH * self.sample_rate) + split_tensor = list(torch.split(denoised, sample_length, dim=1)) + # make tail the same length as the rest + end_idx = len(split_tensor) - 1 + while split_tensor[end_idx].shape[1] < sample_length: + split_tensor[end_idx] = torch.cat([split_tensor[end_idx], + split_tensor[0]], dim=1) + split_tensor[end_idx] = split_tensor[end_idx][:,:sample_length] + return [sample.detach().cpu().numpy() for sample in split_tensor] + + + def __call__(self, denoised: Union[torch.Tensor, np.ndarray]) -> Dict[str, float]: + if self.cache_session: + session = self.session + else: + session = ort.InferenceSession(self.dnsmos_path, providers=self.execution_providers) + + if isinstance(denoised, np.ndarray): + denoised = torch.from_numpy(denoised) + + samples = self._split_tensor(denoised) + scores = { + "raw_sig": [], + "raw_bak": [], + "raw_ovr": [], + "sig": [], + "bak": [], + "ovr": [] + } + + for sample_number, split_sample in enumerate(samples): + raw_sig, raw_bak, raw_ovr = session.run(None, {"input_1": split_sample})[0][0] + scores["raw_sig"].append(raw_sig) + scores["raw_bak"].append(raw_bak) + scores["raw_ovr"].append(raw_ovr) + scores["sig"].append(self.poly_sig(raw_sig)) + scores["bak"].append(self.poly_bak(raw_bak)) + scores["ovr"].append(self.poly_ovr(raw_ovr)) + + for key in scores: + scores[key] = np.mean(scores[key]) + return scores diff --git a/results_viewer.ipynb b/results_viewer.ipynb index d2cb103945830b50a20aa51e3e80e618212634b6..7d9cd6423c52c2c86d23ec37f7bf66c02d96d9ee 100644 --- a/results_viewer.ipynb +++ b/results_viewer.ipynb @@ -2,29 +2,351 @@ "cells": [ { "cell_type": "code", - "execution_count": 21, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>raw_sig</th>\n", + " <th>raw_bak</th>\n", + " <th>raw_ovr</th>\n", + " <th>sig</th>\n", + " <th>bak</th>\n", + " <th>ovr</th>\n", + " </tr>\n", + " <tr>\n", + " <th>batch</th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>3.145465</td>\n", + " <td>2.233926</td>\n", + " <td>2.340730</td>\n", + " <td>2.897365</td>\n", + " <td>2.465761</td>\n", + " <td>2.244499</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>3.845078</td>\n", + " <td>2.709487</td>\n", + " <td>2.718259</td>\n", + " <td>3.454555</td>\n", + " <td>2.896134</td>\n", + " <td>2.559756</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>2.835112</td>\n", + " <td>2.169795</td>\n", + " <td>2.134030</td>\n", + " <td>2.668490</td>\n", + " <td>2.276028</td>\n", + " <td>2.047455</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>2.145720</td>\n", + " <td>1.602384</td>\n", + " <td>1.571615</td>\n", + " <td>2.140288</td>\n", + " <td>1.783102</td>\n", + " <td>1.605862</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>3.246001</td>\n", + " <td>2.639835</td>\n", + " <td>2.498782</td>\n", + " <td>3.029038</td>\n", + " <td>2.767745</td>\n", + " <td>2.354629</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2696</th>\n", + " <td>3.028775</td>\n", + " <td>2.215597</td>\n", + " <td>2.249891</td>\n", + " <td>2.815935</td>\n", + " <td>2.383634</td>\n", + " <td>2.160208</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2697</th>\n", + " <td>3.218708</td>\n", + " <td>2.586063</td>\n", + " <td>2.498565</td>\n", + " <td>2.979260</td>\n", + " <td>2.751397</td>\n", + " <td>2.358187</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2698</th>\n", + " <td>3.720781</td>\n", + " <td>2.729928</td>\n", + " <td>2.815541</td>\n", + " <td>3.340067</td>\n", + " <td>2.948403</td>\n", + " <td>2.612177</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2699</th>\n", + " <td>3.243291</td>\n", + " <td>2.289145</td>\n", + " <td>2.401669</td>\n", + " <td>2.966480</td>\n", + " <td>2.512433</td>\n", + " <td>2.290272</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2700</th>\n", + " <td>3.730871</td>\n", + " <td>3.455348</td>\n", + " <td>3.125872</td>\n", + " <td>3.391119</td>\n", + " <td>3.571671</td>\n", + " <td>2.869536</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>2701 rows × 6 columns</p>\n", + "</div>" + ], + "text/plain": [ + " raw_sig raw_bak raw_ovr sig bak ovr\n", + "batch \n", + "0 3.145465 2.233926 2.340730 2.897365 2.465761 2.244499\n", + "1 3.845078 2.709487 2.718259 3.454555 2.896134 2.559756\n", + "2 2.835112 2.169795 2.134030 2.668490 2.276028 2.047455\n", + "3 2.145720 1.602384 1.571615 2.140288 1.783102 1.605862\n", + "4 3.246001 2.639835 2.498782 3.029038 2.767745 2.354629\n", + "... ... ... ... ... ... ...\n", + "2696 3.028775 2.215597 2.249891 2.815935 2.383634 2.160208\n", + "2697 3.218708 2.586063 2.498565 2.979260 2.751397 2.358187\n", + "2698 3.720781 2.729928 2.815541 3.340067 2.948403 2.612177\n", + "2699 3.243291 2.289145 2.401669 2.966480 2.512433 2.290272\n", + "2700 3.730871 3.455348 3.125872 3.391119 3.571671 2.869536\n", + "\n", + "[2701 rows x 6 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import pandas as pd\n", - "individual = pd.read_csv('~/dnsmostesting/dnsmos_scores_individual.csv', index_col=0)\n", - "batched = pd.read_csv('~/dnsmostesting/dnsmos_scores_batched.csv', index_col=0)\n", + "individual = pd.read_csv('/se/dnsmostesting/dnsmos_scores_individual.csv', index_col=0)\n", + "batched = pd.read_csv('/se/dnsmostesting/dnsmos_scores_batched.csv', index_col=0)\n", "individual_avg = individual.groupby('batch').mean()\n", "individual_avg" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>raw_sig</th>\n", + " <th>raw_bak</th>\n", + " <th>raw_ovr</th>\n", + " <th>sig</th>\n", + " <th>bak</th>\n", + " <th>ovr</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>3.575741</td>\n", + " <td>3.584303</td>\n", + " <td>3.144969</td>\n", + " <td>3.296980</td>\n", + " <td>3.680074</td>\n", + " <td>2.884885</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>3.733906</td>\n", + " <td>2.355249</td>\n", + " <td>2.552172</td>\n", + " <td>3.392991</td>\n", + " <td>2.663522</td>\n", + " <td>2.452156</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>3.847642</td>\n", + " <td>2.021054</td>\n", + " <td>2.351104</td>\n", + " <td>3.459435</td>\n", + " <td>2.318322</td>\n", + " <td>2.294580</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>1.670071</td>\n", + " <td>1.710422</td>\n", + " <td>1.319525</td>\n", + " <td>1.809921</td>\n", + " <td>1.971086</td>\n", + " <td>1.400098</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>3.641207</td>\n", + " <td>3.647267</td>\n", + " <td>3.193602</td>\n", + " <td>3.337230</td>\n", + " <td>3.721440</td>\n", + " <td>2.918276</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2696</th>\n", + " <td>1.068525</td>\n", + " <td>1.048294</td>\n", + " <td>1.037608</td>\n", + " <td>1.213865</td>\n", + " <td>1.146129</td>\n", + " <td>1.130593</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2697</th>\n", + " <td>3.065434</td>\n", + " <td>2.261132</td>\n", + " <td>2.250670</td>\n", + " <td>2.958564</td>\n", + " <td>2.569281</td>\n", + " <td>2.213821</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2698</th>\n", + " <td>3.856602</td>\n", + " <td>3.040631</td>\n", + " <td>3.061332</td>\n", + " <td>3.464577</td>\n", + " <td>3.279465</td>\n", + " <td>2.826714</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2699</th>\n", + " <td>3.093844</td>\n", + " <td>2.507351</td>\n", + " <td>2.339078</td>\n", + " <td>2.978555</td>\n", + " <td>2.810894</td>\n", + " <td>2.284982</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2700</th>\n", + " <td>3.292927</td>\n", + " <td>3.566518</td>\n", + " <td>2.847978</td>\n", + " <td>3.114832</td>\n", + " <td>3.668201</td>\n", + " <td>2.674032</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>2701 rows × 6 columns</p>\n", + "</div>" + ], + "text/plain": [ + " raw_sig raw_bak raw_ovr sig bak ovr\n", + "0 3.575741 3.584303 3.144969 3.296980 3.680074 2.884885\n", + "1 3.733906 2.355249 2.552172 3.392991 2.663522 2.452156\n", + "2 3.847642 2.021054 2.351104 3.459435 2.318322 2.294580\n", + "3 1.670071 1.710422 1.319525 1.809921 1.971086 1.400098\n", + "4 3.641207 3.647267 3.193602 3.337230 3.721440 2.918276\n", + "... ... ... ... ... ... ...\n", + "2696 1.068525 1.048294 1.037608 1.213865 1.146129 1.130593\n", + "2697 3.065434 2.261132 2.250670 2.958564 2.569281 2.213821\n", + "2698 3.856602 3.040631 3.061332 3.464577 3.279465 2.826714\n", + "2699 3.093844 2.507351 2.339078 2.978555 2.810894 2.284982\n", + "2700 3.292927 3.566518 2.847978 3.114832 3.668201 2.674032\n", + "\n", + "[2701 rows x 6 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "batched_avg" + "batched" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 4, "metadata": {}, "outputs": [ {