From 320e7da20e9d750390f8e743b7a9498bd071b69d Mon Sep 17 00:00:00 2001
From: James Graham <j.graham@soton.ac.uk>
Date: Thu, 18 Mar 2021 19:44:45 +0000
Subject: [PATCH] refactor: split up main init method

---
 .gitignore           |  1 +
 pycgtool/__main__.py | 33 +++++++++++++++++++--------------
 pycgtool/frame.py    |  2 +-
 3 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/.gitignore b/.gitignore
index 29da604..2b742f2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@
 *.offsets
 
 # Dependencies and runtime
+/.venv/
 /env/
 /minenv/
 /venv/
diff --git a/pycgtool/__main__.py b/pycgtool/__main__.py
index 77f770a..52f8574 100755
--- a/pycgtool/__main__.py
+++ b/pycgtool/__main__.py
@@ -7,6 +7,7 @@ import pathlib
 import sys
 import textwrap
 import time
+import typing
 
 from mdplus.multiscale import GLIMPS
 from rich.logging import RichHandler
@@ -16,6 +17,8 @@ from .mapping import Mapping
 from .bondset import BondSet
 from .forcefield import ForceField
 
+PathLike = typing.Union[pathlib.Path, str]
+
 logger = logging.getLogger(__name__)
 
 
@@ -36,15 +39,7 @@ class PyCGTOOL:
         self.mapping = None
         self.out_frame = self.in_frame
         if self.config.mapping:
-            self.mapping = Mapping(self.config.mapping,
-                                   self.config,
-                                   itp_filename=self.config.itp)
-            self.out_frame = self.mapping.apply(self.in_frame)
-            self.out_frame.save(self.get_output_filepath('gro'),
-                                frame_number=0)
-
-            if self.config.backmapper_resname and self.out_frame.n_frames > 1:
-                self.train_backmapper(self.config.resname)
+            self.mapping, self.out_frame = self.apply_mapping(self.in_frame)
 
         self.bondset = None
         if self.config.bondset:
@@ -54,14 +49,24 @@ class PyCGTOOL:
         if self.config.output_xtc:
             self.out_frame.save(self.get_output_filepath('xtc'))
 
-    def get_output_filepath(self, ext: str) -> pathlib.Path:
-        """Get file path for an output file by extension.
-
-        :param ext:
-        """
+    def get_output_filepath(self, ext: PathLike) -> pathlib.Path:
+        """Get file path for an output file by extension."""
         out_dir = pathlib.Path(self.config.out_dir)
         return out_dir.joinpath(self.config.output_name + '.' + ext)
 
+    def apply_mapping(self, in_frame: Frame) -> typing.Tuple[Mapping, Frame]:
+        """Map input frame to output using requested mapping file."""
+        mapping = Mapping(self.config.mapping,
+                          self.config,
+                          itp_filename=self.config.itp)
+        out_frame = mapping.apply(in_frame)
+        out_frame.save(self.get_output_filepath('gro'), frame_number=0)
+
+        if self.config.backmapper_resname and self.out_frame.n_frames > 1:
+            self.train_backmapper(self.config.resname)
+
+        return mapping, out_frame
+
     def measure_bonds(self) -> None:
         """Measure bonds at the end of a run."""
         self.bondset.apply(self.out_frame)
diff --git a/pycgtool/frame.py b/pycgtool/frame.py
index 8aa7275..06c3e27 100644
--- a/pycgtool/frame.py
+++ b/pycgtool/frame.py
@@ -193,7 +193,7 @@ class Frame:
         return self._topology.add_atom(name, element, residue)
 
     def save(self,
-             filename: str,
+             filename: PathLike,
              frame_number: typing.Optional[int] = None,
              **kwargs) -> None:
         """Write trajctory to file.
-- 
GitLab