Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
dnsmostesting
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Joseph Omar
dnsmostesting
Commits
08274a0e
Commit
08274a0e
authored
1 year ago
by
Joseph Omar
Browse files
Options
Downloads
Patches
Plain Diff
added outputs and dnsmos.py
parent
8b658650
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
dnsmos.py
+91
-0
91 additions, 0 deletions
dnsmos.py
results_viewer.ipynb
+330
-8
330 additions, 8 deletions
results_viewer.ipynb
with
421 additions
and
8 deletions
dnsmos.py
0 → 100644
+
91
−
0
View file @
08274a0e
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
torch
import
onnxruntime
as
ort
import
numpy
as
np
class
DNSMOS
:
SAMPLE_LENGTH
=
9.01
def
__init__
(
self
,
dnsmos_path
:
str
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
dnsmos.onnx
"
),
device
:
Union
[
Tuple
[
str
,
int
],
torch
.
device
]
=
None
,
sample_rate
:
int
=
16000
,
cache_session
:
bool
=
False
):
if
not
os
.
path
.
exists
(
dnsmos_path
):
raise
FileNotFoundError
(
f
"
DNSMOS model not found at
{
dnsmos_path
}
"
)
if
device
is
None
:
self
.
device
=
torch
.
device
(
"
cpu
"
)
self
.
execution_providers
=
[
"
CPUExecutionProvider
"
]
elif
isinstance
(
device
,
tuple
):
self
.
device
=
torch
.
device
(
device
[
0
],
device
[
1
])
elif
isinstance
(
device
,
torch
.
device
):
self
.
device
=
device
else
:
raise
ValueError
(
"
Invalid device argument
"
)
if
sample_rate
not
in
[
16000
,
8000
]:
raise
ValueError
(
f
"
Sample rate
{
sample_rate
}
not supported by DNSMOS. Must be 16000 or 8000
"
)
if
self
.
device
.
type
==
"
cpu
"
:
self
.
execution_providers
=
[
"
CPUExecutionProvider
"
]
else
:
self
.
execution_providers
=
[(
"
CUDAExecutionProvider
"
,
{
"
device_id
"
:
self
.
device
.
index
}),
"
CPUExecutionProvider
"
]
self
.
dnsmos_path
=
dnsmos_path
self
.
sample_rate
=
sample_rate
self
.
cache_session
=
cache_session
if
self
.
cache_session
:
self
.
session
=
ort
.
InferenceSession
(
self
.
dnsmos_path
,
providers
=
self
.
execution_providers
)
self
.
poly_ovr
=
np
.
poly1d
([
-
0.06766283
,
1.11546468
,
0.04602535
])
self
.
poly_sig
=
np
.
poly1d
([
-
0.08397278
,
1.22083953
,
0.0052439
])
self
.
poly_bak
=
np
.
poly1d
([
-
0.13166888
,
1.60915514
,
-
0.39604546
])
def
_split_tensor
(
self
,
denoised
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
if
denoised
.
ndim
==
1
:
denoised
=
denoised
.
unsqueeze
(
0
)
elif
denoised
.
ndim
>
2
:
raise
ValueError
(
"
Tensor must be 1D or 2D
"
)
sample_length
=
int
(
self
.
SAMPLE_LENGTH
*
self
.
sample_rate
)
split_tensor
=
list
(
torch
.
split
(
denoised
,
sample_length
,
dim
=
1
))
# make tail the same length as the rest
end_idx
=
len
(
split_tensor
)
-
1
while
split_tensor
[
end_idx
].
shape
[
1
]
<
sample_length
:
split_tensor
[
end_idx
]
=
torch
.
cat
([
split_tensor
[
end_idx
],
split_tensor
[
0
]],
dim
=
1
)
split_tensor
[
end_idx
]
=
split_tensor
[
end_idx
][:,:
sample_length
]
return
[
sample
.
detach
().
cpu
().
numpy
()
for
sample
in
split_tensor
]
def
__call__
(
self
,
denoised
:
Union
[
torch
.
Tensor
,
np
.
ndarray
])
->
Dict
[
str
,
float
]:
if
self
.
cache_session
:
session
=
self
.
session
else
:
session
=
ort
.
InferenceSession
(
self
.
dnsmos_path
,
providers
=
self
.
execution_providers
)
if
isinstance
(
denoised
,
np
.
ndarray
):
denoised
=
torch
.
from_numpy
(
denoised
)
samples
=
self
.
_split_tensor
(
denoised
)
scores
=
{
"
raw_sig
"
:
[],
"
raw_bak
"
:
[],
"
raw_ovr
"
:
[],
"
sig
"
:
[],
"
bak
"
:
[],
"
ovr
"
:
[]
}
for
sample_number
,
split_sample
in
enumerate
(
samples
):
raw_sig
,
raw_bak
,
raw_ovr
=
session
.
run
(
None
,
{
"
input_1
"
:
split_sample
})[
0
][
0
]
scores
[
"
raw_sig
"
].
append
(
raw_sig
)
scores
[
"
raw_bak
"
].
append
(
raw_bak
)
scores
[
"
raw_ovr
"
].
append
(
raw_ovr
)
scores
[
"
sig
"
].
append
(
self
.
poly_sig
(
raw_sig
))
scores
[
"
bak
"
].
append
(
self
.
poly_bak
(
raw_bak
))
scores
[
"
ovr
"
].
append
(
self
.
poly_ovr
(
raw_ovr
))
for
key
in
scores
:
scores
[
key
]
=
np
.
mean
(
scores
[
key
])
return
scores
This diff is collapsed.
Click to expand it.
results_viewer.ipynb
+
330
−
8
View file @
08274a0e
...
...
@@ -2,29 +2,351 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2
1
,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>raw_sig</th>\n",
" <th>raw_bak</th>\n",
" <th>raw_ovr</th>\n",
" <th>sig</th>\n",
" <th>bak</th>\n",
" <th>ovr</th>\n",
" </tr>\n",
" <tr>\n",
" <th>batch</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3.145465</td>\n",
" <td>2.233926</td>\n",
" <td>2.340730</td>\n",
" <td>2.897365</td>\n",
" <td>2.465761</td>\n",
" <td>2.244499</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3.845078</td>\n",
" <td>2.709487</td>\n",
" <td>2.718259</td>\n",
" <td>3.454555</td>\n",
" <td>2.896134</td>\n",
" <td>2.559756</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2.835112</td>\n",
" <td>2.169795</td>\n",
" <td>2.134030</td>\n",
" <td>2.668490</td>\n",
" <td>2.276028</td>\n",
" <td>2.047455</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2.145720</td>\n",
" <td>1.602384</td>\n",
" <td>1.571615</td>\n",
" <td>2.140288</td>\n",
" <td>1.783102</td>\n",
" <td>1.605862</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.246001</td>\n",
" <td>2.639835</td>\n",
" <td>2.498782</td>\n",
" <td>3.029038</td>\n",
" <td>2.767745</td>\n",
" <td>2.354629</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2696</th>\n",
" <td>3.028775</td>\n",
" <td>2.215597</td>\n",
" <td>2.249891</td>\n",
" <td>2.815935</td>\n",
" <td>2.383634</td>\n",
" <td>2.160208</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2697</th>\n",
" <td>3.218708</td>\n",
" <td>2.586063</td>\n",
" <td>2.498565</td>\n",
" <td>2.979260</td>\n",
" <td>2.751397</td>\n",
" <td>2.358187</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2698</th>\n",
" <td>3.720781</td>\n",
" <td>2.729928</td>\n",
" <td>2.815541</td>\n",
" <td>3.340067</td>\n",
" <td>2.948403</td>\n",
" <td>2.612177</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2699</th>\n",
" <td>3.243291</td>\n",
" <td>2.289145</td>\n",
" <td>2.401669</td>\n",
" <td>2.966480</td>\n",
" <td>2.512433</td>\n",
" <td>2.290272</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2700</th>\n",
" <td>3.730871</td>\n",
" <td>3.455348</td>\n",
" <td>3.125872</td>\n",
" <td>3.391119</td>\n",
" <td>3.571671</td>\n",
" <td>2.869536</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2701 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" raw_sig raw_bak raw_ovr sig bak ovr\n",
"batch \n",
"0 3.145465 2.233926 2.340730 2.897365 2.465761 2.244499\n",
"1 3.845078 2.709487 2.718259 3.454555 2.896134 2.559756\n",
"2 2.835112 2.169795 2.134030 2.668490 2.276028 2.047455\n",
"3 2.145720 1.602384 1.571615 2.140288 1.783102 1.605862\n",
"4 3.246001 2.639835 2.498782 3.029038 2.767745 2.354629\n",
"... ... ... ... ... ... ...\n",
"2696 3.028775 2.215597 2.249891 2.815935 2.383634 2.160208\n",
"2697 3.218708 2.586063 2.498565 2.979260 2.751397 2.358187\n",
"2698 3.720781 2.729928 2.815541 3.340067 2.948403 2.612177\n",
"2699 3.243291 2.289145 2.401669 2.966480 2.512433 2.290272\n",
"2700 3.730871 3.455348 3.125872 3.391119 3.571671 2.869536\n",
"\n",
"[2701 rows x 6 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"individual = pd.read_csv('
~
/dnsmostesting/dnsmos_scores_individual.csv', index_col=0)\n",
"batched = pd.read_csv('
~
/dnsmostesting/dnsmos_scores_batched.csv', index_col=0)\n",
"individual = pd.read_csv('
/se
/dnsmostesting/dnsmos_scores_individual.csv', index_col=0)\n",
"batched = pd.read_csv('
/se
/dnsmostesting/dnsmos_scores_batched.csv', index_col=0)\n",
"individual_avg = individual.groupby('batch').mean()\n",
"individual_avg"
]
},
{
"cell_type": "code",
"execution_count":
null
,
"execution_count":
3
,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>raw_sig</th>\n",
" <th>raw_bak</th>\n",
" <th>raw_ovr</th>\n",
" <th>sig</th>\n",
" <th>bak</th>\n",
" <th>ovr</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3.575741</td>\n",
" <td>3.584303</td>\n",
" <td>3.144969</td>\n",
" <td>3.296980</td>\n",
" <td>3.680074</td>\n",
" <td>2.884885</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3.733906</td>\n",
" <td>2.355249</td>\n",
" <td>2.552172</td>\n",
" <td>3.392991</td>\n",
" <td>2.663522</td>\n",
" <td>2.452156</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3.847642</td>\n",
" <td>2.021054</td>\n",
" <td>2.351104</td>\n",
" <td>3.459435</td>\n",
" <td>2.318322</td>\n",
" <td>2.294580</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.670071</td>\n",
" <td>1.710422</td>\n",
" <td>1.319525</td>\n",
" <td>1.809921</td>\n",
" <td>1.971086</td>\n",
" <td>1.400098</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.641207</td>\n",
" <td>3.647267</td>\n",
" <td>3.193602</td>\n",
" <td>3.337230</td>\n",
" <td>3.721440</td>\n",
" <td>2.918276</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2696</th>\n",
" <td>1.068525</td>\n",
" <td>1.048294</td>\n",
" <td>1.037608</td>\n",
" <td>1.213865</td>\n",
" <td>1.146129</td>\n",
" <td>1.130593</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2697</th>\n",
" <td>3.065434</td>\n",
" <td>2.261132</td>\n",
" <td>2.250670</td>\n",
" <td>2.958564</td>\n",
" <td>2.569281</td>\n",
" <td>2.213821</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2698</th>\n",
" <td>3.856602</td>\n",
" <td>3.040631</td>\n",
" <td>3.061332</td>\n",
" <td>3.464577</td>\n",
" <td>3.279465</td>\n",
" <td>2.826714</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2699</th>\n",
" <td>3.093844</td>\n",
" <td>2.507351</td>\n",
" <td>2.339078</td>\n",
" <td>2.978555</td>\n",
" <td>2.810894</td>\n",
" <td>2.284982</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2700</th>\n",
" <td>3.292927</td>\n",
" <td>3.566518</td>\n",
" <td>2.847978</td>\n",
" <td>3.114832</td>\n",
" <td>3.668201</td>\n",
" <td>2.674032</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2701 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" raw_sig raw_bak raw_ovr sig bak ovr\n",
"0 3.575741 3.584303 3.144969 3.296980 3.680074 2.884885\n",
"1 3.733906 2.355249 2.552172 3.392991 2.663522 2.452156\n",
"2 3.847642 2.021054 2.351104 3.459435 2.318322 2.294580\n",
"3 1.670071 1.710422 1.319525 1.809921 1.971086 1.400098\n",
"4 3.641207 3.647267 3.193602 3.337230 3.721440 2.918276\n",
"... ... ... ... ... ... ...\n",
"2696 1.068525 1.048294 1.037608 1.213865 1.146129 1.130593\n",
"2697 3.065434 2.261132 2.250670 2.958564 2.569281 2.213821\n",
"2698 3.856602 3.040631 3.061332 3.464577 3.279465 2.826714\n",
"2699 3.093844 2.507351 2.339078 2.978555 2.810894 2.284982\n",
"2700 3.292927 3.566518 2.847978 3.114832 3.668201 2.674032\n",
"\n",
"[2701 rows x 6 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batched
_avg
"
"batched"
]
},
{
"cell_type": "code",
"execution_count":
25
,
"execution_count":
4
,
"metadata": {},
"outputs": [
{
...
...
%% Cell type:code id: tags:
```
python
import
pandas
as
pd
individual
=
pd
.
read_csv
(
'
~
/dnsmostesting/dnsmos_scores_individual.csv
'
,
index_col
=
0
)
batched
=
pd
.
read_csv
(
'
~
/dnsmostesting/dnsmos_scores_batched.csv
'
,
index_col
=
0
)
individual
=
pd
.
read_csv
(
'
/se
/dnsmostesting/dnsmos_scores_individual.csv
'
,
index_col
=
0
)
batched
=
pd
.
read_csv
(
'
/se
/dnsmostesting/dnsmos_scores_batched.csv
'
,
index_col
=
0
)
individual_avg
=
individual
.
groupby
(
'
batch
'
).
mean
()
individual_avg
```
%% Output
raw_sig raw_bak raw_ovr sig bak ovr
batch
0 3.145465 2.233926 2.340730 2.897365 2.465761 2.244499
1 3.845078 2.709487 2.718259 3.454555 2.896134 2.559756
2 2.835112 2.169795 2.134030 2.668490 2.276028 2.047455
3 2.145720 1.602384 1.571615 2.140288 1.783102 1.605862
4 3.246001 2.639835 2.498782 3.029038 2.767745 2.354629
... ... ... ... ... ... ...
2696 3.028775 2.215597 2.249891 2.815935 2.383634 2.160208
2697 3.218708 2.586063 2.498565 2.979260 2.751397 2.358187
2698 3.720781 2.729928 2.815541 3.340067 2.948403 2.612177
2699 3.243291 2.289145 2.401669 2.966480 2.512433 2.290272
2700 3.730871 3.455348 3.125872 3.391119 3.571671 2.869536
[2701 rows x 6 columns]
%% Cell type:code id: tags:
```
python
batched
_avg
batched
```
%% Output
raw_sig raw_bak raw_ovr sig bak ovr
0 3.575741 3.584303 3.144969 3.296980 3.680074 2.884885
1 3.733906 2.355249 2.552172 3.392991 2.663522 2.452156
2 3.847642 2.021054 2.351104 3.459435 2.318322 2.294580
3 1.670071 1.710422 1.319525 1.809921 1.971086 1.400098
4 3.641207 3.647267 3.193602 3.337230 3.721440 2.918276
... ... ... ... ... ... ...
2696 1.068525 1.048294 1.037608 1.213865 1.146129 1.130593
2697 3.065434 2.261132 2.250670 2.958564 2.569281 2.213821
2698 3.856602 3.040631 3.061332 3.464577 3.279465 2.826714
2699 3.093844 2.507351 2.339078 2.978555 2.810894 2.284982
2700 3.292927 3.566518 2.847978 3.114832 3.668201 2.674032
[2701 rows x 6 columns]
%% Cell type:code id: tags:
```
python
import
matplotlib.pyplot
as
plt
# Iterate over each column in batched and individual
for
column
in
batched
.
columns
:
# Create a new figure and axes for each plot
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
6
,
6
))
# Create a box plot for the column in batched
ax
.
boxplot
(
batched
[
column
],
positions
=
[
0
],
widths
=
0.6
,
patch_artist
=
True
,
boxprops
=
dict
(
facecolor
=
'
blue
'
))
# Create a box plot for the column in individual
ax
.
boxplot
(
individual
[
column
],
positions
=
[
1
],
widths
=
0.6
,
patch_artist
=
True
,
boxprops
=
dict
(
facecolor
=
'
orange
'
))
# Set the x-axis labels
ax
.
set_xticks
([
0
,
1
])
ax
.
set_xticklabels
([
'
batched
'
,
'
individual
'
])
# Set the y-axis label
ax
.
set_ylabel
(
column
)
# Set the title of the figure
ax
.
set_title
(
f
'
Box Plot Comparison:
{
column
}
'
)
# Add minor ticks and gridlines
ax
.
minorticks_on
()
ax
.
grid
(
which
=
'
both
'
,
linestyle
=
'
:
'
,
linewidth
=
'
0.5
'
,
color
=
'
gray
'
)
# Show the plot
plt
.
show
()
```
%% Output
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment