diff --git a/pycgtool/frame.py b/pycgtool/frame.py index 2c4b76fe2b6dd388a9bb21593ccbe443572c602e..5edd16e4dc45ce54b070d83befca2d9ce79ea92a 100644 --- a/pycgtool/frame.py +++ b/pycgtool/frame.py @@ -5,14 +5,11 @@ The Frame class may contain multiple Residues which may each contain multiple At Both Frame and Residue are iterable. Residue is indexable with either atom numbers or names. """ -import os -import abc import logging -import itertools import numpy as np -from .util import backup_file, FixedFormatUnpacker +from .util import backup_file from .parsers.cfg import CFG logger = logging.getLogger(__name__) @@ -110,217 +107,6 @@ class Residue: self.name_to_num[atom.name] = len(self.atoms) - 1 -class FrameReader(metaclass=abc.ABCMeta): - def __init__(self, topname, trajname=None, frame_start=0): - self._topname = topname - self._trajname = trajname - self._frame_number = frame_start - - self.num_atoms = 0 - self.num_frames = 0 - - def initialise_frame(self, frame): - self._initialise_frame(frame) - - def read_next(self, frame): - result = self.read_frame_number(self._frame_number, frame) - if result: - self._frame_number += 1 - return result - - def read_frame_number(self, number, frame): - try: - time, coords, box = self._read_frame_number(number) - frame.time = time - frame.box = box - - for atom, coord_line in zip(itertools.chain.from_iterable(frame.residues), coords): - atom.coords = coord_line - - except (IndexError, AttributeError): - # IndexError - run out of xtc frames - # AttributeError - we didn't provide an xtc - return False - return True - - @abc.abstractmethod - def _initialise_frame(self, frame): - pass - - @abc.abstractmethod - def _read_frame_number(self, number): - pass - - -class FrameReaderSimpleTraj(FrameReader): - def __init__(self, topname, trajname=None, frame_start=0): - """ - Open input XTC file from which to read coordinates using simpletraj library. - - :param topname: MD topology file - not used - :param trajname: MD trajectory file to read subsequent frames - :param frame_start: Frame number to start on, default 0 - """ - FrameReader.__init__(self, topname, trajname, frame_start) - - from simpletraj import trajectory - - if trajname is not None: - try: - self._traj = trajectory.get_trajectory(trajname) - except OSError as e: - if not os.path.isfile(trajname): - raise FileNotFoundError(trajname) from e - e.args = ("Error opening file '{0}'".format(trajname),) - raise - - self.num_atoms = self._traj.numatoms - self.num_frames = self._traj.numframes - - def _initialise_frame(self, frame): - """ - Parse a GROMACS GRO file and create Residues/Atoms - Required before reading coordinates from XTC file - - :param frame: Frame instance to initialise from GRO file - """ - with open(self._topname) as gro: - frame.name = gro.readline().strip() - self.num_atoms = int(gro.readline()) - frame.natoms = self.num_atoms - resnum_last = None - atnum = 0 - - unpacker = FixedFormatUnpacker("I5,2A5,5X,3F8", - FixedFormatUnpacker.FormatStyle.Fortran) - - for _ in range(self.num_atoms): - resnum, resname, atomname, *coords = unpacker.unpack(gro.readline()) - coords = np.array(coords, dtype=np.float32) - - if resnum != resnum_last: - frame.residues.append(Residue(name=resname, - num=resnum)) - resnum_last = resnum - atnum = 0 - - atom = Atom(name=atomname, num=atnum, coords=coords) - frame.residues[-1].add_atom(atom) - atnum += 1 - - frame.box = np.array([float(x) for x in gro.readline().split()[0:3]], dtype=np.float32) - - def _read_frame_number(self, number): - """ - Read next frame from XTC using simpletraj library. - """ - self._traj.get_frame(number) - # SimpleTraj uses Angstrom, we want nanometers - xyz = self._traj.x / 10 - box = np.diag(self._traj.box)[0:3] / 10 - - return self._traj.time, xyz, box - - -class FrameReaderMDTraj(FrameReader): - def __init__(self, topname, trajname=None, frame_start=0): - """ - Open input XTC file from which to read coordinates using mdtraj library. - - :param topname: GROMACS GRO file from which to read topology - :param trajname: GROMACS XTC file to read subsequent frames - """ - FrameReader.__init__(self, topname, trajname, frame_start) - - try: - import mdtraj - except ImportError as e: - if "scipy" in e.msg: - e.msg = "The MDTraj FrameReader also requires Scipy" - else: - e.msg = "The MDTraj FrameReader requires the module MDTraj (and probably Scipy)" - raise - logger.warning("WARNING: Using MDTraj which renames solvent molecules") - - try: - if trajname is None: - self._traj = mdtraj.load(topname) - else: - self._traj = mdtraj.load(trajname, top=topname) - except OSError as e: - if not os.path.isfile(topname): - raise FileNotFoundError(topname) from e - if not os.path.isfile(trajname): - raise FileNotFoundError(trajname) from e - e.args = ("Error opening file '{0}' or '{1}'".format(topname, trajname),) - raise - - self.num_atoms = self._traj.n_atoms - self.num_frames = self._traj.n_frames - - def _initialise_frame(self, frame): - """ - Parse a GROMACS GRO file and create Residues/Atoms - Required before reading coordinates from XTC file - - :param frame: Frame instance to initialise from GRO file - """ - import mdtraj - top = mdtraj.load(self._topname) - - frame.name = "" - self.num_atoms = top.n_atoms - frame.natoms = top.n_atoms - - frame.residues = [Residue(name=res.name, num=res.resSeq) for res in top.topology.residues] - - for atom in top.topology.atoms: - new_atom = Atom(name=atom.name, num=atom.serial, - coords=top.xyz[0][atom.index]) - frame.residues[atom.residue.index].add_atom(new_atom) - - frame.box = top.unitcell_lengths[0] - - def _read_frame_number(self, number): - """ - Read next frame from XTC using mdtraj library. - """ - return self._traj.time[number], self._traj.xyz[number], self._traj.unitcell_lengths[number] - - -class FrameReaderMDAnalysis(FrameReader): - def __init__(self, topname, trajname=None, frame_start=0): - import MDAnalysis - - super().__init__(topname, trajname, frame_start) - - if trajname is None: - self._traj = MDAnalysis.Universe(topname) - else: - self._traj = MDAnalysis.Universe(topname, trajname) - - self.num_atoms = self._traj.atoms.n_atoms - self.num_frames = self._traj.trajectory.n_frames - - def _initialise_frame(self, frame): - frame.name = "" - frame.natoms = self.num_atoms - - import MDAnalysis - topol = MDAnalysis.Universe(self._topname) - frame.box = topol.dimensions[0:3] / 10. - - for res in topol.residues: - residue = Residue(name=res.resname, num=res.resnum) - for atom in res.atoms: - residue.add_atom(Atom(name=atom.name, num=atom.id, coords=atom.position / 10.)) - frame.residues.append(residue) - - def _read_frame_number(self, number): - traj_frame = self._traj.trajectory[number] - return traj_frame.time, traj_frame.positions / 10., traj_frame.dimensions[0:3] / 10. - - class Frame: """ Hold Atom data separated into Residues @@ -345,14 +131,8 @@ class Frame: self._xtc_buffer = None if gro is not None: - open_xtc = {"simpletraj": FrameReaderSimpleTraj, - "mdtraj": FrameReaderMDTraj, - "mdanalysis": FrameReaderMDAnalysis} - try: - self._trajreader = open_xtc[xtc_reader](gro, xtc, frame_start=frame_start) - except KeyError as e: - e.args = ("XTC reader {0} is not a valid option.".format(xtc_reader)) - raise + from .framereader import get_frame_reader + self._trajreader = get_frame_reader(gro, traj=xtc, frame_start=frame_start) self._trajreader.initialise_frame(self) diff --git a/pycgtool/framereader.py b/pycgtool/framereader.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4fcb3d79ffd25d0a30c29cf63a96da5112ed73 --- /dev/null +++ b/pycgtool/framereader.py @@ -0,0 +1,275 @@ +""" +This module contains classes for reading trajectories into a Frame instance. + +Multiple readers are defined, allowing different underlying trajectory libraries to be used. +This module is tightly coupled with the frame.py module, so shares a set of unit tests. +""" + +import os +import abc +import itertools +import logging +import collections + +import numpy as np + +from .frame import Atom, Residue +from .util import FixedFormatUnpacker + +logger = logging.getLogger(__name__) + + +class UnsupportedFormatException(Exception): + pass + + +def get_frame_reader(top, traj=None, frame_start=0, name=None): + readers = collections.OrderedDict([ + ("simpletraj", FrameReaderSimpleTraj), + ("mdtraj", FrameReaderMDTraj), + ("mdanalysis", FrameReaderMDAnalysis), + ]) + + try: + return readers[name](top, traj, frame_start) + except KeyError as e: + if name is not None: + e.args = ("Frame reader '{0}' is not a valid option.".format(name),) + raise + for name, reader in readers.items(): # Return first reader that accepts given files + try: + return reader(top, traj, frame_start) + except (UnsupportedFormatException, ImportError): + continue + raise UnsupportedFormatException("None of the available readers support the trajector format provided") + + +class FrameReader(metaclass=abc.ABCMeta): + def __init__(self, topname, trajname=None, frame_start=0): + self._topname = topname + self._trajname = trajname + self._frame_number = frame_start + + self.num_atoms = 0 + self.num_frames = 0 + + def initialise_frame(self, frame): + self._initialise_frame(frame) + + def read_next(self, frame): + result = self.read_frame_number(self._frame_number, frame) + if result: + self._frame_number += 1 + return result + + def read_frame_number(self, number, frame): + try: + time, coords, box = self._read_frame_number(number) + frame.time = time + frame.box = box + + for atom, coord_line in zip(itertools.chain.from_iterable(frame.residues), coords): + atom.coords = coord_line + + except (IndexError, AttributeError): + # IndexError - run out of xtc frames + # AttributeError - we didn't provide an xtc + return False + return True + + @abc.abstractmethod + def _initialise_frame(self, frame): + pass + + @abc.abstractmethod + def _read_frame_number(self, number): + pass + + +class FrameReaderSimpleTraj(FrameReader): + def __init__(self, topname, trajname=None, frame_start=0): + """ + Open input XTC file from which to read coordinates using simpletraj library. + + :param topname: MD topology file - not used + :param trajname: MD trajectory file to read subsequent frames + :param frame_start: Frame number to start on, default 0 + """ + FrameReader.__init__(self, topname, trajname, frame_start) + + with open(self._topname) as gro: + gro.readline() + try: + self.num_atoms = int(gro.readline()) + except ValueError as e: + raise UnsupportedFormatException from e + + from simpletraj import trajectory + + if trajname is not None: + try: + self._traj = trajectory.get_trajectory(trajname) + except OSError as e: + if not os.path.isfile(trajname): + raise FileNotFoundError(trajname) from e + e.args = ("Error opening file '{0}'".format(trajname),) + raise + + if self._traj.numatoms != self.num_atoms: + raise UnsupportedFormatException + self.num_frames = self._traj.numframes + + def _initialise_frame(self, frame): + """ + Parse a GROMACS GRO file and create Residues/Atoms + Required before reading coordinates from XTC file + + :param frame: Frame instance to initialise from GRO file + """ + with open(self._topname) as gro: + frame.name = gro.readline().strip() + try: + self.num_atoms = int(gro.readline()) + except ValueError as e: + raise UnsupportedFormatException from e + + frame.natoms = self.num_atoms + resnum_last = None + atnum = 0 + + unpacker = FixedFormatUnpacker("I5,2A5,5X,3F8", + FixedFormatUnpacker.FormatStyle.Fortran) + + for _ in range(self.num_atoms): + resnum, resname, atomname, *coords = unpacker.unpack(gro.readline()) + coords = np.array(coords, dtype=np.float32) + + if resnum != resnum_last: + frame.residues.append(Residue(name=resname, + num=resnum)) + resnum_last = resnum + atnum = 0 + + atom = Atom(name=atomname, num=atnum, coords=coords) + frame.residues[-1].add_atom(atom) + atnum += 1 + + frame.box = np.array([float(x) for x in gro.readline().split()[0:3]], dtype=np.float32) + + def _read_frame_number(self, number): + """ + Read next frame from XTC using simpletraj library. + """ + self._traj.get_frame(number) + # SimpleTraj uses Angstrom, we want nanometers + xyz = self._traj.x / 10 + box = np.diag(self._traj.box)[0:3] / 10 + + return self._traj.time, xyz, box + + +class FrameReaderMDTraj(FrameReader): + def __init__(self, topname, trajname=None, frame_start=0): + """ + Open input XTC file from which to read coordinates using mdtraj library. + + :param topname: GROMACS GRO file from which to read topology + :param trajname: GROMACS XTC file to read subsequent frames + """ + FrameReader.__init__(self, topname, trajname, frame_start) + + try: + import mdtraj + except ImportError as e: + if "scipy" in e.msg: + e.msg = "The MDTraj FrameReader also requires Scipy" + else: + e.msg = "The MDTraj FrameReader requires the module MDTraj (and probably Scipy)" + raise + logger.warning("WARNING: Using MDTraj which renames solvent molecules") + + try: + if trajname is None: + self._traj = mdtraj.load(topname) + else: + self._traj = mdtraj.load(trajname, top=topname) + except OSError as e: + if not os.path.isfile(topname): + raise FileNotFoundError(topname) from e + if trajname is not None and not os.path.isfile(trajname): + raise FileNotFoundError(trajname) from e + if "no loader for filename" in repr(e): + raise UnsupportedFormatException from e + e.args = ("Error opening file '{0}' or '{1}'".format(topname, trajname),) + raise + + self.num_atoms = self._traj.n_atoms + self.num_frames = self._traj.n_frames + + def _initialise_frame(self, frame): + """ + Parse a GROMACS GRO file and create Residues/Atoms + Required before reading coordinates from XTC file + + :param frame: Frame instance to initialise from GRO file + """ + import mdtraj + top = mdtraj.load(self._topname) + + frame.name = "" + self.num_atoms = top.n_atoms + frame.natoms = top.n_atoms + + frame.residues = [Residue(name=res.name, num=res.resSeq) for res in top.topology.residues] + + for atom in top.topology.atoms: + new_atom = Atom(name=atom.name, num=atom.serial, + coords=top.xyz[0][atom.index]) + frame.residues[atom.residue.index].add_atom(new_atom) + + frame.box = top.unitcell_lengths[0] + + def _read_frame_number(self, number): + """ + Read next frame from XTC using mdtraj library. + """ + return self._traj.time[number], self._traj.xyz[number], self._traj.unitcell_lengths[number] + + +class FrameReaderMDAnalysis(FrameReader): + def __init__(self, topname, trajname=None, frame_start=0): + import MDAnalysis + + super().__init__(topname, trajname, frame_start) + + try: + if trajname is None: + self._traj = MDAnalysis.Universe(topname) + else: + self._traj = MDAnalysis.Universe(topname, trajname) + except ValueError as e: + if "isn't a valid topology format" in repr(e): + raise UnsupportedFormatException from e + raise + + self.num_atoms = self._traj.atoms.n_atoms + self.num_frames = self._traj.trajectory.n_frames + + def _initialise_frame(self, frame): + frame.name = "" + frame.natoms = self.num_atoms + + import MDAnalysis + topol = MDAnalysis.Universe(self._topname) + frame.box = topol.dimensions[0:3] / 10. + + for res in topol.residues: + residue = Residue(name=res.resname, num=res.resnum) + for atom in res.atoms: + residue.add_atom(Atom(name=atom.name, num=atom.id, coords=atom.position / 10.)) + frame.residues.append(residue) + + def _read_frame_number(self, number): + traj_frame = self._traj.trajectory[number] + return traj_frame.time, traj_frame.positions / 10., traj_frame.dimensions[0:3] / 10. + diff --git a/test/test_frame.py b/test/test_frame.py index 344ff8b7c71bb355f1baf13c67f9a26f519c46ef..6b9f1846af622fb4fc6e350ee5c58b214c8033e4 100644 --- a/test/test_frame.py +++ b/test/test_frame.py @@ -6,7 +6,8 @@ import logging import numpy as np from pycgtool.frame import Atom, Residue, Frame -from pycgtool.frame import FrameReaderSimpleTraj, FrameReaderMDAnalysis, FrameReaderMDTraj, FrameReader +from pycgtool.framereader import FrameReaderSimpleTraj, FrameReaderMDAnalysis, FrameReaderMDTraj +from pycgtool.framereader import FrameReader, get_frame_reader, UnsupportedFormatException try: import mdtraj @@ -94,7 +95,7 @@ class FrameTest(unittest.TestCase): self.helper_read_xtc(frame, first_only=True) - @unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") + @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present") def test_frame_mdtraj_read_gro(self): logging.disable(logging.WARNING) frame = Frame("test/data/water.gro", xtc_reader="mdtraj") @@ -109,13 +110,24 @@ class FrameTest(unittest.TestCase): self.helper_read_xtc(frame, first_only=True) - @unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") + @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present") def test_frame_mdtraj_read_pdb(self): reader = FrameReaderMDTraj("test/data/water.pdb") frame = Frame.instance_from_reader(reader) self.helper_read_xtc(frame, first_only=True, skip_names=True) + @unittest.skipIf(not mdtraj_present and not mdanalysis_present, "Neither MDTraj or MDAnalysis is present") + def test_frame_any_read_pdb(self): + reader = get_frame_reader("test/data/water.pdb") + frame = Frame.instance_from_reader(reader) + + self.helper_read_xtc(frame, first_only=True, skip_names=True) + + def test_frame_any_read_unsupported(self): + with self.assertRaises(UnsupportedFormatException): + reader = get_frame_reader("test/data/dppc.map") + @unittest.skipIf(not mdanalysis_present, "MDAnalysis not present") def test_frame_mdanalysis_read_pdb(self): reader = FrameReaderMDAnalysis("test/data/water.pdb") @@ -134,7 +146,7 @@ class FrameTest(unittest.TestCase): xtc_reader="simpletraj") self.assertEqual(11, frame.numframes) - @unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") + @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present") def test_frame_read_xtc_mdtraj_numframes(self): logging.disable(logging.WARNING) frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", @@ -147,7 +159,7 @@ class FrameTest(unittest.TestCase): xtc_reader="simpletraj") self.helper_read_xtc(frame) - @unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") + @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present") def test_frame_mdtraj_read_xtc(self): logging.disable(logging.WARNING) frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", @@ -163,7 +175,7 @@ class FrameTest(unittest.TestCase): self.helper_read_xtc(frame) - @unittest.skipIf(not mdtraj_present, "MDTRAJ or Scipy not present") + @unittest.skipIf(not mdtraj_present, "MDTraj or Scipy not present") def test_frame_write_xtc_mdtraj(self): try: os.remove("water_test2.xtc")