diff --git a/pycgtool/mapping.py b/pycgtool/mapping.py index 19a3283aaaa3c391e419c289a4f23fc109f1b3fa..ce92e6347045f2a6cb1178db9874c4cbe85703ac 100644 --- a/pycgtool/mapping.py +++ b/pycgtool/mapping.py @@ -38,8 +38,9 @@ class BeadMap(Atom): """ Atom.__init__(self, name=name, type=type, charge=charge, mass=mass) self.atoms = atoms + # NB: weights are overwritten in Mapping.__init__ if an itp file is provided self.weights = {"geom": None, - "mass": np.ones(len(self.atoms), dtype=np.float32)} + "first": np.array([[1]] + [[0] for _ in range(len(self.atoms) - 1)], dtype=np.float32)} def __iter__(self): """ @@ -63,15 +64,6 @@ class EmptyBeadError(Exception): pass -def coordinate_weight(center, atom): - centers = {"geom": lambda at: at.coords, - "mass": lambda at: at.coords * at.mass} - try: - return centers[center](atom) - except KeyError: - raise ValueError("Invalid map-center type '{0}'".format(center)) - - class Mapping: """ Class used to perform the AA->CG mapping. @@ -175,7 +167,14 @@ class Mapping: if self._map_center == "geom": bead.coords = calc_coords(ref_coords, coords, box) else: - weights = bmap.weights[self._map_center] + try: + weights = bmap.weights[self._map_center] + except KeyError as e: + if self._map_center == "mass": + e.args = ("Error with mapping type 'mass', did you provide an itp file?",) + else: + e.args = ("Error, unknown mapping type '{0}'".format(e.args[0]),) + raise bead.coords = calc_coords_weight(ref_coords, coords, box, weights) return res diff --git a/test/test_mapping.py b/test/test_mapping.py index 554f6381be3f9caae1a99609c84d1aaff4123c26..008a1a8dbe9e81c53d26776de180d6d284d02e49 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -5,17 +5,13 @@ import os import numpy as np from pycgtool.mapping import Mapping -from pycgtool.frame import Frame, Atom +from pycgtool.frame import Frame class DummyOptions: map_center = "geom" -class DummyOptionsMass: - map_center = "mass" - - class MappingTest(unittest.TestCase): def test_mapping_create(self): mapping = Mapping("test/data/water.map", DummyOptions) @@ -40,16 +36,33 @@ class MappingTest(unittest.TestCase): cgframe = mapping.apply(frame) np.testing.assert_allclose(frame[0][0].coords, cgframe[0][0].coords) - def test_mapping_weights(self): + def test_mapping_weights_geom(self): frame = Frame("test/data/two.gro") - mapping = Mapping("test/data/two.map", DummyOptions, itp="test/data/two.itp") cg = mapping.apply(frame) np.testing.assert_allclose(np.array([1.5, 1.5, 1.5]), cg[0][0].coords) - mapping = Mapping("test/data/two.map", DummyOptionsMass, itp="test/data/two.itp") + def test_mapping_weights_mass(self): + frame = Frame("test/data/two.gro") + options = DummyOptions() + options.map_center = "mass" + + mapping = Mapping("test/data/two.map", options, itp="test/data/two.itp") cg = mapping.apply(frame) np.testing.assert_allclose(np.array([2., 2., 2.]), cg[0][0].coords) + with self.assertRaises(Exception): + mapping = Mapping("test/data/two.map", options) + cg = mapping.apply(frame) + + def test_mapping_weights_first(self): + frame = Frame("test/data/two.gro") + options = DummyOptions() + options.map_center = "first" + + mapping = Mapping("test/data/two.map", options, itp="test/data/two.itp") + cg = mapping.apply(frame) + np.testing.assert_allclose(np.array([1., 1., 1.]), cg[0][0].coords) +