Skip to content
Snippets Groups Projects
Commit 08274a0e authored by Joseph Omar's avatar Joseph Omar
Browse files

added outputs and dnsmos.py

parent 8b658650
No related branches found
No related tags found
No related merge requests found
import os
from typing import Dict, List, Tuple, Union
import torch
import onnxruntime as ort
import numpy as np
class DNSMOS:
SAMPLE_LENGTH = 9.01
def __init__(self, dnsmos_path: str = os.path.join(os.path.dirname(__file__), "dnsmos.onnx"), device: Union[Tuple[str, int], torch.device] = None, sample_rate: int = 16000, cache_session: bool = False):
if not os.path.exists(dnsmos_path):
raise FileNotFoundError(f"DNSMOS model not found at {dnsmos_path}")
if device is None:
self.device = torch.device("cpu")
self.execution_providers = ["CPUExecutionProvider"]
elif isinstance(device, tuple):
self.device = torch.device(device[0], device[1])
elif isinstance(device, torch.device):
self.device = device
else:
raise ValueError("Invalid device argument")
if sample_rate not in [16000, 8000]:
raise ValueError(f"Sample rate {sample_rate} not supported by DNSMOS. Must be 16000 or 8000")
if self.device.type == "cpu":
self.execution_providers = ["CPUExecutionProvider"]
else:
self.execution_providers = [("CUDAExecutionProvider", {"device_id": self.device.index}), "CPUExecutionProvider"]
self.dnsmos_path = dnsmos_path
self.sample_rate = sample_rate
self.cache_session = cache_session
if self.cache_session:
self.session = ort.InferenceSession(self.dnsmos_path, providers=self.execution_providers)
self.poly_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
self.poly_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439])
self.poly_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])
def _split_tensor(self, denoised: torch.Tensor) -> List[torch.Tensor]:
if denoised.ndim == 1:
denoised = denoised.unsqueeze(0)
elif denoised.ndim > 2:
raise ValueError("Tensor must be 1D or 2D")
sample_length = int(self.SAMPLE_LENGTH * self.sample_rate)
split_tensor = list(torch.split(denoised, sample_length, dim=1))
# make tail the same length as the rest
end_idx = len(split_tensor) - 1
while split_tensor[end_idx].shape[1] < sample_length:
split_tensor[end_idx] = torch.cat([split_tensor[end_idx],
split_tensor[0]], dim=1)
split_tensor[end_idx] = split_tensor[end_idx][:,:sample_length]
return [sample.detach().cpu().numpy() for sample in split_tensor]
def __call__(self, denoised: Union[torch.Tensor, np.ndarray]) -> Dict[str, float]:
if self.cache_session:
session = self.session
else:
session = ort.InferenceSession(self.dnsmos_path, providers=self.execution_providers)
if isinstance(denoised, np.ndarray):
denoised = torch.from_numpy(denoised)
samples = self._split_tensor(denoised)
scores = {
"raw_sig": [],
"raw_bak": [],
"raw_ovr": [],
"sig": [],
"bak": [],
"ovr": []
}
for sample_number, split_sample in enumerate(samples):
raw_sig, raw_bak, raw_ovr = session.run(None, {"input_1": split_sample})[0][0]
scores["raw_sig"].append(raw_sig)
scores["raw_bak"].append(raw_bak)
scores["raw_ovr"].append(raw_ovr)
scores["sig"].append(self.poly_sig(raw_sig))
scores["bak"].append(self.poly_bak(raw_bak))
scores["ovr"].append(self.poly_ovr(raw_ovr))
for key in scores:
scores[key] = np.mean(scores[key])
return scores
%% Cell type:code id: tags:
``` python
import pandas as pd
individual = pd.read_csv('~/dnsmostesting/dnsmos_scores_individual.csv', index_col=0)
batched = pd.read_csv('~/dnsmostesting/dnsmos_scores_batched.csv', index_col=0)
individual = pd.read_csv('/se/dnsmostesting/dnsmos_scores_individual.csv', index_col=0)
batched = pd.read_csv('/se/dnsmostesting/dnsmos_scores_batched.csv', index_col=0)
individual_avg = individual.groupby('batch').mean()
individual_avg
```
%% Output
raw_sig raw_bak raw_ovr sig bak ovr
batch
0 3.145465 2.233926 2.340730 2.897365 2.465761 2.244499
1 3.845078 2.709487 2.718259 3.454555 2.896134 2.559756
2 2.835112 2.169795 2.134030 2.668490 2.276028 2.047455
3 2.145720 1.602384 1.571615 2.140288 1.783102 1.605862
4 3.246001 2.639835 2.498782 3.029038 2.767745 2.354629
... ... ... ... ... ... ...
2696 3.028775 2.215597 2.249891 2.815935 2.383634 2.160208
2697 3.218708 2.586063 2.498565 2.979260 2.751397 2.358187
2698 3.720781 2.729928 2.815541 3.340067 2.948403 2.612177
2699 3.243291 2.289145 2.401669 2.966480 2.512433 2.290272
2700 3.730871 3.455348 3.125872 3.391119 3.571671 2.869536
[2701 rows x 6 columns]
%% Cell type:code id: tags:
``` python
batched_avg
batched
```
%% Output
raw_sig raw_bak raw_ovr sig bak ovr
0 3.575741 3.584303 3.144969 3.296980 3.680074 2.884885
1 3.733906 2.355249 2.552172 3.392991 2.663522 2.452156
2 3.847642 2.021054 2.351104 3.459435 2.318322 2.294580
3 1.670071 1.710422 1.319525 1.809921 1.971086 1.400098
4 3.641207 3.647267 3.193602 3.337230 3.721440 2.918276
... ... ... ... ... ... ...
2696 1.068525 1.048294 1.037608 1.213865 1.146129 1.130593
2697 3.065434 2.261132 2.250670 2.958564 2.569281 2.213821
2698 3.856602 3.040631 3.061332 3.464577 3.279465 2.826714
2699 3.093844 2.507351 2.339078 2.978555 2.810894 2.284982
2700 3.292927 3.566518 2.847978 3.114832 3.668201 2.674032
[2701 rows x 6 columns]
%% Cell type:code id: tags:
``` python
import matplotlib.pyplot as plt
# Iterate over each column in batched and individual
for column in batched.columns:
# Create a new figure and axes for each plot
fig, ax = plt.subplots(figsize=(6, 6))
# Create a box plot for the column in batched
ax.boxplot(batched[column], positions=[0], widths=0.6, patch_artist=True, boxprops=dict(facecolor='blue'))
# Create a box plot for the column in individual
ax.boxplot(individual[column], positions=[1], widths=0.6, patch_artist=True, boxprops=dict(facecolor='orange'))
# Set the x-axis labels
ax.set_xticks([0, 1])
ax.set_xticklabels(['batched', 'individual'])
# Set the y-axis label
ax.set_ylabel(column)
# Set the title of the figure
ax.set_title(f'Box Plot Comparison: {column}')
# Add minor ticks and gridlines
ax.minorticks_on()
ax.grid(which='both', linestyle=':', linewidth='0.5', color='gray')
# Show the plot
plt.show()
```
%% Output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment