diff --git a/dnsmos.py b/dnsmos.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa11ce8d503cd3e470221f10ed7ff2d9ba702eab
--- /dev/null
+++ b/dnsmos.py
@@ -0,0 +1,91 @@
+import os
+from typing import Dict, List, Tuple, Union
+import torch
+import onnxruntime as ort
+import numpy as np
+
+class DNSMOS:
+    SAMPLE_LENGTH = 9.01
+    def __init__(self, dnsmos_path: str = os.path.join(os.path.dirname(__file__), "dnsmos.onnx"), device: Union[Tuple[str, int], torch.device] = None, sample_rate: int = 16000, cache_session: bool = False):
+        if not os.path.exists(dnsmos_path):
+            raise FileNotFoundError(f"DNSMOS model not found at {dnsmos_path}")
+               
+        if device is None:
+            self.device = torch.device("cpu")
+            self.execution_providers = ["CPUExecutionProvider"]
+        elif isinstance(device, tuple):
+            self.device = torch.device(device[0], device[1])
+        elif isinstance(device, torch.device):
+            self.device = device
+        else:
+            raise ValueError("Invalid device argument")
+        
+        if sample_rate not in [16000, 8000]:
+            raise ValueError(f"Sample rate {sample_rate} not supported by DNSMOS. Must be 16000 or 8000")
+        
+        
+        
+        if self.device.type == "cpu":
+            self.execution_providers = ["CPUExecutionProvider"]
+        else:
+            self.execution_providers = [("CUDAExecutionProvider", {"device_id": self.device.index}), "CPUExecutionProvider"]
+        
+        self.dnsmos_path = dnsmos_path
+        self.sample_rate = sample_rate
+        self.cache_session = cache_session
+        if self.cache_session:
+            self.session = ort.InferenceSession(self.dnsmos_path, providers=self.execution_providers)
+            
+            
+        self.poly_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
+        self.poly_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439])
+        self.poly_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])
+        
+    def _split_tensor(self, denoised: torch.Tensor) -> List[torch.Tensor]:
+        if denoised.ndim == 1:
+            denoised = denoised.unsqueeze(0)
+        elif denoised.ndim > 2:
+            raise ValueError("Tensor must be 1D or 2D")
+
+        sample_length = int(self.SAMPLE_LENGTH * self.sample_rate)
+        split_tensor = list(torch.split(denoised, sample_length, dim=1))
+        # make tail the same length as the rest
+        end_idx = len(split_tensor) - 1
+        while split_tensor[end_idx].shape[1] < sample_length:
+            split_tensor[end_idx] = torch.cat([split_tensor[end_idx],
+                                                  split_tensor[0]], dim=1)
+            split_tensor[end_idx] = split_tensor[end_idx][:,:sample_length]
+        return [sample.detach().cpu().numpy() for sample in split_tensor]
+        
+        
+    def __call__(self, denoised: Union[torch.Tensor, np.ndarray]) -> Dict[str, float]:
+        if self.cache_session:
+            session = self.session
+        else:
+            session = ort.InferenceSession(self.dnsmos_path, providers=self.execution_providers)
+        
+        if isinstance(denoised, np.ndarray):
+            denoised = torch.from_numpy(denoised)
+        
+        samples = self._split_tensor(denoised)
+        scores = {
+            "raw_sig": [],
+            "raw_bak": [],
+            "raw_ovr": [],
+            "sig": [],
+            "bak": [],
+            "ovr": []
+        }
+        
+        for sample_number, split_sample in enumerate(samples):
+            raw_sig, raw_bak, raw_ovr = session.run(None, {"input_1": split_sample})[0][0]
+            scores["raw_sig"].append(raw_sig)
+            scores["raw_bak"].append(raw_bak)
+            scores["raw_ovr"].append(raw_ovr)
+            scores["sig"].append(self.poly_sig(raw_sig))
+            scores["bak"].append(self.poly_bak(raw_bak))
+            scores["ovr"].append(self.poly_ovr(raw_ovr))
+            
+        for key in scores:
+            scores[key] = np.mean(scores[key])
+        return scores
diff --git a/results_viewer.ipynb b/results_viewer.ipynb
index d2cb103945830b50a20aa51e3e80e618212634b6..7d9cd6423c52c2c86d23ec37f7bf66c02d96d9ee 100644
--- a/results_viewer.ipynb
+++ b/results_viewer.ipynb
@@ -2,29 +2,351 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 2,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>raw_sig</th>\n",
+       "      <th>raw_bak</th>\n",
+       "      <th>raw_ovr</th>\n",
+       "      <th>sig</th>\n",
+       "      <th>bak</th>\n",
+       "      <th>ovr</th>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>batch</th>\n",
+       "      <th></th>\n",
+       "      <th></th>\n",
+       "      <th></th>\n",
+       "      <th></th>\n",
+       "      <th></th>\n",
+       "      <th></th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>3.145465</td>\n",
+       "      <td>2.233926</td>\n",
+       "      <td>2.340730</td>\n",
+       "      <td>2.897365</td>\n",
+       "      <td>2.465761</td>\n",
+       "      <td>2.244499</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>3.845078</td>\n",
+       "      <td>2.709487</td>\n",
+       "      <td>2.718259</td>\n",
+       "      <td>3.454555</td>\n",
+       "      <td>2.896134</td>\n",
+       "      <td>2.559756</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>2.835112</td>\n",
+       "      <td>2.169795</td>\n",
+       "      <td>2.134030</td>\n",
+       "      <td>2.668490</td>\n",
+       "      <td>2.276028</td>\n",
+       "      <td>2.047455</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>2.145720</td>\n",
+       "      <td>1.602384</td>\n",
+       "      <td>1.571615</td>\n",
+       "      <td>2.140288</td>\n",
+       "      <td>1.783102</td>\n",
+       "      <td>1.605862</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>3.246001</td>\n",
+       "      <td>2.639835</td>\n",
+       "      <td>2.498782</td>\n",
+       "      <td>3.029038</td>\n",
+       "      <td>2.767745</td>\n",
+       "      <td>2.354629</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>...</th>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2696</th>\n",
+       "      <td>3.028775</td>\n",
+       "      <td>2.215597</td>\n",
+       "      <td>2.249891</td>\n",
+       "      <td>2.815935</td>\n",
+       "      <td>2.383634</td>\n",
+       "      <td>2.160208</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2697</th>\n",
+       "      <td>3.218708</td>\n",
+       "      <td>2.586063</td>\n",
+       "      <td>2.498565</td>\n",
+       "      <td>2.979260</td>\n",
+       "      <td>2.751397</td>\n",
+       "      <td>2.358187</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2698</th>\n",
+       "      <td>3.720781</td>\n",
+       "      <td>2.729928</td>\n",
+       "      <td>2.815541</td>\n",
+       "      <td>3.340067</td>\n",
+       "      <td>2.948403</td>\n",
+       "      <td>2.612177</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2699</th>\n",
+       "      <td>3.243291</td>\n",
+       "      <td>2.289145</td>\n",
+       "      <td>2.401669</td>\n",
+       "      <td>2.966480</td>\n",
+       "      <td>2.512433</td>\n",
+       "      <td>2.290272</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2700</th>\n",
+       "      <td>3.730871</td>\n",
+       "      <td>3.455348</td>\n",
+       "      <td>3.125872</td>\n",
+       "      <td>3.391119</td>\n",
+       "      <td>3.571671</td>\n",
+       "      <td>2.869536</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "<p>2701 rows × 6 columns</p>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "        raw_sig   raw_bak   raw_ovr       sig       bak       ovr\n",
+       "batch                                                            \n",
+       "0      3.145465  2.233926  2.340730  2.897365  2.465761  2.244499\n",
+       "1      3.845078  2.709487  2.718259  3.454555  2.896134  2.559756\n",
+       "2      2.835112  2.169795  2.134030  2.668490  2.276028  2.047455\n",
+       "3      2.145720  1.602384  1.571615  2.140288  1.783102  1.605862\n",
+       "4      3.246001  2.639835  2.498782  3.029038  2.767745  2.354629\n",
+       "...         ...       ...       ...       ...       ...       ...\n",
+       "2696   3.028775  2.215597  2.249891  2.815935  2.383634  2.160208\n",
+       "2697   3.218708  2.586063  2.498565  2.979260  2.751397  2.358187\n",
+       "2698   3.720781  2.729928  2.815541  3.340067  2.948403  2.612177\n",
+       "2699   3.243291  2.289145  2.401669  2.966480  2.512433  2.290272\n",
+       "2700   3.730871  3.455348  3.125872  3.391119  3.571671  2.869536\n",
+       "\n",
+       "[2701 rows x 6 columns]"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "import pandas as pd\n",
-    "individual = pd.read_csv('~/dnsmostesting/dnsmos_scores_individual.csv', index_col=0)\n",
-    "batched = pd.read_csv('~/dnsmostesting/dnsmos_scores_batched.csv', index_col=0)\n",
+    "individual = pd.read_csv('/se/dnsmostesting/dnsmos_scores_individual.csv', index_col=0)\n",
+    "batched = pd.read_csv('/se/dnsmostesting/dnsmos_scores_batched.csv', index_col=0)\n",
     "individual_avg = individual.groupby('batch').mean()\n",
     "individual_avg"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 3,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>raw_sig</th>\n",
+       "      <th>raw_bak</th>\n",
+       "      <th>raw_ovr</th>\n",
+       "      <th>sig</th>\n",
+       "      <th>bak</th>\n",
+       "      <th>ovr</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>3.575741</td>\n",
+       "      <td>3.584303</td>\n",
+       "      <td>3.144969</td>\n",
+       "      <td>3.296980</td>\n",
+       "      <td>3.680074</td>\n",
+       "      <td>2.884885</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>3.733906</td>\n",
+       "      <td>2.355249</td>\n",
+       "      <td>2.552172</td>\n",
+       "      <td>3.392991</td>\n",
+       "      <td>2.663522</td>\n",
+       "      <td>2.452156</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>3.847642</td>\n",
+       "      <td>2.021054</td>\n",
+       "      <td>2.351104</td>\n",
+       "      <td>3.459435</td>\n",
+       "      <td>2.318322</td>\n",
+       "      <td>2.294580</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>1.670071</td>\n",
+       "      <td>1.710422</td>\n",
+       "      <td>1.319525</td>\n",
+       "      <td>1.809921</td>\n",
+       "      <td>1.971086</td>\n",
+       "      <td>1.400098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>3.641207</td>\n",
+       "      <td>3.647267</td>\n",
+       "      <td>3.193602</td>\n",
+       "      <td>3.337230</td>\n",
+       "      <td>3.721440</td>\n",
+       "      <td>2.918276</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>...</th>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2696</th>\n",
+       "      <td>1.068525</td>\n",
+       "      <td>1.048294</td>\n",
+       "      <td>1.037608</td>\n",
+       "      <td>1.213865</td>\n",
+       "      <td>1.146129</td>\n",
+       "      <td>1.130593</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2697</th>\n",
+       "      <td>3.065434</td>\n",
+       "      <td>2.261132</td>\n",
+       "      <td>2.250670</td>\n",
+       "      <td>2.958564</td>\n",
+       "      <td>2.569281</td>\n",
+       "      <td>2.213821</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2698</th>\n",
+       "      <td>3.856602</td>\n",
+       "      <td>3.040631</td>\n",
+       "      <td>3.061332</td>\n",
+       "      <td>3.464577</td>\n",
+       "      <td>3.279465</td>\n",
+       "      <td>2.826714</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2699</th>\n",
+       "      <td>3.093844</td>\n",
+       "      <td>2.507351</td>\n",
+       "      <td>2.339078</td>\n",
+       "      <td>2.978555</td>\n",
+       "      <td>2.810894</td>\n",
+       "      <td>2.284982</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2700</th>\n",
+       "      <td>3.292927</td>\n",
+       "      <td>3.566518</td>\n",
+       "      <td>2.847978</td>\n",
+       "      <td>3.114832</td>\n",
+       "      <td>3.668201</td>\n",
+       "      <td>2.674032</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "<p>2701 rows × 6 columns</p>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "       raw_sig   raw_bak   raw_ovr       sig       bak       ovr\n",
+       "0     3.575741  3.584303  3.144969  3.296980  3.680074  2.884885\n",
+       "1     3.733906  2.355249  2.552172  3.392991  2.663522  2.452156\n",
+       "2     3.847642  2.021054  2.351104  3.459435  2.318322  2.294580\n",
+       "3     1.670071  1.710422  1.319525  1.809921  1.971086  1.400098\n",
+       "4     3.641207  3.647267  3.193602  3.337230  3.721440  2.918276\n",
+       "...        ...       ...       ...       ...       ...       ...\n",
+       "2696  1.068525  1.048294  1.037608  1.213865  1.146129  1.130593\n",
+       "2697  3.065434  2.261132  2.250670  2.958564  2.569281  2.213821\n",
+       "2698  3.856602  3.040631  3.061332  3.464577  3.279465  2.826714\n",
+       "2699  3.093844  2.507351  2.339078  2.978555  2.810894  2.284982\n",
+       "2700  3.292927  3.566518  2.847978  3.114832  3.668201  2.674032\n",
+       "\n",
+       "[2701 rows x 6 columns]"
+      ]
+     },
+     "execution_count": 3,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
-    "batched_avg"
+    "batched"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [
     {