From cfc314f34e4bcf01ffb279487e2d0fb0af481f41 Mon Sep 17 00:00:00 2001
From: James Graham <J.A.Graham@soton.ac.uk>
Date: Fri, 16 Dec 2016 16:54:02 +0000
Subject: [PATCH] Remove 'exclude' from Frame, Mapping and Bondset up to 10%
 performance increase

---
 pycgtool/bondset.py  |  5 +----
 pycgtool/frame.py    | 38 +++++++++++++++-----------------------
 pycgtool/mapping.py  | 10 +++-------
 pycgtool/pycgtool.py | 10 +++++-----
 test/data/sugar.bnd  |  2 --
 test/test_bondset.py | 16 +++++++---------
 6 files changed, 31 insertions(+), 50 deletions(-)

diff --git a/pycgtool/bondset.py b/pycgtool/bondset.py
index 1768c8d..d5af198 100644
--- a/pycgtool/bondset.py
+++ b/pycgtool/bondset.py
@@ -260,13 +260,12 @@ class BondSet:
                     e.args = ("Bead(s) {0} do(es) not exist in residue {1}".format(missing, mol),)
                     raise
 
-    def write_itp(self, filename, mapping, exclude=set()):
+    def write_itp(self, filename, mapping):
         """
         Output a GROMACS .itp file containing atoms/beads and bonded terms.
 
         :param filename: Name of output file
         :param mapping: AA->CG Mapping from which to collect bead properties
-        :param exclude: Set of molecule names to be excluded from itp
         """
         self._populate_atom_numbers(mapping)
         backup_file(filename)
@@ -296,8 +295,6 @@ class BondSet:
             # Print molecule
             not_calc = "  Parameters have not been calculated."
             for mol in self._molecules:
-                if mol in exclude:
-                    continue
                 if mol not in mapping:
                     logger.warning("Molecule '{0}' present in bonding file, but not in mapping.".format(mol) + not_calc)
                     continue
diff --git a/pycgtool/frame.py b/pycgtool/frame.py
index b7a959b..a0a103d 100644
--- a/pycgtool/frame.py
+++ b/pycgtool/frame.py
@@ -8,6 +8,7 @@ Both Frame and Residue are iterable. Residue is indexable with either atom numbe
 import os
 import abc
 import logging
+import itertools
 
 import numpy as np
 
@@ -104,16 +105,11 @@ class Residue:
 
 
 class FrameReader(metaclass=abc.ABCMeta):
-    def __init__(self, topname, trajname=None, exclude=None, frame_start=0):
+    def __init__(self, topname, trajname=None, frame_start=0):
         self._topname = topname
         self._trajname = trajname
         self._frame_number = frame_start
 
-        if exclude is not None:
-            self._exclude = exclude
-        else:
-            self._exclude = set()
-
         self.num_atoms = 0
         self.num_frames = 0
 
@@ -127,21 +123,13 @@ class FrameReader(metaclass=abc.ABCMeta):
         return result
 
     def read_frame_number(self, number, frame):
-        if self._trajname is None:
-            return False
         try:
             time, coords, box = self._read_frame_number(number)
             frame.time = time
             frame.box = box
 
-            i = 0
-            for res in frame.residues:
-                if res.name in self._exclude:
-                    i += len(res.atoms)
-                    continue
-                for atom in res:
-                    atom.coords = coords[i]
-                    i += 1
+            for atom, coord_line in zip(itertools.chain.from_iterable(frame.residues), coords):
+                atom.coords = coord_line
 
         except (IndexError, AttributeError):
             # IndexError - run out of xtc frames
@@ -159,7 +147,7 @@ class FrameReader(metaclass=abc.ABCMeta):
 
 
 class FrameReaderSimpleTraj(FrameReader):
-    def __init__(self, topname, trajname=None, exclude=None, frame_start=0):
+    def __init__(self, topname, trajname=None, frame_start=0):
         """
         Open input XTC file from which to read coordinates using simpletraj library.
 
@@ -167,7 +155,7 @@ class FrameReaderSimpleTraj(FrameReader):
         :param trajname: MD trajectory file to read subsequent frames
         :param frame_start: Frame number to start on, default 0
         """
-        FrameReader.__init__(self, topname, trajname, exclude, frame_start)
+        FrameReader.__init__(self, topname, trajname, frame_start)
 
         from simpletraj import trajectory
 
@@ -229,14 +217,14 @@ class FrameReaderSimpleTraj(FrameReader):
 
 
 class FrameReaderMDTraj(FrameReader):
-    def __init__(self, topname, trajname=None, exclude=None, frame_start=0):
+    def __init__(self, topname, trajname=None, frame_start=0):
         """
         Open input XTC file from which to read coordinates using mdtraj library.
 
         :param topname: GROMACS GRO file from which to read topology
         :param trajname: GROMACS XTC file to read subsequent frames
         """
-        FrameReader.__init__(self, topname, trajname, exclude, frame_start)
+        FrameReader.__init__(self, topname, trajname, frame_start)
 
         try:
             import mdtraj
@@ -307,7 +295,7 @@ class Frame:
     """
     Hold Atom data separated into Residues
     """
-    def __init__(self, gro=None, xtc=None, itp=None, exclude=None, frame_start=0, xtc_reader="simpletraj"):
+    def __init__(self, gro=None, xtc=None, itp=None, frame_start=0, xtc_reader="simpletraj"):
         """
         Return Frame instance having read Residues and Atoms from GRO if provided
 
@@ -329,8 +317,7 @@ class Frame:
             open_xtc = {"simpletraj": FrameReaderSimpleTraj,
                         "mdtraj":     FrameReaderMDTraj}
             try:
-                self._trajreader = open_xtc[xtc_reader](gro, xtc, exclude=exclude,
-                                                        frame_start=frame_start)
+                self._trajreader = open_xtc[xtc_reader](gro, xtc, frame_start=frame_start)
             except KeyError as e:
                 e.args = ("XTC reader {0} is not a valid option.".format(xtc_reader))
                 raise
@@ -362,6 +349,11 @@ class Frame:
         rep += "\n".join(atoms)
         return rep
 
+    def yield_resname_in(self, container):
+        for res in self:
+            if res.name in container:
+                yield res
+
     def next_frame(self):
         """
         Read next frame from input XTC.
diff --git a/pycgtool/mapping.py b/pycgtool/mapping.py
index 64abd39..83204ca 100644
--- a/pycgtool/mapping.py
+++ b/pycgtool/mapping.py
@@ -212,13 +212,12 @@ class Mapping:
 
         return cgframe
 
-    def apply(self, frame, cgframe=None, exclude=None):
+    def apply(self, frame, cgframe=None):
         """
         Apply the AA->CG mapping to an atomistic Frame.
 
         :param frame: Frame to which mapping will be applied
         :param cgframe: CG Frame to remap - optional
-        :param exclude: Set of molecule names to exclude from mapping - e.g. solvent
         :return: Frame instance containing the CG frame
         """
         if self._map_center == "mass" and not self._masses_are_set:
@@ -226,16 +225,13 @@ class Mapping:
 
         if cgframe is None:
             # Frame needs initialising
-            cgframe = self._cg_frame_setup(frame.residues, frame.name)
+            cgframe = self._cg_frame_setup(frame.yield_resname_in(self._mappings), frame.name)
 
         cgframe.time = frame.time
         cgframe.number = frame.number
         cgframe.box = frame.box
 
-        select_predicate = lambda res: res.name in self._mappings and not (exclude is not None and res.name in exclude)
-        aa_residues = filter(select_predicate, frame)
-
-        for aares, cgres in zip(aa_residues, cgframe):
+        for aares, cgres in zip(frame.yield_resname_in(self._mappings), cgframe):
             molmap = self._mappings[aares.name]
 
             for i, (bead, bmap) in enumerate(zip(cgres, molmap)):
diff --git a/pycgtool/pycgtool.py b/pycgtool/pycgtool.py
index 16a1480..dc086aa 100755
--- a/pycgtool/pycgtool.py
+++ b/pycgtool/pycgtool.py
@@ -18,7 +18,7 @@ def main(args, config):
     :param args: Arguments from argparse
     :param config: Configuration dictionary
     """
-    frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin, exclude={"SOL"})
+    frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin)
 
     if args.bnd:
         logger.info("Bond measurements will be made")
@@ -29,7 +29,7 @@ def main(args, config):
     if args.map:
         logger.info("Mapping will be performed")
         mapping = Mapping(args.map, config, itp=args.itp)
-        cgframe = mapping.apply(frame, exclude={"SOL"})
+        cgframe = mapping.apply(frame)
         cgframe.output(config.output_name + ".gro", format=config.output)
     else:
         logger.info("Mapping will not be performed")
@@ -46,7 +46,7 @@ def main(args, config):
         if not frame.next_frame():
             return False
         if args.map:
-            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
+            cgframe = mapping.apply(frame, cgframe=cgframe)
             if config.output_xtc:
                 cgframe.write_xtc(config.output_name + ".xtc")
         else:
@@ -84,7 +84,7 @@ def map_only(args, config):
     """
     frame = Frame(gro=args.gro, xtc=args.xtc)
     mapping = Mapping(args.map, config)
-    cgframe = mapping.apply(frame, exclude={"SOL"})
+    cgframe = mapping.apply(frame)
     cgframe.output(config.output_name + ".gro", format=config.output)
 
     if args.xtc and (config.output_xtc or args.outputxtc):
@@ -93,7 +93,7 @@ def map_only(args, config):
             nonlocal cgframe
             if not frame.next_frame():
                 return False
-            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
+            cgframe = mapping.apply(frame, cgframe=cgframe)
             cgframe.write_xtc(config.output_name + ".xtc")
             return True
 
diff --git a/test/data/sugar.bnd b/test/data/sugar.bnd
index 4ef01f8..9a67739 100644
--- a/test/data/sugar.bnd
+++ b/test/data/sugar.bnd
@@ -1,6 +1,4 @@
 ; comments begin with a semicolon
-[SOL]
-
 [ALLA]
 C1 C2
 C2 C3
diff --git a/test/test_bondset.py b/test/test_bondset.py
index d8eebf4..62a7ca4 100644
--- a/test/test_bondset.py
+++ b/test/test_bondset.py
@@ -47,10 +47,8 @@ class BondSetTest(unittest.TestCase):
 
     def test_bondset_create(self):
         measure = BondSet("test/data/sugar.bnd", DummyOptions)
-        self.assertEqual(2, len(measure))
-        self.assertTrue("SOL" in measure)
+        self.assertEqual(1, len(measure))
         self.assertTrue("ALLA" in measure)
-        self.assertEqual(0, len(measure["SOL"]))
         self.assertEqual(18, len(measure["ALLA"]))
 
     def test_bondset_apply(self):
@@ -96,7 +94,7 @@ class BondSetTest(unittest.TestCase):
 
         cgframe = mapping.apply(frame)
         while frame.next_frame():
-            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
+            cgframe = mapping.apply(frame, cgframe=cgframe)
             measure.apply(cgframe)
 
         measure.boltzmann_invert()
@@ -112,7 +110,7 @@ class BondSetTest(unittest.TestCase):
 
         cgframe = mapping.apply(frame)
         while frame.next_frame():
-            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
+            cgframe = mapping.apply(frame, cgframe=cgframe)
             measure.apply(cgframe)
 
         measure.boltzmann_invert()
@@ -130,7 +128,7 @@ class BondSetTest(unittest.TestCase):
 
         cgframe = mapping.apply(frame)
         while frame.next_frame():
-            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
+            cgframe = mapping.apply(frame, cgframe=cgframe)
             measure.apply(cgframe)
 
         measure.boltzmann_invert()
@@ -148,7 +146,7 @@ class BondSetTest(unittest.TestCase):
 
         cgframe = mapping.apply(frame)
         while frame.next_frame():
-            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
+            cgframe = mapping.apply(frame, cgframe=cgframe)
             measure.apply(cgframe)
 
         measure.boltzmann_invert()
@@ -184,13 +182,13 @@ class BondSetTest(unittest.TestCase):
         cgframe = mapping.apply(frame)
 
         while frame.next_frame():
-            cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
+            cgframe = mapping.apply(frame, cgframe=cgframe)
             measure.apply(cgframe)
 
         measure.boltzmann_invert()
 
         logging.disable(logging.WARNING)
-        measure.write_itp("sugar_out.itp", mapping, exclude={"SOL"})
+        measure.write_itp("sugar_out.itp", mapping)
         logging.disable(logging.NOTSET)
 
         self.assertTrue(cmp_whitespace_float("sugar_out.itp", "test/data/sugar_out.itp", float_rel_error=0.001))
-- 
GitLab