Skip to content
Snippets Groups Projects
Commit 7cc99cbe authored by James Graham's avatar James Graham
Browse files

Remove buffered XTC output - keep framewise output

Update mdtraj XTC input to use Trajectory class
parent 224917b5
Branches
No related tags found
No related merge requests found
...@@ -34,7 +34,7 @@ def main(args, config): ...@@ -34,7 +34,7 @@ def main(args, config):
if args.map: if args.map:
cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
if config.output_xtc: if config.output_xtc:
cgframe.write_to_xtc_buffer() cgframe.write_xtc(config.output_name + ".xtc")
else: else:
cgframe = frame cgframe = frame
...@@ -53,9 +53,6 @@ def main(args, config): ...@@ -53,9 +53,6 @@ def main(args, config):
if config.dump_measurements: if config.dump_measurements:
bonds.dump_values(config.dump_n_values) bonds.dump_values(config.dump_n_values)
if args.map and (config.output_xtc or args.outputxtc):
cgframe.flush_xtc_buffer(config.output_name + ".xtc")
def map_only(args, config): def map_only(args, config):
""" """
...@@ -73,8 +70,7 @@ def map_only(args, config): ...@@ -73,8 +70,7 @@ def map_only(args, config):
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
for _ in Progress(numframes, postwhile=frame.next_frame): for _ in Progress(numframes, postwhile=frame.next_frame):
cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
cgframe.write_to_xtc_buffer() cgframe.write_xtc(config.output_name + ".xtc")
cgframe.flush_xtc_buffer(config.output_name + ".xtc")
if __name__ == "__main__": if __name__ == "__main__":
...@@ -96,7 +92,7 @@ if __name__ == "__main__": ...@@ -96,7 +92,7 @@ if __name__ == "__main__":
config = Options([("output_name", "out"), config = Options([("output_name", "out"),
("output", "gro"), ("output", "gro"),
("output_xtc", args.outputxtc), ("output_xtc", args.outputxtc),
("map_only", bool(args.bnd)), ("map_only", not bool(args.bnd)),
("map_center", "geom"), ("map_center", "geom"),
("constr_threshold", 100000), ("constr_threshold", 100000),
("dump_measurements", bool(args.bnd) and not bool(args.map)), ("dump_measurements", bool(args.bnd) and not bool(args.map)),
...@@ -107,8 +103,6 @@ if __name__ == "__main__": ...@@ -107,8 +103,6 @@ if __name__ == "__main__":
("generate_angles", True), ("generate_angles", True),
("generate_dihedrals", False)], ("generate_dihedrals", False)],
args) args)
if not args.bnd:
config.set("map_only", True)
if not args.map and not args.bnd: if not args.map and not args.bnd:
parser.error("One or both of -m and -b is required.") parser.error("One or both of -m and -b is required.")
......
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
from simpletraj import trajectory from simpletraj import trajectory
try: try:
from mdtraj.formats import XTCTrajectoryFile import mdtraj
except ImportError: except ImportError:
pass pass
...@@ -137,7 +137,7 @@ class Frame: ...@@ -137,7 +137,7 @@ class Frame:
open_xtc = {"simpletraj": self._open_xtc_simpletraj, open_xtc = {"simpletraj": self._open_xtc_simpletraj,
"mdtraj": self._open_xtc_mdtraj} "mdtraj": self._open_xtc_mdtraj}
try: try:
open_xtc[self._xtc_reader](xtc) open_xtc[self._xtc_reader](xtc, gro)
except KeyError as e: except KeyError as e:
e.args = ("XTC reader {0} is not a valid option.".format(self._xtc_reader)) e.args = ("XTC reader {0} is not a valid option.".format(self._xtc_reader))
raise raise
...@@ -145,7 +145,7 @@ class Frame: ...@@ -145,7 +145,7 @@ class Frame:
if itp is not None: if itp is not None:
self._parse_itp(itp) self._parse_itp(itp)
def _open_xtc_simpletraj(self, xtc): def _open_xtc_simpletraj(self, xtc, gro=None):
try: try:
self.xtc = trajectory.XtcTrajectory(xtc) self.xtc = trajectory.XtcTrajectory(xtc)
except OSError as e: except OSError as e:
...@@ -158,9 +158,9 @@ class Frame: ...@@ -158,9 +158,9 @@ class Frame:
raise AssertionError("Number of atoms does not match between gro and xtc files.") raise AssertionError("Number of atoms does not match between gro and xtc files.")
self.numframes += self.xtc.numframes self.numframes += self.xtc.numframes
def _open_xtc_mdtraj(self, xtc): def _open_xtc_mdtraj(self, xtc, gro):
try: try:
self.xtc = XTCTrajectoryFile(xtc) self.xtc = mdtraj.load_xtc(xtc, top=gro)
except OSError as e: except OSError as e:
if not os.path.isfile(xtc): if not os.path.isfile(xtc):
raise FileNotFoundError(xtc) from e raise FileNotFoundError(xtc) from e
...@@ -168,25 +168,11 @@ class Frame: ...@@ -168,25 +168,11 @@ class Frame:
raise raise
except NameError as e: except NameError as e:
raise ImportError("No module named 'mdtraj'") from e raise ImportError("No module named 'mdtraj'") from e
else:
xyz, time, step, box = self.xtc.read(n_frames=1) if self.xtc.n_atoms != self.natoms:
natoms = len(xyz[0])
if natoms != self.natoms:
print(xyz[0])
print(natoms, self.natoms)
raise AssertionError("Number of atoms does not match between gro and xtc files.") raise AssertionError("Number of atoms does not match between gro and xtc files.")
# Seek to end to count frames self.numframes += self.xtc.n_frames
# self.xtc.seek(0, whence=2)
self.numframes += 1
self.xtc.seek(0)
while True:
try:
self.xtc.seek(1, whence=1)
self.numframes += 1
except IndexError:
break
self.xtc.seek(0)
def __len__(self): def __len__(self):
return len(self.residues) return len(self.residues)
...@@ -221,25 +207,22 @@ class Frame: ...@@ -221,25 +207,22 @@ class Frame:
raise raise
def _next_frame_mdtraj(self, exclude=None): def _next_frame_mdtraj(self, exclude=None):
if XTCTrajectoryFile is None:
raise ImportError("No module named 'mdtraj'")
try: try:
# self.xtc.seek(self.number)
i = 0 i = 0
xyz, time, step, box = self.xtc.read(n_frames=1) # This returns a slice of length 1, properties still need to be indexed
xyz = xyz[0] xtc_frame = self.xtc[self.number]
for res in self.residues: for res in self.residues:
if exclude is not None and res.name in exclude: if exclude is not None and res.name in exclude:
continue continue
for atom in res: for atom in res:
atom.coords = xyz[i] atom.coords = xtc_frame.xyz[0][i]
i += 1 i += 1
self.number += 1 self.number += 1
self.box = np.diag(box[0]) / 10. self.box = xtc_frame.unitcell_lengths[0]
return True return True
# IndexError - run out of xtc frames # IndexError - run out of xtc frames
# AttributeError - we didn't provide an xtc # AttributeError - we didn't provide an xtc or is wrong reader
except (IndexError, AttributeError): except (IndexError, AttributeError):
return False return False
...@@ -265,56 +248,13 @@ class Frame: ...@@ -265,56 +248,13 @@ class Frame:
except (IndexError, AttributeError): except (IndexError, AttributeError):
return False return False
class XTCBuffer: def write_xtc(self, filename):
def __init__(self):
self.coords = []
self.step = []
self.box = []
def __call__(self):
return (np.array(self.coords, dtype=np.float32),
np.array(self.step, dtype=np.int32),
np.array(self.box, dtype=np.float32))
def append(self, coords, step, box):
self.coords.append(coords)
self.step.append(step)
self.box.append(box)
def write_to_xtc_buffer(self):
if self._xtc_buffer is None: if self._xtc_buffer is None:
self._xtc_buffer = self.XTCBuffer()
xyz = np.ndarray((self.natoms, 3))
i = 0
for residue in self.residues:
for atom in residue.atoms:
xyz[i] = atom.coords
i += 1
box = np.zeros((3, 3), dtype=np.float32)
for i in range(3):
box[i][i] = self.box[i]
self._xtc_buffer.append(xyz, self.number, box)
def flush_xtc_buffer(self, filename):
if self._xtc_buffer is not None:
try: try:
xtc = XTCTrajectoryFile(filename, mode="w") self._xtc_buffer = mdtraj.formats.XTCTrajectoryFile(filename, mode="w")
xyz, step, box = self._xtc_buffer()
xtc.write(xyz, step=step, box=box)
xtc.close()
self._xtc_buffer = None
except NameError as e: except NameError as e:
raise ImportError("No module named 'mdtraj'") from e raise ImportError("No module named 'mdtraj'") from e
def write_xtc(self, filename):
if self._xtc_buffer is None:
self._xtc_buffer = XTCTrajectoryFile(filename, mode="w")
xyz = np.ndarray((1, self.natoms, 3), dtype=np.float32) xyz = np.ndarray((1, self.natoms, 3), dtype=np.float32)
i = 0 i = 0
for residue in self.residues: for residue in self.residues:
......
...@@ -77,16 +77,24 @@ class FrameTest(unittest.TestCase): ...@@ -77,16 +77,24 @@ class FrameTest(unittest.TestCase):
def test_frame_read_xtc(self): def test_frame_read_xtc(self):
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc") frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc")
self.assertEqual(663, frame.natoms)
# These are the coordinates from the gro file # These are the coordinates from the gro file
np.testing.assert_allclose(np.array([0.696, 1.33, 1.211]), np.testing.assert_allclose(np.array([0.696, 1.33, 1.211]),
frame.residues[0].atoms[0].coords) frame.residues[0].atoms[0].coords)
np.testing.assert_allclose(np.array([1.89868, 1.89868, 1.89868]),
frame.box)
frame.next_frame() frame.next_frame()
# These coordinates are from the xtc file # These coordinates are from the xtc file
np.testing.assert_allclose(np.array([1.176, 1.152, 1.586]), np.testing.assert_allclose(np.array([1.176, 1.152, 1.586]),
frame.residues[0].atoms[0].coords) frame.residues[0].atoms[0].coords)
np.testing.assert_allclose(np.array([1.9052, 1.9052, 1.9052]),
frame.box)
frame.next_frame() frame.next_frame()
np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]), np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]),
frame.residues[0].atoms[0].coords) frame.residues[0].atoms[0].coords)
np.testing.assert_allclose(np.array([1.90325272, 1.90325272, 1.90325272]),
frame.box)
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present") @unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
def test_frame_read_xtc_mdtraj_numframes(self): def test_frame_read_xtc_mdtraj_numframes(self):
...@@ -98,31 +106,30 @@ class FrameTest(unittest.TestCase): ...@@ -98,31 +106,30 @@ class FrameTest(unittest.TestCase):
def test_frame_read_xtc_mdtraj(self): def test_frame_read_xtc_mdtraj(self):
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj") xtc_reader="mdtraj")
self.assertEqual(663, frame.natoms)
# These are the coordinates from the gro file # These are the coordinates from the gro file
np.testing.assert_allclose(np.array([0.696, 1.33, 1.211]), np.testing.assert_allclose(np.array([0.696, 1.33, 1.211]),
frame.residues[0].atoms[0].coords) frame.residues[0].atoms[0].coords)
np.testing.assert_allclose(np.array([1.89868, 1.89868, 1.89868]),
frame.box)
frame.next_frame() frame.next_frame()
# These coordinates are from the xtc file # These coordinates are from the xtc file
np.testing.assert_allclose(np.array([1.176, 1.152, 1.586]), np.testing.assert_allclose(np.array([1.176, 1.152, 1.586]),
frame.residues[0].atoms[0].coords) frame.residues[0].atoms[0].coords)
np.testing.assert_allclose(np.array([1.9052, 1.9052, 1.9052]),
frame.box)
frame.next_frame() frame.next_frame()
np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]), np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]),
frame.residues[0].atoms[0].coords) frame.residues[0].atoms[0].coords)
np.testing.assert_allclose(np.array([1.90325272, 1.90325272, 1.90325272]),
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present") frame.box)
def test_frame_write_xtc_buffered_mdtraj(self):
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj")
while frame.next_frame():
frame.write_to_xtc_buffer()
frame.flush_xtc_buffer("test.xtc")
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present") @unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
def test_frame_write_xtc_mdtraj(self): def test_frame_write_xtc_mdtraj(self):
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj") xtc_reader="mdtraj")
while frame.next_frame(): while frame.next_frame():
frame.write_xtc("test2.xtc") frame.write_xtc("water_test2.xtc")
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment