Skip to content
Snippets Groups Projects
Commit 8301ba3e authored by James Graham's avatar James Graham
Browse files

Add factory classmethod to create Frame from FrameReader

parent 08b193b3
No related branches found
No related tags found
No related merge requests found
...@@ -63,14 +63,10 @@ class Atom: ...@@ -63,14 +63,10 @@ class Atom:
def add_missing_data(self, other): def add_missing_data(self, other):
assert self.name == other.name assert self.name == other.name
assert self.num == other.num assert self.num == other.num
if self.type is None:
self.type = other.type for attr in ("type", "mass", "charge", "coords"):
if self.mass is None: if getattr(self, attr) is None:
self.mass = other.mass setattr(self, attr, getattr(other, attr))
if self.charge is None:
self.charge = other.charge
if self.coords is None:
self.coords = other.coords
class Residue: class Residue:
...@@ -333,6 +329,13 @@ class Frame: ...@@ -333,6 +329,13 @@ class Frame:
if itp is not None: if itp is not None:
self._parse_itp(itp) self._parse_itp(itp)
@classmethod
def instance_from_reader(cls, reader):
obj = cls()
obj._trajreader = reader
obj._trajreader.initialise_frame(obj)
return obj
def __len__(self): def __len__(self):
return len(self.residues) return len(self.residues)
......
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import numpy as np import numpy as np
from pycgtool.frame import Atom, Residue, Frame from pycgtool.frame import Atom, Residue, Frame
from pycgtool.frame import FrameReaderSimpleTraj, FrameReaderMDTraj, FrameReader
try: try:
import mdtraj import mdtraj
...@@ -164,6 +165,38 @@ class FrameTest(unittest.TestCase): ...@@ -164,6 +165,38 @@ class FrameTest(unittest.TestCase):
while frame.next_frame(): while frame.next_frame():
frame.write_xtc("water_test2.xtc") frame.write_xtc("water_test2.xtc")
def test_frame_instance_from_reader(self):
reader = FrameReaderSimpleTraj("test/data/water.gro")
frame = Frame.instance_from_reader(reader)
self.assertEqual(221, len(frame.residues))
self.assertEqual("SOL", frame.residues[0].name)
self.assertEqual(3, len(frame.residues[0].atoms))
self.assertEqual("OW", frame.residues[0].atoms[0].name)
np.testing.assert_allclose(np.array([0.696, 1.33, 1.211]),
frame.residues[0].atoms[0].coords)
def test_frame_instance_from_reader_dummy(self):
class DummyReader(FrameReader):
def _initialise_frame(self, frame):
frame.dummy_reader = True
def _read_frame_number(self, number):
return number * 10, [], None
reader = DummyReader(None)
frame = Frame.instance_from_reader(reader)
self.assertTrue(frame.dummy_reader)
frame.next_frame()
self.assertEqual(frame.number, 0)
self.assertEqual(frame.time, 0)
self.assertIsNone(frame.box)
frame.next_frame()
self.assertEqual(frame.number, 1)
self.assertEqual(frame.time, 10)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment