diff --git a/pycgtool.py b/pycgtool.py index 6eb7ca03169c8c0e437da63bf7ed7e7b2c8bc675..d7f119eb024b029d0b0f881464e1cda0d7530efc 100755 --- a/pycgtool.py +++ b/pycgtool.py @@ -34,7 +34,7 @@ def main(args, config): if args.map: cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) if config.output_xtc: - cgframe.write_to_xtc_buffer() + cgframe.write_xtc(config.output_name + ".xtc") else: cgframe = frame @@ -53,9 +53,6 @@ def main(args, config): if config.dump_measurements: bonds.dump_values(config.dump_n_values) - if args.map and (config.output_xtc or args.outputxtc): - cgframe.flush_xtc_buffer(config.output_name + ".xtc") - def map_only(args, config): """ @@ -73,8 +70,7 @@ def map_only(args, config): numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin for _ in Progress(numframes, postwhile=frame.next_frame): cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) - cgframe.write_to_xtc_buffer() - cgframe.flush_xtc_buffer(config.output_name + ".xtc") + cgframe.write_xtc(config.output_name + ".xtc") if __name__ == "__main__": @@ -96,7 +92,7 @@ if __name__ == "__main__": config = Options([("output_name", "out"), ("output", "gro"), ("output_xtc", args.outputxtc), - ("map_only", bool(args.bnd)), + ("map_only", not bool(args.bnd)), ("map_center", "geom"), ("constr_threshold", 100000), ("dump_measurements", bool(args.bnd) and not bool(args.map)), @@ -107,8 +103,6 @@ if __name__ == "__main__": ("generate_angles", True), ("generate_dihedrals", False)], args) - if not args.bnd: - config.set("map_only", True) if not args.map and not args.bnd: parser.error("One or both of -m and -b is required.") diff --git a/pycgtool/frame.py b/pycgtool/frame.py index f0fd2a714def21846d28a27eb72dd5e22e2f4448..944474167444738c335b19f7c8e8acfcb5905132 100644 --- a/pycgtool/frame.py +++ b/pycgtool/frame.py @@ -12,7 +12,7 @@ import numpy as np from simpletraj import trajectory try: - from mdtraj.formats import XTCTrajectoryFile + import mdtraj except ImportError: pass @@ -137,7 +137,7 @@ class Frame: open_xtc = {"simpletraj": self._open_xtc_simpletraj, "mdtraj": self._open_xtc_mdtraj} try: - open_xtc[self._xtc_reader](xtc) + open_xtc[self._xtc_reader](xtc, gro) except KeyError as e: e.args = ("XTC reader {0} is not a valid option.".format(self._xtc_reader)) raise @@ -145,7 +145,7 @@ class Frame: if itp is not None: self._parse_itp(itp) - def _open_xtc_simpletraj(self, xtc): + def _open_xtc_simpletraj(self, xtc, gro=None): try: self.xtc = trajectory.XtcTrajectory(xtc) except OSError as e: @@ -158,9 +158,9 @@ class Frame: raise AssertionError("Number of atoms does not match between gro and xtc files.") self.numframes += self.xtc.numframes - def _open_xtc_mdtraj(self, xtc): + def _open_xtc_mdtraj(self, xtc, gro): try: - self.xtc = XTCTrajectoryFile(xtc) + self.xtc = mdtraj.load_xtc(xtc, top=gro) except OSError as e: if not os.path.isfile(xtc): raise FileNotFoundError(xtc) from e @@ -168,25 +168,11 @@ class Frame: raise except NameError as e: raise ImportError("No module named 'mdtraj'") from e - else: - xyz, time, step, box = self.xtc.read(n_frames=1) - natoms = len(xyz[0]) - if natoms != self.natoms: - print(xyz[0]) - print(natoms, self.natoms) - raise AssertionError("Number of atoms does not match between gro and xtc files.") - # Seek to end to count frames - # self.xtc.seek(0, whence=2) - self.numframes += 1 - self.xtc.seek(0) - while True: - try: - self.xtc.seek(1, whence=1) - self.numframes += 1 - except IndexError: - break - self.xtc.seek(0) + if self.xtc.n_atoms != self.natoms: + raise AssertionError("Number of atoms does not match between gro and xtc files.") + + self.numframes += self.xtc.n_frames def __len__(self): return len(self.residues) @@ -221,25 +207,22 @@ class Frame: raise def _next_frame_mdtraj(self, exclude=None): - if XTCTrajectoryFile is None: - raise ImportError("No module named 'mdtraj'") try: - # self.xtc.seek(self.number) i = 0 - xyz, time, step, box = self.xtc.read(n_frames=1) - xyz = xyz[0] + # This returns a slice of length 1, properties still need to be indexed + xtc_frame = self.xtc[self.number] for res in self.residues: if exclude is not None and res.name in exclude: continue for atom in res: - atom.coords = xyz[i] + atom.coords = xtc_frame.xyz[0][i] i += 1 self.number += 1 - self.box = np.diag(box[0]) / 10. + self.box = xtc_frame.unitcell_lengths[0] return True # IndexError - run out of xtc frames - # AttributeError - we didn't provide an xtc + # AttributeError - we didn't provide an xtc or is wrong reader except (IndexError, AttributeError): return False @@ -265,56 +248,13 @@ class Frame: except (IndexError, AttributeError): return False - class XTCBuffer: - def __init__(self): - self.coords = [] - self.step = [] - self.box = [] - - def __call__(self): - return (np.array(self.coords, dtype=np.float32), - np.array(self.step, dtype=np.int32), - np.array(self.box, dtype=np.float32)) - - def append(self, coords, step, box): - self.coords.append(coords) - self.step.append(step) - self.box.append(box) - - def write_to_xtc_buffer(self): + def write_xtc(self, filename): if self._xtc_buffer is None: - self._xtc_buffer = self.XTCBuffer() - - xyz = np.ndarray((self.natoms, 3)) - i = 0 - for residue in self.residues: - for atom in residue.atoms: - xyz[i] = atom.coords - i += 1 - - box = np.zeros((3, 3), dtype=np.float32) - for i in range(3): - box[i][i] = self.box[i] - - self._xtc_buffer.append(xyz, self.number, box) - - def flush_xtc_buffer(self, filename): - if self._xtc_buffer is not None: try: - xtc = XTCTrajectoryFile(filename, mode="w") - - xyz, step, box = self._xtc_buffer() - xtc.write(xyz, step=step, box=box) - xtc.close() - - self._xtc_buffer = None + self._xtc_buffer = mdtraj.formats.XTCTrajectoryFile(filename, mode="w") except NameError as e: raise ImportError("No module named 'mdtraj'") from e - def write_xtc(self, filename): - if self._xtc_buffer is None: - self._xtc_buffer = XTCTrajectoryFile(filename, mode="w") - xyz = np.ndarray((1, self.natoms, 3), dtype=np.float32) i = 0 for residue in self.residues: diff --git a/test/test_frame.py b/test/test_frame.py index 7c40f3d63f222e6e30185e5a6307ff71430684a7..383fe6eee4c49b9da97481151a1721a289b5d659 100644 --- a/test/test_frame.py +++ b/test/test_frame.py @@ -77,16 +77,24 @@ class FrameTest(unittest.TestCase): def test_frame_read_xtc(self): frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc") + self.assertEqual(663, frame.natoms) # These are the coordinates from the gro file np.testing.assert_allclose(np.array([0.696, 1.33, 1.211]), frame.residues[0].atoms[0].coords) + np.testing.assert_allclose(np.array([1.89868, 1.89868, 1.89868]), + frame.box) + frame.next_frame() # These coordinates are from the xtc file np.testing.assert_allclose(np.array([1.176, 1.152, 1.586]), frame.residues[0].atoms[0].coords) + np.testing.assert_allclose(np.array([1.9052, 1.9052, 1.9052]), + frame.box) frame.next_frame() np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]), frame.residues[0].atoms[0].coords) + np.testing.assert_allclose(np.array([1.90325272, 1.90325272, 1.90325272]), + frame.box) @unittest.skipIf(not mdtraj_present, "MDTRAJ not present") def test_frame_read_xtc_mdtraj_numframes(self): @@ -98,31 +106,30 @@ class FrameTest(unittest.TestCase): def test_frame_read_xtc_mdtraj(self): frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", xtc_reader="mdtraj") + self.assertEqual(663, frame.natoms) # These are the coordinates from the gro file np.testing.assert_allclose(np.array([0.696, 1.33, 1.211]), frame.residues[0].atoms[0].coords) + np.testing.assert_allclose(np.array([1.89868, 1.89868, 1.89868]), + frame.box) frame.next_frame() # These coordinates are from the xtc file np.testing.assert_allclose(np.array([1.176, 1.152, 1.586]), frame.residues[0].atoms[0].coords) + np.testing.assert_allclose(np.array([1.9052, 1.9052, 1.9052]), + frame.box) frame.next_frame() np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]), frame.residues[0].atoms[0].coords) - - @unittest.skipIf(not mdtraj_present, "MDTRAJ not present") - def test_frame_write_xtc_buffered_mdtraj(self): - frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", - xtc_reader="mdtraj") - while frame.next_frame(): - frame.write_to_xtc_buffer() - frame.flush_xtc_buffer("test.xtc") + np.testing.assert_allclose(np.array([1.90325272, 1.90325272, 1.90325272]), + frame.box) @unittest.skipIf(not mdtraj_present, "MDTRAJ not present") def test_frame_write_xtc_mdtraj(self): frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", xtc_reader="mdtraj") while frame.next_frame(): - frame.write_xtc("test2.xtc") + frame.write_xtc("water_test2.xtc") if __name__ == '__main__':