Skip to content
Snippets Groups Projects
Commit 3b333aa8 authored by James Graham's avatar James Graham
Browse files

Add MDAnalysis FrameReader

Move FrameReaders out into own module, still shares tests with Frame
parent fc57e818
No related branches found
No related tags found
No related merge requests found
...@@ -5,14 +5,11 @@ The Frame class may contain multiple Residues which may each contain multiple At ...@@ -5,14 +5,11 @@ The Frame class may contain multiple Residues which may each contain multiple At
Both Frame and Residue are iterable. Residue is indexable with either atom numbers or names. Both Frame and Residue are iterable. Residue is indexable with either atom numbers or names.
""" """
import os
import abc
import logging import logging
import itertools
import numpy as np import numpy as np
from .util import backup_file, FixedFormatUnpacker from .util import backup_file
from .parsers.cfg import CFG from .parsers.cfg import CFG
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -110,217 +107,6 @@ class Residue: ...@@ -110,217 +107,6 @@ class Residue:
self.name_to_num[atom.name] = len(self.atoms) - 1 self.name_to_num[atom.name] = len(self.atoms) - 1
class FrameReader(metaclass=abc.ABCMeta):
def __init__(self, topname, trajname=None, frame_start=0):
self._topname = topname
self._trajname = trajname
self._frame_number = frame_start
self.num_atoms = 0
self.num_frames = 0
def initialise_frame(self, frame):
self._initialise_frame(frame)
def read_next(self, frame):
result = self.read_frame_number(self._frame_number, frame)
if result:
self._frame_number += 1
return result
def read_frame_number(self, number, frame):
try:
time, coords, box = self._read_frame_number(number)
frame.time = time
frame.box = box
for atom, coord_line in zip(itertools.chain.from_iterable(frame.residues), coords):
atom.coords = coord_line
except (IndexError, AttributeError):
# IndexError - run out of xtc frames
# AttributeError - we didn't provide an xtc
return False
return True
@abc.abstractmethod
def _initialise_frame(self, frame):
pass
@abc.abstractmethod
def _read_frame_number(self, number):
pass
class FrameReaderSimpleTraj(FrameReader):
def __init__(self, topname, trajname=None, frame_start=0):
"""
Open input XTC file from which to read coordinates using simpletraj library.
:param topname: MD topology file - not used
:param trajname: MD trajectory file to read subsequent frames
:param frame_start: Frame number to start on, default 0
"""
FrameReader.__init__(self, topname, trajname, frame_start)
from simpletraj import trajectory
if trajname is not None:
try:
self._traj = trajectory.get_trajectory(trajname)
except OSError as e:
if not os.path.isfile(trajname):
raise FileNotFoundError(trajname) from e
e.args = ("Error opening file '{0}'".format(trajname),)
raise
self.num_atoms = self._traj.numatoms
self.num_frames = self._traj.numframes
def _initialise_frame(self, frame):
"""
Parse a GROMACS GRO file and create Residues/Atoms
Required before reading coordinates from XTC file
:param frame: Frame instance to initialise from GRO file
"""
with open(self._topname) as gro:
frame.name = gro.readline().strip()
self.num_atoms = int(gro.readline())
frame.natoms = self.num_atoms
resnum_last = None
atnum = 0
unpacker = FixedFormatUnpacker("I5,2A5,5X,3F8",
FixedFormatUnpacker.FormatStyle.Fortran)
for _ in range(self.num_atoms):
resnum, resname, atomname, *coords = unpacker.unpack(gro.readline())
coords = np.array(coords, dtype=np.float32)
if resnum != resnum_last:
frame.residues.append(Residue(name=resname,
num=resnum))
resnum_last = resnum
atnum = 0
atom = Atom(name=atomname, num=atnum, coords=coords)
frame.residues[-1].add_atom(atom)
atnum += 1
frame.box = np.array([float(x) for x in gro.readline().split()[0:3]], dtype=np.float32)
def _read_frame_number(self, number):
"""
Read next frame from XTC using simpletraj library.
"""
self._traj.get_frame(number)
# SimpleTraj uses Angstrom, we want nanometers
xyz = self._traj.x / 10
box = np.diag(self._traj.box)[0:3] / 10
return self._traj.time, xyz, box
class FrameReaderMDTraj(FrameReader):
def __init__(self, topname, trajname=None, frame_start=0):
"""
Open input XTC file from which to read coordinates using mdtraj library.
:param topname: GROMACS GRO file from which to read topology
:param trajname: GROMACS XTC file to read subsequent frames
"""
FrameReader.__init__(self, topname, trajname, frame_start)
try:
import mdtraj
except ImportError as e:
if "scipy" in e.msg:
e.msg = "The MDTraj FrameReader also requires Scipy"
else:
e.msg = "The MDTraj FrameReader requires the module MDTraj (and probably Scipy)"
raise
logger.warning("WARNING: Using MDTraj which renames solvent molecules")
try:
if trajname is None:
self._traj = mdtraj.load(topname)
else:
self._traj = mdtraj.load(trajname, top=topname)
except OSError as e:
if not os.path.isfile(topname):
raise FileNotFoundError(topname) from e
if not os.path.isfile(trajname):
raise FileNotFoundError(trajname) from e
e.args = ("Error opening file '{0}' or '{1}'".format(topname, trajname),)
raise
self.num_atoms = self._traj.n_atoms
self.num_frames = self._traj.n_frames
def _initialise_frame(self, frame):
"""
Parse a GROMACS GRO file and create Residues/Atoms
Required before reading coordinates from XTC file
:param frame: Frame instance to initialise from GRO file
"""
import mdtraj
top = mdtraj.load(self._topname)
frame.name = ""
self.num_atoms = top.n_atoms
frame.natoms = top.n_atoms
frame.residues = [Residue(name=res.name, num=res.resSeq) for res in top.topology.residues]
for atom in top.topology.atoms:
new_atom = Atom(name=atom.name, num=atom.serial,
coords=top.xyz[0][atom.index])
frame.residues[atom.residue.index].add_atom(new_atom)
frame.box = top.unitcell_lengths[0]
def _read_frame_number(self, number):
"""
Read next frame from XTC using mdtraj library.
"""
return self._traj.time[number], self._traj.xyz[number], self._traj.unitcell_lengths[number]
class FrameReaderMDAnalysis(FrameReader):
def __init__(self, topname, trajname=None, frame_start=0):
import MDAnalysis
super().__init__(topname, trajname, frame_start)
if trajname is None:
self._traj = MDAnalysis.Universe(topname)
else:
self._traj = MDAnalysis.Universe(topname, trajname)
self.num_atoms = self._traj.atoms.n_atoms
self.num_frames = self._traj.trajectory.n_frames
def _initialise_frame(self, frame):
frame.name = ""
frame.natoms = self.num_atoms
import MDAnalysis
topol = MDAnalysis.Universe(self._topname)
frame.box = topol.dimensions[0:3] / 10.
for res in topol.residues:
residue = Residue(name=res.resname, num=res.resnum)
for atom in res.atoms:
residue.add_atom(Atom(name=atom.name, num=atom.id, coords=atom.position / 10.))
frame.residues.append(residue)
def _read_frame_number(self, number):
traj_frame = self._traj.trajectory[number]
return traj_frame.time, traj_frame.positions / 10., traj_frame.dimensions[0:3] / 10.
class Frame: class Frame:
""" """
Hold Atom data separated into Residues Hold Atom data separated into Residues
...@@ -345,14 +131,8 @@ class Frame: ...@@ -345,14 +131,8 @@ class Frame:
self._xtc_buffer = None self._xtc_buffer = None
if gro is not None: if gro is not None:
open_xtc = {"simpletraj": FrameReaderSimpleTraj, from .framereader import get_frame_reader
"mdtraj": FrameReaderMDTraj, self._trajreader = get_frame_reader(gro, traj=xtc, frame_start=frame_start)
"mdanalysis": FrameReaderMDAnalysis}
try:
self._trajreader = open_xtc[xtc_reader](gro, xtc, frame_start=frame_start)
except KeyError as e:
e.args = ("XTC reader {0} is not a valid option.".format(xtc_reader))
raise
self._trajreader.initialise_frame(self) self._trajreader.initialise_frame(self)
......
"""
This module contains classes for reading trajectories into a Frame instance.
Multiple readers are defined, allowing different underlying trajectory libraries to be used.
This module is tightly coupled with the frame.py module, so shares a set of unit tests.
"""
import os
import abc
import itertools
import logging
import collections
import numpy as np
from .frame import Atom, Residue
from .util import FixedFormatUnpacker
logger = logging.getLogger(__name__)
class UnsupportedFormatException(Exception):
pass
def get_frame_reader(top, traj=None, frame_start=0, name=None):
readers = collections.OrderedDict([
("simpletraj", FrameReaderSimpleTraj),
("mdtraj", FrameReaderMDTraj),
("mdanalysis", FrameReaderMDAnalysis),
])
try:
return readers[name](top, traj, frame_start)
except KeyError as e:
if name is not None:
e.args = ("Frame reader '{0}' is not a valid option.".format(name),)
raise
for name, reader in readers.items(): # Return first reader that accepts given files
try:
return reader(top, traj, frame_start)
except (UnsupportedFormatException, ImportError):
continue
raise UnsupportedFormatException("None of the available readers support the trajector format provided")
class FrameReader(metaclass=abc.ABCMeta):
def __init__(self, topname, trajname=None, frame_start=0):
self._topname = topname
self._trajname = trajname
self._frame_number = frame_start
self.num_atoms = 0
self.num_frames = 0
def initialise_frame(self, frame):
self._initialise_frame(frame)
def read_next(self, frame):
result = self.read_frame_number(self._frame_number, frame)
if result:
self._frame_number += 1
return result
def read_frame_number(self, number, frame):
try:
time, coords, box = self._read_frame_number(number)
frame.time = time
frame.box = box
for atom, coord_line in zip(itertools.chain.from_iterable(frame.residues), coords):
atom.coords = coord_line
except (IndexError, AttributeError):
# IndexError - run out of xtc frames
# AttributeError - we didn't provide an xtc
return False
return True
@abc.abstractmethod
def _initialise_frame(self, frame):
pass
@abc.abstractmethod
def _read_frame_number(self, number):
pass
class FrameReaderSimpleTraj(FrameReader):
def __init__(self, topname, trajname=None, frame_start=0):
"""
Open input XTC file from which to read coordinates using simpletraj library.
:param topname: MD topology file - not used
:param trajname: MD trajectory file to read subsequent frames
:param frame_start: Frame number to start on, default 0
"""
FrameReader.__init__(self, topname, trajname, frame_start)
with open(self._topname) as gro:
gro.readline()
try:
self.num_atoms = int(gro.readline())
except ValueError as e:
raise UnsupportedFormatException from e
from simpletraj import trajectory
if trajname is not None:
try:
self._traj = trajectory.get_trajectory(trajname)
except OSError as e:
if not os.path.isfile(trajname):
raise FileNotFoundError(trajname) from e
e.args = ("Error opening file '{0}'".format(trajname),)
raise
if self._traj.numatoms != self.num_atoms:
raise UnsupportedFormatException
self.num_frames = self._traj.numframes
def _initialise_frame(self, frame):
"""
Parse a GROMACS GRO file and create Residues/Atoms
Required before reading coordinates from XTC file
:param frame: Frame instance to initialise from GRO file
"""
with open(self._topname) as gro:
frame.name = gro.readline().strip()
try:
self.num_atoms = int(gro.readline())
except ValueError as e:
raise UnsupportedFormatException from e
frame.natoms = self.num_atoms
resnum_last = None
atnum = 0
unpacker = FixedFormatUnpacker("I5,2A5,5X,3F8",
FixedFormatUnpacker.FormatStyle.Fortran)
for _ in range(self.num_atoms):
resnum, resname, atomname, *coords = unpacker.unpack(gro.readline())
coords = np.array(coords, dtype=np.float32)
if resnum != resnum_last:
frame.residues.append(Residue(name=resname,
num=resnum))
resnum_last = resnum
atnum = 0
atom = Atom(name=atomname, num=atnum, coords=coords)
frame.residues[-1].add_atom(atom)
atnum += 1
frame.box = np.array([float(x) for x in gro.readline().split()[0:3]], dtype=np.float32)
def _read_frame_number(self, number):
"""
Read next frame from XTC using simpletraj library.
"""
self._traj.get_frame(number)
# SimpleTraj uses Angstrom, we want nanometers
xyz = self._traj.x / 10
box = np.diag(self._traj.box)[0:3] / 10
return self._traj.time, xyz, box
class FrameReaderMDTraj(FrameReader):
def __init__(self, topname, trajname=None, frame_start=0):
"""
Open input XTC file from which to read coordinates using mdtraj library.
:param topname: GROMACS GRO file from which to read topology
:param trajname: GROMACS XTC file to read subsequent frames
"""
FrameReader.__init__(self, topname, trajname, frame_start)
try:
import mdtraj
except ImportError as e:
if "scipy" in e.msg:
e.msg = "The MDTraj FrameReader also requires Scipy"
else:
e.msg = "The MDTraj FrameReader requires the module MDTraj (and probably Scipy)"
raise
logger.warning("WARNING: Using MDTraj which renames solvent molecules")
try:
if trajname is None:
self._traj = mdtraj.load(topname)
else:
self._traj = mdtraj.load(trajname, top=topname)
except OSError as e:
if not os.path.isfile(topname):
raise FileNotFoundError(topname) from e
if trajname is not None and not os.path.isfile(trajname):
raise FileNotFoundError(trajname) from e
if "no loader for filename" in repr(e):
raise UnsupportedFormatException from e
e.args = ("Error opening file '{0}' or '{1}'".format(topname, trajname),)
raise
self.num_atoms = self._traj.n_atoms
self.num_frames = self._traj.n_frames
def _initialise_frame(self, frame):
"""
Parse a GROMACS GRO file and create Residues/Atoms
Required before reading coordinates from XTC file
:param frame: Frame instance to initialise from GRO file
"""
import mdtraj
top = mdtraj.load(self._topname)
frame.name = ""
self.num_atoms = top.n_atoms
frame.natoms = top.n_atoms
frame.residues = [Residue(name=res.name, num=res.resSeq) for res in top.topology.residues]
for atom in top.topology.atoms:
new_atom = Atom(name=atom.name, num=atom.serial,
coords=top.xyz[0][atom.index])
frame.residues[atom.residue.index].add_atom(new_atom)
frame.box = top.unitcell_lengths[0]
def _read_frame_number(self, number):
"""
Read next frame from XTC using mdtraj library.
"""
return self._traj.time[number], self._traj.xyz[number], self._traj.unitcell_lengths[number]
class FrameReaderMDAnalysis(FrameReader):
def __init__(self, topname, trajname=None, frame_start=0):
import MDAnalysis
super().__init__(topname, trajname, frame_start)
try:
if trajname is None:
self._traj = MDAnalysis.Universe(topname)
else:
self._traj = MDAnalysis.Universe(topname, trajname)
except ValueError as e:
if "isn't a valid topology format" in repr(e):
raise UnsupportedFormatException from e
raise
self.num_atoms = self._traj.atoms.n_atoms
self.num_frames = self._traj.trajectory.n_frames
def _initialise_frame(self, frame):
frame.name = ""
frame.natoms = self.num_atoms
import MDAnalysis
topol = MDAnalysis.Universe(self._topname)
frame.box = topol.dimensions[0:3] / 10.
for res in topol.residues:
residue = Residue(name=res.resname, num=res.resnum)
for atom in res.atoms:
residue.add_atom(Atom(name=atom.name, num=atom.id, coords=atom.position / 10.))
frame.residues.append(residue)
def _read_frame_number(self, number):
traj_frame = self._traj.trajectory[number]
return traj_frame.time, traj_frame.positions / 10., traj_frame.dimensions[0:3] / 10.
...@@ -6,7 +6,8 @@ import logging ...@@ -6,7 +6,8 @@ import logging
import numpy as np import numpy as np
from pycgtool.frame import Atom, Residue, Frame from pycgtool.frame import Atom, Residue, Frame
from pycgtool.frame import FrameReaderSimpleTraj, FrameReaderMDAnalysis, FrameReaderMDTraj, FrameReader from pycgtool.framereader import FrameReaderSimpleTraj, FrameReaderMDAnalysis, FrameReaderMDTraj
from pycgtool.framereader import FrameReader, get_frame_reader, UnsupportedFormatException
try: try:
import mdtraj import mdtraj
...@@ -94,7 +95,7 @@ class FrameTest(unittest.TestCase): ...@@ -94,7 +95,7 @@ class FrameTest(unittest.TestCase):
self.helper_read_xtc(frame, first_only=True) self.helper_read_xtc(frame, first_only=True)
@unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present")
def test_frame_mdtraj_read_gro(self): def test_frame_mdtraj_read_gro(self):
logging.disable(logging.WARNING) logging.disable(logging.WARNING)
frame = Frame("test/data/water.gro", xtc_reader="mdtraj") frame = Frame("test/data/water.gro", xtc_reader="mdtraj")
...@@ -109,13 +110,24 @@ class FrameTest(unittest.TestCase): ...@@ -109,13 +110,24 @@ class FrameTest(unittest.TestCase):
self.helper_read_xtc(frame, first_only=True) self.helper_read_xtc(frame, first_only=True)
@unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present")
def test_frame_mdtraj_read_pdb(self): def test_frame_mdtraj_read_pdb(self):
reader = FrameReaderMDTraj("test/data/water.pdb") reader = FrameReaderMDTraj("test/data/water.pdb")
frame = Frame.instance_from_reader(reader) frame = Frame.instance_from_reader(reader)
self.helper_read_xtc(frame, first_only=True, skip_names=True) self.helper_read_xtc(frame, first_only=True, skip_names=True)
@unittest.skipIf(not mdtraj_present and not mdanalysis_present, "Neither MDTraj or MDAnalysis is present")
def test_frame_any_read_pdb(self):
reader = get_frame_reader("test/data/water.pdb")
frame = Frame.instance_from_reader(reader)
self.helper_read_xtc(frame, first_only=True, skip_names=True)
def test_frame_any_read_unsupported(self):
with self.assertRaises(UnsupportedFormatException):
reader = get_frame_reader("test/data/dppc.map")
@unittest.skipIf(not mdanalysis_present, "MDAnalysis not present") @unittest.skipIf(not mdanalysis_present, "MDAnalysis not present")
def test_frame_mdanalysis_read_pdb(self): def test_frame_mdanalysis_read_pdb(self):
reader = FrameReaderMDAnalysis("test/data/water.pdb") reader = FrameReaderMDAnalysis("test/data/water.pdb")
...@@ -134,7 +146,7 @@ class FrameTest(unittest.TestCase): ...@@ -134,7 +146,7 @@ class FrameTest(unittest.TestCase):
xtc_reader="simpletraj") xtc_reader="simpletraj")
self.assertEqual(11, frame.numframes) self.assertEqual(11, frame.numframes)
@unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present")
def test_frame_read_xtc_mdtraj_numframes(self): def test_frame_read_xtc_mdtraj_numframes(self):
logging.disable(logging.WARNING) logging.disable(logging.WARNING)
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
...@@ -147,7 +159,7 @@ class FrameTest(unittest.TestCase): ...@@ -147,7 +159,7 @@ class FrameTest(unittest.TestCase):
xtc_reader="simpletraj") xtc_reader="simpletraj")
self.helper_read_xtc(frame) self.helper_read_xtc(frame)
@unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present")
def test_frame_mdtraj_read_xtc(self): def test_frame_mdtraj_read_xtc(self):
logging.disable(logging.WARNING) logging.disable(logging.WARNING)
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
...@@ -163,7 +175,7 @@ class FrameTest(unittest.TestCase): ...@@ -163,7 +175,7 @@ class FrameTest(unittest.TestCase):
self.helper_read_xtc(frame) self.helper_read_xtc(frame)
@unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present")
def test_frame_write_xtc_mdtraj(self): def test_frame_write_xtc_mdtraj(self):
try: try:
os.remove("water_test2.xtc") os.remove("water_test2.xtc")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment