From 7f1b1ae254808cdcd31274717008de06993da797 Mon Sep 17 00:00:00 2001
From: James Graham <J.A.Graham@soton.ac.uk>
Date: Fri, 10 Jun 2016 11:59:58 +0100
Subject: [PATCH] Working XTC output CGFrame now persistent - carries XTC
 buffer

---
 pycgtool.py         | 13 +++++++----
 pycgtool/frame.py   | 57 +++++++++++++++++++++++++++++++++++++++++----
 pycgtool/mapping.py |  6 +++--
 test/test_frame.py  | 19 ++++++++++++++-
 4 files changed, 82 insertions(+), 13 deletions(-)

diff --git a/pycgtool.py b/pycgtool.py
index e779d86..4512c52 100755
--- a/pycgtool.py
+++ b/pycgtool.py
@@ -32,9 +32,9 @@ def main(args, config):
     numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
     for _ in Progress(numframes, postwhile=frame.next_frame):
         if args.map:
-            cgframe = mapping.apply(frame, exclude={"SOL"})
+            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
             if config.output_xtc:
-                cgframe.write_xtc("out.xtc")
+                cgframe.write_to_xtc_buffer()
         else:
             cgframe = frame
 
@@ -53,6 +53,9 @@ def main(args, config):
         if config.dump_measurements:
             bonds.dump_values(config.dump_n_values)
 
+    if args.map and config.output_xtc:
+        cgframe.flush_xtc_buffer("out.xtc")
+
 
 def map_only(args, config):
     """
@@ -69,9 +72,9 @@ def map_only(args, config):
     if args.xtc and config.output_xtc:
         numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
         for _ in Progress(numframes, postwhile=frame.next_frame):
-            cgframe = mapping.apply(frame, exclude={"SOL"})
-            cgframe.write_xtc("out.xtc")
-
+            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
+            cgframe.write_to_xtc_buffer()
+        cgframe.flush_xtc_buffer("out.xtc")
 
 
 if __name__ == "__main__":
diff --git a/pycgtool/frame.py b/pycgtool/frame.py
index e6f2d9e..0edb531 100644
--- a/pycgtool/frame.py
+++ b/pycgtool/frame.py
@@ -122,7 +122,7 @@ class Frame:
         self.numframes = 0
         self.box = np.zeros(3, dtype=np.float32)
 
-        self._out_xtc = None
+        self._xtc_buffer = None
 
         if gro is not None:
             self._parse_gro(gro)
@@ -257,9 +257,52 @@ class Frame:
         except (IndexError, AttributeError):
             return False
 
-    def write_xtc(self, filename=None):
-        if self._out_xtc is None:
-           self._out_xtc = XTCTrajectoryFile(filename, mode="w")
+    class XTCBuffer:
+        def __init__(self):
+            self.coords = []
+            self.step = []
+            self.box = []
+
+        def __call__(self):
+            return (np.array(self.coords, dtype=np.float32),
+                    np.array(self.step,   dtype=np.int32),
+                    np.array(self.box,    dtype=np.float32))
+
+        def append(self, coords, step, box):
+            self.coords.append(coords)
+            self.step.append(step)
+            self.box.append(box)
+
+    def write_to_xtc_buffer(self):
+        if self._xtc_buffer is None:
+            self._xtc_buffer = self.XTCBuffer()
+
+        xyz = np.ndarray((self.natoms, 3))
+        i = 0
+        for residue in self.residues:
+            for atom in residue.atoms:
+                xyz[i] = atom.coords
+                i += 1
+
+        box = np.zeros((3, 3), dtype=np.float32)
+        for i in range(3):
+            box[i][i] = self.box[i]
+
+        self._xtc_buffer.append(xyz, self.number, box)
+
+    def flush_xtc_buffer(self, filename):
+        if self._xtc_buffer is not None:
+            xtc = XTCTrajectoryFile(filename, mode="w")
+
+            xyz, step, box = self._xtc_buffer()
+            xtc.write(xyz, step=step, box=box)
+            xtc.close()
+
+            self._xtc_buffer = None
+
+    def write_xtc(self, filename):
+        if self._xtc_buffer is None:
+            self._xtc_buffer = XTCTrajectoryFile(filename, mode="w")
 
         xyz = np.ndarray((1, self.natoms, 3), dtype=np.float32)
         i = 0
@@ -267,11 +310,15 @@ class Frame:
             for atom in residue.atoms:
                 xyz[0][i] = atom.coords
                 i += 1
+
         step = np.array([self.number], dtype=np.int32)
+
         box = np.zeros((1, 3, 3), dtype=np.float32)
         for i in range(3):
             box[0][i][i] = self.box[i]
-        self._out_xtc.write(xyz, step=step, box=box)
+
+        self._xtc_buffer.write(xyz, step=step, box=box)
+        # self._xtc_buffer.close()
 
     def _parse_gro(self, filename):
         """
diff --git a/pycgtool/mapping.py b/pycgtool/mapping.py
index ce92e63..d0f3e78 100644
--- a/pycgtool/mapping.py
+++ b/pycgtool/mapping.py
@@ -116,7 +116,7 @@ class Mapping:
     def __iter__(self):
         return iter(self._mappings)
 
-    def apply(self, frame, exclude=None):
+    def apply(self, frame, cgframe=None, exclude=None):
         """
         Apply the AA->CG mapping to an atomistic Frame.
 
@@ -124,8 +124,10 @@ class Mapping:
         :param exclude: Set of molecule names to exclude from mapping - e.g. solvent
         :return: A new Frame instance containing the CG frame
         """
-        cgframe = Frame()
+        if cgframe is None:
+            cgframe = Frame()
         cgframe.name = frame.name
+        cgframe.number = frame.number
         cgframe.box = frame.box
         cgframe.natoms = 0
         cgframe.residues = []
diff --git a/test/test_frame.py b/test/test_frame.py
index f46d87b..7c40f3d 100644
--- a/test/test_frame.py
+++ b/test/test_frame.py
@@ -6,6 +6,12 @@ import numpy as np
 
 from pycgtool.frame import Atom, Residue, Frame
 
+try:
+    import mdtraj
+    mdtraj_present = True
+except ImportError:
+    mdtraj_present = False
+
 
 class AtomTest(unittest.TestCase):
     def test_atom_create(self):
@@ -82,11 +88,13 @@ class FrameTest(unittest.TestCase):
         np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]),
                                    frame.residues[0].atoms[0].coords)
 
+    @unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
     def test_frame_read_xtc_mdtraj_numframes(self):
         frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
                       xtc_reader="mdtraj")
         self.assertEqual(12, frame.numframes)
 
+    @unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
     def test_frame_read_xtc_mdtraj(self):
         frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
                       xtc_reader="mdtraj")
@@ -101,11 +109,20 @@ class FrameTest(unittest.TestCase):
         np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]),
                                    frame.residues[0].atoms[0].coords)
 
+    @unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
+    def test_frame_write_xtc_buffered_mdtraj(self):
+        frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
+                      xtc_reader="mdtraj")
+        while frame.next_frame():
+            frame.write_to_xtc_buffer()
+        frame.flush_xtc_buffer("test.xtc")
+
+    @unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
     def test_frame_write_xtc_mdtraj(self):
         frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
                       xtc_reader="mdtraj")
         while frame.next_frame():
-            frame.write_xtc("test.xtc")
+            frame.write_xtc("test2.xtc")
 
 
 if __name__ == '__main__':
-- 
GitLab