Skip to content
Snippets Groups Projects
Commit 7f1b1ae2 authored by James Graham's avatar James Graham
Browse files

Working XTC output

CGFrame now persistent - carries XTC buffer
parent f3c1083f
Branches
No related tags found
No related merge requests found
...@@ -32,9 +32,9 @@ def main(args, config): ...@@ -32,9 +32,9 @@ def main(args, config):
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
for _ in Progress(numframes, postwhile=frame.next_frame): for _ in Progress(numframes, postwhile=frame.next_frame):
if args.map: if args.map:
cgframe = mapping.apply(frame, exclude={"SOL"}) cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
if config.output_xtc: if config.output_xtc:
cgframe.write_xtc("out.xtc") cgframe.write_to_xtc_buffer()
else: else:
cgframe = frame cgframe = frame
...@@ -53,6 +53,9 @@ def main(args, config): ...@@ -53,6 +53,9 @@ def main(args, config):
if config.dump_measurements: if config.dump_measurements:
bonds.dump_values(config.dump_n_values) bonds.dump_values(config.dump_n_values)
if args.map and config.output_xtc:
cgframe.flush_xtc_buffer("out.xtc")
def map_only(args, config): def map_only(args, config):
""" """
...@@ -69,9 +72,9 @@ def map_only(args, config): ...@@ -69,9 +72,9 @@ def map_only(args, config):
if args.xtc and config.output_xtc: if args.xtc and config.output_xtc:
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
for _ in Progress(numframes, postwhile=frame.next_frame): for _ in Progress(numframes, postwhile=frame.next_frame):
cgframe = mapping.apply(frame, exclude={"SOL"}) cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
cgframe.write_xtc("out.xtc") cgframe.write_to_xtc_buffer()
cgframe.flush_xtc_buffer("out.xtc")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -122,7 +122,7 @@ class Frame: ...@@ -122,7 +122,7 @@ class Frame:
self.numframes = 0 self.numframes = 0
self.box = np.zeros(3, dtype=np.float32) self.box = np.zeros(3, dtype=np.float32)
self._out_xtc = None self._xtc_buffer = None
if gro is not None: if gro is not None:
self._parse_gro(gro) self._parse_gro(gro)
...@@ -257,9 +257,52 @@ class Frame: ...@@ -257,9 +257,52 @@ class Frame:
except (IndexError, AttributeError): except (IndexError, AttributeError):
return False return False
def write_xtc(self, filename=None): class XTCBuffer:
if self._out_xtc is None: def __init__(self):
self._out_xtc = XTCTrajectoryFile(filename, mode="w") self.coords = []
self.step = []
self.box = []
def __call__(self):
return (np.array(self.coords, dtype=np.float32),
np.array(self.step, dtype=np.int32),
np.array(self.box, dtype=np.float32))
def append(self, coords, step, box):
self.coords.append(coords)
self.step.append(step)
self.box.append(box)
def write_to_xtc_buffer(self):
if self._xtc_buffer is None:
self._xtc_buffer = self.XTCBuffer()
xyz = np.ndarray((self.natoms, 3))
i = 0
for residue in self.residues:
for atom in residue.atoms:
xyz[i] = atom.coords
i += 1
box = np.zeros((3, 3), dtype=np.float32)
for i in range(3):
box[i][i] = self.box[i]
self._xtc_buffer.append(xyz, self.number, box)
def flush_xtc_buffer(self, filename):
if self._xtc_buffer is not None:
xtc = XTCTrajectoryFile(filename, mode="w")
xyz, step, box = self._xtc_buffer()
xtc.write(xyz, step=step, box=box)
xtc.close()
self._xtc_buffer = None
def write_xtc(self, filename):
if self._xtc_buffer is None:
self._xtc_buffer = XTCTrajectoryFile(filename, mode="w")
xyz = np.ndarray((1, self.natoms, 3), dtype=np.float32) xyz = np.ndarray((1, self.natoms, 3), dtype=np.float32)
i = 0 i = 0
...@@ -267,11 +310,15 @@ class Frame: ...@@ -267,11 +310,15 @@ class Frame:
for atom in residue.atoms: for atom in residue.atoms:
xyz[0][i] = atom.coords xyz[0][i] = atom.coords
i += 1 i += 1
step = np.array([self.number], dtype=np.int32) step = np.array([self.number], dtype=np.int32)
box = np.zeros((1, 3, 3), dtype=np.float32) box = np.zeros((1, 3, 3), dtype=np.float32)
for i in range(3): for i in range(3):
box[0][i][i] = self.box[i] box[0][i][i] = self.box[i]
self._out_xtc.write(xyz, step=step, box=box)
self._xtc_buffer.write(xyz, step=step, box=box)
# self._xtc_buffer.close()
def _parse_gro(self, filename): def _parse_gro(self, filename):
""" """
......
...@@ -116,7 +116,7 @@ class Mapping: ...@@ -116,7 +116,7 @@ class Mapping:
def __iter__(self): def __iter__(self):
return iter(self._mappings) return iter(self._mappings)
def apply(self, frame, exclude=None): def apply(self, frame, cgframe=None, exclude=None):
""" """
Apply the AA->CG mapping to an atomistic Frame. Apply the AA->CG mapping to an atomistic Frame.
...@@ -124,8 +124,10 @@ class Mapping: ...@@ -124,8 +124,10 @@ class Mapping:
:param exclude: Set of molecule names to exclude from mapping - e.g. solvent :param exclude: Set of molecule names to exclude from mapping - e.g. solvent
:return: A new Frame instance containing the CG frame :return: A new Frame instance containing the CG frame
""" """
if cgframe is None:
cgframe = Frame() cgframe = Frame()
cgframe.name = frame.name cgframe.name = frame.name
cgframe.number = frame.number
cgframe.box = frame.box cgframe.box = frame.box
cgframe.natoms = 0 cgframe.natoms = 0
cgframe.residues = [] cgframe.residues = []
......
...@@ -6,6 +6,12 @@ import numpy as np ...@@ -6,6 +6,12 @@ import numpy as np
from pycgtool.frame import Atom, Residue, Frame from pycgtool.frame import Atom, Residue, Frame
try:
import mdtraj
mdtraj_present = True
except ImportError:
mdtraj_present = False
class AtomTest(unittest.TestCase): class AtomTest(unittest.TestCase):
def test_atom_create(self): def test_atom_create(self):
...@@ -82,11 +88,13 @@ class FrameTest(unittest.TestCase): ...@@ -82,11 +88,13 @@ class FrameTest(unittest.TestCase):
np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]), np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]),
frame.residues[0].atoms[0].coords) frame.residues[0].atoms[0].coords)
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
def test_frame_read_xtc_mdtraj_numframes(self): def test_frame_read_xtc_mdtraj_numframes(self):
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj") xtc_reader="mdtraj")
self.assertEqual(12, frame.numframes) self.assertEqual(12, frame.numframes)
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
def test_frame_read_xtc_mdtraj(self): def test_frame_read_xtc_mdtraj(self):
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj") xtc_reader="mdtraj")
...@@ -101,11 +109,20 @@ class FrameTest(unittest.TestCase): ...@@ -101,11 +109,20 @@ class FrameTest(unittest.TestCase):
np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]), np.testing.assert_allclose(np.array([1.122, 1.130, 1.534]),
frame.residues[0].atoms[0].coords) frame.residues[0].atoms[0].coords)
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
def test_frame_write_xtc_buffered_mdtraj(self):
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj")
while frame.next_frame():
frame.write_to_xtc_buffer()
frame.flush_xtc_buffer("test.xtc")
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
def test_frame_write_xtc_mdtraj(self): def test_frame_write_xtc_mdtraj(self):
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj") xtc_reader="mdtraj")
while frame.next_frame(): while frame.next_frame():
frame.write_xtc("test.xtc") frame.write_xtc("test2.xtc")
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment