From cfc314f34e4bcf01ffb279487e2d0fb0af481f41 Mon Sep 17 00:00:00 2001 From: James Graham <J.A.Graham@soton.ac.uk> Date: Fri, 16 Dec 2016 16:54:02 +0000 Subject: [PATCH] Remove 'exclude' from Frame, Mapping and Bondset up to 10% performance increase --- pycgtool/bondset.py | 5 +---- pycgtool/frame.py | 38 +++++++++++++++----------------------- pycgtool/mapping.py | 10 +++------- pycgtool/pycgtool.py | 10 +++++----- test/data/sugar.bnd | 2 -- test/test_bondset.py | 16 +++++++--------- 6 files changed, 31 insertions(+), 50 deletions(-) diff --git a/pycgtool/bondset.py b/pycgtool/bondset.py index 1768c8d..d5af198 100644 --- a/pycgtool/bondset.py +++ b/pycgtool/bondset.py @@ -260,13 +260,12 @@ class BondSet: e.args = ("Bead(s) {0} do(es) not exist in residue {1}".format(missing, mol),) raise - def write_itp(self, filename, mapping, exclude=set()): + def write_itp(self, filename, mapping): """ Output a GROMACS .itp file containing atoms/beads and bonded terms. :param filename: Name of output file :param mapping: AA->CG Mapping from which to collect bead properties - :param exclude: Set of molecule names to be excluded from itp """ self._populate_atom_numbers(mapping) backup_file(filename) @@ -296,8 +295,6 @@ class BondSet: # Print molecule not_calc = " Parameters have not been calculated." for mol in self._molecules: - if mol in exclude: - continue if mol not in mapping: logger.warning("Molecule '{0}' present in bonding file, but not in mapping.".format(mol) + not_calc) continue diff --git a/pycgtool/frame.py b/pycgtool/frame.py index b7a959b..a0a103d 100644 --- a/pycgtool/frame.py +++ b/pycgtool/frame.py @@ -8,6 +8,7 @@ Both Frame and Residue are iterable. Residue is indexable with either atom numbe import os import abc import logging +import itertools import numpy as np @@ -104,16 +105,11 @@ class Residue: class FrameReader(metaclass=abc.ABCMeta): - def __init__(self, topname, trajname=None, exclude=None, frame_start=0): + def __init__(self, topname, trajname=None, frame_start=0): self._topname = topname self._trajname = trajname self._frame_number = frame_start - if exclude is not None: - self._exclude = exclude - else: - self._exclude = set() - self.num_atoms = 0 self.num_frames = 0 @@ -127,21 +123,13 @@ class FrameReader(metaclass=abc.ABCMeta): return result def read_frame_number(self, number, frame): - if self._trajname is None: - return False try: time, coords, box = self._read_frame_number(number) frame.time = time frame.box = box - i = 0 - for res in frame.residues: - if res.name in self._exclude: - i += len(res.atoms) - continue - for atom in res: - atom.coords = coords[i] - i += 1 + for atom, coord_line in zip(itertools.chain.from_iterable(frame.residues), coords): + atom.coords = coord_line except (IndexError, AttributeError): # IndexError - run out of xtc frames @@ -159,7 +147,7 @@ class FrameReader(metaclass=abc.ABCMeta): class FrameReaderSimpleTraj(FrameReader): - def __init__(self, topname, trajname=None, exclude=None, frame_start=0): + def __init__(self, topname, trajname=None, frame_start=0): """ Open input XTC file from which to read coordinates using simpletraj library. @@ -167,7 +155,7 @@ class FrameReaderSimpleTraj(FrameReader): :param trajname: MD trajectory file to read subsequent frames :param frame_start: Frame number to start on, default 0 """ - FrameReader.__init__(self, topname, trajname, exclude, frame_start) + FrameReader.__init__(self, topname, trajname, frame_start) from simpletraj import trajectory @@ -229,14 +217,14 @@ class FrameReaderSimpleTraj(FrameReader): class FrameReaderMDTraj(FrameReader): - def __init__(self, topname, trajname=None, exclude=None, frame_start=0): + def __init__(self, topname, trajname=None, frame_start=0): """ Open input XTC file from which to read coordinates using mdtraj library. :param topname: GROMACS GRO file from which to read topology :param trajname: GROMACS XTC file to read subsequent frames """ - FrameReader.__init__(self, topname, trajname, exclude, frame_start) + FrameReader.__init__(self, topname, trajname, frame_start) try: import mdtraj @@ -307,7 +295,7 @@ class Frame: """ Hold Atom data separated into Residues """ - def __init__(self, gro=None, xtc=None, itp=None, exclude=None, frame_start=0, xtc_reader="simpletraj"): + def __init__(self, gro=None, xtc=None, itp=None, frame_start=0, xtc_reader="simpletraj"): """ Return Frame instance having read Residues and Atoms from GRO if provided @@ -329,8 +317,7 @@ class Frame: open_xtc = {"simpletraj": FrameReaderSimpleTraj, "mdtraj": FrameReaderMDTraj} try: - self._trajreader = open_xtc[xtc_reader](gro, xtc, exclude=exclude, - frame_start=frame_start) + self._trajreader = open_xtc[xtc_reader](gro, xtc, frame_start=frame_start) except KeyError as e: e.args = ("XTC reader {0} is not a valid option.".format(xtc_reader)) raise @@ -362,6 +349,11 @@ class Frame: rep += "\n".join(atoms) return rep + def yield_resname_in(self, container): + for res in self: + if res.name in container: + yield res + def next_frame(self): """ Read next frame from input XTC. diff --git a/pycgtool/mapping.py b/pycgtool/mapping.py index 64abd39..83204ca 100644 --- a/pycgtool/mapping.py +++ b/pycgtool/mapping.py @@ -212,13 +212,12 @@ class Mapping: return cgframe - def apply(self, frame, cgframe=None, exclude=None): + def apply(self, frame, cgframe=None): """ Apply the AA->CG mapping to an atomistic Frame. :param frame: Frame to which mapping will be applied :param cgframe: CG Frame to remap - optional - :param exclude: Set of molecule names to exclude from mapping - e.g. solvent :return: Frame instance containing the CG frame """ if self._map_center == "mass" and not self._masses_are_set: @@ -226,16 +225,13 @@ class Mapping: if cgframe is None: # Frame needs initialising - cgframe = self._cg_frame_setup(frame.residues, frame.name) + cgframe = self._cg_frame_setup(frame.yield_resname_in(self._mappings), frame.name) cgframe.time = frame.time cgframe.number = frame.number cgframe.box = frame.box - select_predicate = lambda res: res.name in self._mappings and not (exclude is not None and res.name in exclude) - aa_residues = filter(select_predicate, frame) - - for aares, cgres in zip(aa_residues, cgframe): + for aares, cgres in zip(frame.yield_resname_in(self._mappings), cgframe): molmap = self._mappings[aares.name] for i, (bead, bmap) in enumerate(zip(cgres, molmap)): diff --git a/pycgtool/pycgtool.py b/pycgtool/pycgtool.py index 16a1480..dc086aa 100755 --- a/pycgtool/pycgtool.py +++ b/pycgtool/pycgtool.py @@ -18,7 +18,7 @@ def main(args, config): :param args: Arguments from argparse :param config: Configuration dictionary """ - frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin, exclude={"SOL"}) + frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin) if args.bnd: logger.info("Bond measurements will be made") @@ -29,7 +29,7 @@ def main(args, config): if args.map: logger.info("Mapping will be performed") mapping = Mapping(args.map, config, itp=args.itp) - cgframe = mapping.apply(frame, exclude={"SOL"}) + cgframe = mapping.apply(frame) cgframe.output(config.output_name + ".gro", format=config.output) else: logger.info("Mapping will not be performed") @@ -46,7 +46,7 @@ def main(args, config): if not frame.next_frame(): return False if args.map: - cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) + cgframe = mapping.apply(frame, cgframe=cgframe) if config.output_xtc: cgframe.write_xtc(config.output_name + ".xtc") else: @@ -84,7 +84,7 @@ def map_only(args, config): """ frame = Frame(gro=args.gro, xtc=args.xtc) mapping = Mapping(args.map, config) - cgframe = mapping.apply(frame, exclude={"SOL"}) + cgframe = mapping.apply(frame) cgframe.output(config.output_name + ".gro", format=config.output) if args.xtc and (config.output_xtc or args.outputxtc): @@ -93,7 +93,7 @@ def map_only(args, config): nonlocal cgframe if not frame.next_frame(): return False - cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) + cgframe = mapping.apply(frame, cgframe=cgframe) cgframe.write_xtc(config.output_name + ".xtc") return True diff --git a/test/data/sugar.bnd b/test/data/sugar.bnd index 4ef01f8..9a67739 100644 --- a/test/data/sugar.bnd +++ b/test/data/sugar.bnd @@ -1,6 +1,4 @@ ; comments begin with a semicolon -[SOL] - [ALLA] C1 C2 C2 C3 diff --git a/test/test_bondset.py b/test/test_bondset.py index d8eebf4..62a7ca4 100644 --- a/test/test_bondset.py +++ b/test/test_bondset.py @@ -47,10 +47,8 @@ class BondSetTest(unittest.TestCase): def test_bondset_create(self): measure = BondSet("test/data/sugar.bnd", DummyOptions) - self.assertEqual(2, len(measure)) - self.assertTrue("SOL" in measure) + self.assertEqual(1, len(measure)) self.assertTrue("ALLA" in measure) - self.assertEqual(0, len(measure["SOL"])) self.assertEqual(18, len(measure["ALLA"])) def test_bondset_apply(self): @@ -96,7 +94,7 @@ class BondSetTest(unittest.TestCase): cgframe = mapping.apply(frame) while frame.next_frame(): - cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) + cgframe = mapping.apply(frame, cgframe=cgframe) measure.apply(cgframe) measure.boltzmann_invert() @@ -112,7 +110,7 @@ class BondSetTest(unittest.TestCase): cgframe = mapping.apply(frame) while frame.next_frame(): - cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) + cgframe = mapping.apply(frame, cgframe=cgframe) measure.apply(cgframe) measure.boltzmann_invert() @@ -130,7 +128,7 @@ class BondSetTest(unittest.TestCase): cgframe = mapping.apply(frame) while frame.next_frame(): - cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) + cgframe = mapping.apply(frame, cgframe=cgframe) measure.apply(cgframe) measure.boltzmann_invert() @@ -148,7 +146,7 @@ class BondSetTest(unittest.TestCase): cgframe = mapping.apply(frame) while frame.next_frame(): - cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) + cgframe = mapping.apply(frame, cgframe=cgframe) measure.apply(cgframe) measure.boltzmann_invert() @@ -184,13 +182,13 @@ class BondSetTest(unittest.TestCase): cgframe = mapping.apply(frame) while frame.next_frame(): - cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) + cgframe = mapping.apply(frame, cgframe=cgframe) measure.apply(cgframe) measure.boltzmann_invert() logging.disable(logging.WARNING) - measure.write_itp("sugar_out.itp", mapping, exclude={"SOL"}) + measure.write_itp("sugar_out.itp", mapping) logging.disable(logging.NOTSET) self.assertTrue(cmp_whitespace_float("sugar_out.itp", "test/data/sugar_out.itp", float_rel_error=0.001)) -- GitLab