Skip to content
Snippets Groups Projects
Commit 65765f94 authored by James Graham's avatar James Graham
Browse files

Approx 25-30% performance improvement on membrane benchmark from coordinate mapping function

Fixed some complaints by linter
parent 91ccc467
No related branches found
No related tags found
No related merge requests found
......@@ -10,10 +10,11 @@ import logging
import numpy as np
try:
from tqdm import tqdm
except ImportError:
pass
from .util import tqdm_dummy as tqdm
from .util import sliding, dist_with_pbc, transpose_and_sample
from .util import extend_graph_chain, backup_file
......@@ -373,11 +374,8 @@ class BondSet:
bond_iter = itertools.chain(*self._molecules.values())
bond_iter_wrap = bond_iter
if progress:
try:
total = sum(map(len, self._molecules.values()))
bond_iter_wrap = tqdm(bond_iter, total=total, ncols=80)
except NameError:
pass
for bond in bond_iter_wrap:
try:
......
......@@ -176,20 +176,21 @@ class FrameReaderSimpleTraj(FrameReader):
Parse a GROMACS GRO file and create Residues/Atoms
Required before reading coordinates from XTC file
:param filename: Filename of GROMACS GRO to read
:param frame: Frame instance to initialise from GRO file
"""
with open(self._topname) as gro:
frame.name = gro.readline().strip()
self.num_atoms = int(gro.readline())
frame.natoms = self.num_atoms
resnum_last = None
atnum = 0
unpacker = FixedFormatUnpacker("I5,A5,A5,X5,F8,F8,F8",
unpacker = FixedFormatUnpacker("I5,2A5,5X,3F8",
FixedFormatUnpacker.FormatStyle.Fortran)
for _ in range(self.num_atoms):
resnum, resname, atomname, x, y, z = unpacker.unpack(gro.readline())
coords = np.array([x, y, z], dtype=np.float32)
resnum, resname, atomname, *coords = unpacker.unpack(gro.readline())
coords = np.array(coords, dtype=np.float32)
if resnum != resnum_last:
frame.residues.append(Residue(name=resname,
......@@ -201,8 +202,7 @@ class FrameReaderSimpleTraj(FrameReader):
frame.residues[-1].add_atom(atom)
atnum += 1
line = gro.readline()
frame.box = np.array([float(x) for x in line.split()[0:3]], dtype=np.float32)
frame.box = np.array([float(x) for x in gro.readline().split()[0:3]], dtype=np.float32)
def _read_frame_number(self, number):
"""
......@@ -234,6 +234,7 @@ class FrameReaderMDTraj(FrameReader):
else:
e.msg = "The MDTraj FrameReader requires the module MDTraj (and probably Scipy)"
raise
logger.warning("WARNING: Using MDTraj which renames solvent molecules")
try:
if trajname is None:
......@@ -256,26 +257,16 @@ class FrameReaderMDTraj(FrameReader):
Parse a GROMACS GRO file and create Residues/Atoms
Required before reading coordinates from XTC file
:param filename: Filename of GROMACS GRO to read
:param frame: Frame instance to initialise from GRO file
"""
try:
import mdtraj
except ImportError as e:
if "scipy" in e.msg:
e.msg = "The MDTraj FrameReader also requires Scipy"
else:
e.msg = "The MDTraj FrameReader requires the module MDTraj (and probably Scipy)"
raise
logger.warning("WARNING: Using MDTraj which renames solvent molecules")
top = mdtraj.load(self._topname)
frame.name = ""
self.num_atoms = top.n_atoms
frame.natoms = top.n_atoms
for residue in top.topology.residues:
frame.residues.append(Residue(name=residue.name,
num=residue.resSeq))
frame.residues = [Residue(name=res.name, num=res.resSeq) for res in top.topology.residues]
for atom in top.topology.atoms:
new_atom = Atom(name=atom.name, num=atom.serial,
......@@ -304,6 +295,7 @@ class Frame:
:param itp: GROMACS ITP file to read masses and charges
:return: Frame instance
"""
self.name = ""
self.residues = []
self.number = frame_start - 1
self.time = 0
......
......@@ -20,8 +20,6 @@ except ImportError:
from .util import NumbaDummy
numba = NumbaDummy()
np.seterr(all="raise")
logger = logging.getLogger(__name__)
......@@ -29,7 +27,7 @@ class BeadMap(Atom):
"""
POD class holding values relating to the AA->CG transformation for a single bead.
"""
__slots__ = ["name", "type", "atoms", "charge", "mass", "weights"]
__slots__ = ["name", "type", "atoms", "charge", "mass", "weights", "weights_dict"]
def __init__(self, name=None, type=None, atoms=None, charge=0, mass=0):
"""
......@@ -45,8 +43,9 @@ class BeadMap(Atom):
Atom.__init__(self, name=name, type=type, charge=charge, mass=mass)
self.atoms = atoms
# NB: Mass weights are added in Mapping.__init__ if an itp file is provided
self.weights = {"geom": np.array([[1. / len(atoms)] for _ in atoms], dtype=np.float32),
self.weights_dict = {"geom": np.array([[1. / len(atoms)] for _ in atoms], dtype=np.float32),
"first": np.array([[1.]] + [[0.] for _ in atoms[1:]], dtype=np.float32)}
self.weights = self.weights_dict["geom"]
def __iter__(self):
"""
......@@ -118,7 +117,7 @@ class Mapping:
mass_array = np.array([[atoms[atom][1]] for atom in bead], dtype=np.float32)
bead.mass = sum(mass_array)
mass_array /= bead.mass
bead.weights["mass"] = mass_array
bead.weights_dict["mass"] = mass_array
for atom in bead:
if self._manual_charges[molname]:
......@@ -128,6 +127,12 @@ class Mapping:
self._masses_are_set = True
if self._map_center == "mass" and not self._masses_are_set:
self._guess_atom_masses()
for molname, mapping in self._mappings.items():
for bmap in mapping:
bmap.weights = bmap.weights_dict[self._map_center]
def __len__(self):
return len(self._mappings)
......@@ -170,7 +175,7 @@ class Mapping:
raise RuntimeError(msg)
mass_array /= bead.mass
bead.weights["mass"] = mass_array
bead.weights_dict["mass"] = mass_array
self._masses_are_set = True
......@@ -220,9 +225,6 @@ class Mapping:
:param cgframe: CG Frame to remap - optional
:return: Frame instance containing the CG frame
"""
if self._map_center == "mass" and not self._masses_are_set:
self._guess_atom_masses()
if cgframe is None:
# Frame needs initialising
cgframe = self._cg_frame_setup(frame.yield_resname_in(self._mappings), frame.name)
......@@ -231,20 +233,19 @@ class Mapping:
cgframe.number = frame.number
cgframe.box = frame.box
coord_func = calc_coords_weight if frame.box[0] * frame.box[1] * frame.box[2] else calc_coords_weight_nobox
for aares, cgres in zip(frame.yield_resname_in(self._mappings), cgframe):
molmap = self._mappings[aares.name]
for i, (bead, bmap) in enumerate(zip(cgres, molmap)):
ref_coords = aares[bmap[0]].coords
coords = np.array([aares[atom].coords for atom in bmap], dtype=np.float32)
try:
weights = bmap.weights[self._map_center]
except KeyError as e:
e.args = ("Unknown mapping type '{0}'.".format(e.args[0]),)
raise
if len(bmap) == 1:
bead.coords = ref_coords
continue
bead.coords = calc_coords_weight(ref_coords, coords, cgframe.box, weights)
coords = np.asarray([aares[atom].coords for atom in bmap], dtype=np.float32)
bead.coords = coord_func(ref_coords, coords, cgframe.box, bmap.weights)
return cgframe
......@@ -260,13 +261,25 @@ def calc_coords_weight(ref_coords, coords, box, weights):
:param weights: Array of atom weights, must sum to 1
:return: Coordinates of CG bead
"""
n = len(coords)
result = np.zeros(3, dtype=np.float32)
for i in range(n):
tmp_coords = coords[i]
tmp_coords -= ref_coords
if box[0] * box[1] * box[2] != 0:
tmp_coords -= box * np.rint(tmp_coords / box)
result += weights[i] * tmp_coords
vectors = coords - ref_coords
vectors -= box * np.rint(vectors / box)
result = np.sum(weights * vectors, axis=0)
result += ref_coords
return result
@numba.jit
def calc_coords_weight_nobox(ref_coords, coords, box, weights):
"""
Calculate the coordinates of a single CG bead from weighted component atom coordinates.
:param ref_coords: Coordinates of reference atom, usually first atom in bead
:param coords: Array of coordinates of component atoms
:param box: PBC box vectors
:param weights: Array of atom weights, must sum to 1
:return: Coordinates of CG bead
"""
vectors = coords - ref_coords
result = np.sum(weights * vectors, axis=0)
result += ref_coords
return result
......@@ -13,7 +13,6 @@ import re
from collections import namedtuple
import numpy as np
np.seterr(all="raise")
logger = logging.getLogger(__name__)
......@@ -24,7 +23,7 @@ class NumbaDummy:
"""
def __getattr__(self, item):
if item == "jit":
return self.jit
return NumbaDummy.jit
return self
def __getitem__(self, item):
......@@ -33,7 +32,8 @@ class NumbaDummy:
def __call__(self, *args, **kwargs):
return None
def jit(self, *args, **kwargs):
@staticmethod
def jit(*args, **kwargs):
"""
Dummy version of numba.jit decorator, does nothing
"""
......@@ -144,7 +144,7 @@ def dist_with_pbc(pos1, pos2, box):
:return: Vector between two points
"""
d = pos2 - pos1
if box[0] * box[1] * box[2] != 0:
if box[0] * box[1] * box[2]:
d -= box * np.rint(d / box)
return d
......@@ -417,11 +417,13 @@ class SimpleEnum(object):
return self.value == other.value
@classmethod
def enum(cls, name, keys=list(), values=None):
def enum(cls, name, keys=None, values=None):
def returner(val):
return lambda _: val
enum_cls = type(name, (cls.Enum,), {})
if keys is None:
keys = []
if values is None:
for key in keys:
prop = property(returner(cls.EnumItem(name, key)))
......@@ -439,7 +441,6 @@ class SimpleEnum(object):
return cls.enum(name, key_val_dict.keys(), key_val_dict.values())
# TODO testing
class FixedFormatUnpacker(object):
"""
Unpack strings printed in fixed format.
......@@ -503,3 +504,7 @@ class FixedFormatUnpacker(object):
if format_item.type is not None:
items.append(format_item.type(string_part.strip()))
return items
def tqdm_dummy(iterable, **kwargs):
return iterable
......@@ -130,7 +130,7 @@ class UtilTest(unittest.TestCase):
enum2 = SimpleEnum.enum("enum2", ["one", "two", "three"])
with self.assertRaises(TypeError):
tmp = enum2.one == enum.one
assert enum2.one == enum.one
self.assertTrue("one" in enum)
self.assertFalse("four" in enum)
......@@ -155,7 +155,7 @@ class UtilTest(unittest.TestCase):
enum2 = SimpleEnum.enum("enum2", ["one", "two", "three"])
with self.assertRaises(TypeError):
tmp = enum2.one == enum.one
assert enum2.one == enum.one
self.assertTrue("one" in enum)
self.assertEqual(111, enum.one.value)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment