Skip to content
Snippets Groups Projects
Commit 2992e1da authored by James Graham's avatar James Graham
Browse files

refactor: clean up flow in __main__

parent 45500a20
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ import logging
import pathlib
import sys
import textwrap
import typing
import time
from mdplus.multiscale import GLIMPS
import numpy as np
......@@ -28,6 +28,33 @@ class PyCGTOOL:
def __init__(self, config):
self.config = config
self.in_frame = Frame(
topology_file=self.config.topology,
trajectory_file=self.config.trajectory, # May be None
frame_start=self.config.begin,
frame_end=self.config.end)
self.mapping = None
self.out_frame = self.in_frame
if self.config.mapping:
self.mapping = Mapping(self.config.mapping,
self.config,
itp_filename=self.config.itp)
self.out_frame = self.mapping.apply(self.in_frame)
self.out_frame.save(self.get_output_filepath('gro'),
frame_number=0)
if self.out_frame.n_frames > 1:
self.train_backmapper(self.in_frame, self.out_frame)
self.bondset = None
if self.config.bondset:
self.bondset = BondSet(self.config.bondset, self.config)
self.measure_bonds()
if self.config.output_xtc:
self.out_frame.save(self.get_output_filepath('xtc'))
def get_output_filepath(self, ext: str) -> pathlib.Path:
"""Get file path for an output file by extension.
......@@ -36,117 +63,73 @@ class PyCGTOOL:
out_dir = pathlib.Path(self.config.out_dir)
return out_dir.joinpath(self.config.output_name + '.' + ext)
def measure_bonds(self) -> None:
"""Measure bonds at the end of a run."""
self.bondset.apply(self.out_frame)
def measure_bonds(self, frame: Frame, mapping: typing.Optional[Mapping]) -> None:
"""Measure bonds at the end of a run.
:param frame:
:param mapping:
"""
bonds = BondSet(self.config.bondset, self.config)
bonds.apply(frame)
if self.config.mapping and self.config.trajectory:
if self.mapping is not None and self.out_frame.n_frames > 1:
# Only perform Boltzmann Inversion if we have a mapping and a trajectory.
# Otherwise we get infinite force constants.
logger.info('Starting Boltzmann Inversion')
bonds.boltzmann_invert()
self.bondset.boltzmann_invert()
logger.info('Finished Boltzmann Inversion')
if self.config.output_forcefield:
logger.info("Writing GROMACS forcefield directory")
out_dir = pathlib.Path(self.config.out_dir)
forcefield = ForceField(self.config.output_name, dir_path=out_dir)
forcefield.write(self.config.output_name, mapping, bonds)
forcefield = ForceField(self.config.output_name,
dir_path=out_dir)
forcefield.write(self.config.output_name, self.mapping,
self.bondset)
logger.info("Finished writing GROMACS forcefield directory")
else:
bonds.write_itp(self.get_output_filepath('itp'),
mapping=mapping)
self.bondset.write_itp(self.get_output_filepath('itp'),
mapping=self.mapping)
if self.config.dump_measurements:
logger.info('Writing bond measurements to file')
bonds.dump_values(self.config.dump_n_values, self.config.out_dir)
self.bondset.dump_values(self.config.dump_n_values,
self.config.out_dir)
logger.info('Finished writing bond measurements to file')
def mapping_loop(self, frame: Frame) -> typing.Tuple[Frame, Mapping]:
"""Perform mapping loop over input trajectory.
:param frame:
"""
logger.info('Starting AA->CG mapping')
mapping = Mapping(self.config.mapping, self.config, itp_filename=self.config.itp)
cg_frame = mapping.apply(frame)
cg_frame.save(self.get_output_filepath('gro'), frame_number=0)
logging.info('Finished AA->CG mapping')
self.train_backmapper(frame, cg_frame)
return cg_frame, mapping
def get_coords(self, frame: Frame, resname: str) -> np.ndarray:
return np.concatenate([
frame._trajectory.atom_slice([atom.index for atom in residue.atoms]).xyz
frame._trajectory.atom_slice(
[atom.index for atom in residue.atoms]).xyz
for residue in frame._trajectory.topology.residues
if residue.name == resname
])
def train_backmapper(self, aa_frame: Frame, cg_frame: Frame):
# resname = 'POPC'
# aa_coords = get_coords(aa_frame, resname)
# cg_coords = get_coords(cg_frame, resname)
cg_subset_traj = cg_frame._trajectory.atom_slice(cg_frame._trajectory.topology.select('resid 1'))
aa_subset_traj = aa_frame._trajectory.atom_slice(aa_frame._trajectory.topology.select('resid 1'))
cg_subset_traj = cg_frame._trajectory.atom_slice(
cg_frame._trajectory.topology.select('resid 0'))
aa_subset_traj = aa_frame._trajectory.atom_slice(
aa_frame._trajectory.topology.select('resid 0'))
cg_subset_traj.save('cg_test.gro')
aa_subset_traj.save('aa_test.gro')
logger.info('Training backmapper')
backmapper = GLIMPS()
# Param x_valence is approximate number of bonds per CG bead
# Values greater than 2 fail for small molecules e.g. sugar test case
backmapper = GLIMPS(x_valence=2)
backmapper.fit(cg_subset_traj.xyz, aa_subset_traj.xyz)
logger.info('Finished training backmapper')
logger.info('Testing backmapper')
backmapped = backmapper.transform(cg_subset_traj.xyz)
aa_subset_traj.xyz = backmapped
aa_subset_traj.save('backmapped.gro')
logger.info('Finished testing backmapper')
# logger.info('Testing backmapper')
# backmapped = backmapper.transform(cg_subset_traj.xyz)
# aa_subset_traj.xyz = backmapped
# aa_subset_traj.save('backmapped.gro')
# logger.info('Finished testing backmapper')
backmapper.save(self.get_output_filepath('backmapper.pkl'))
def full_run(self):
"""Main function of the program PyCGTOOL.
Performs the complete AA->CG mapping and outputs a files dependent on given input.
"""
frame = Frame(
topology_file=self.config.topology,
trajectory_file=self.config.trajectory, # May be None
frame_start=self.config.begin,
frame_end=self.config.end)
frame._trajectory.make_molecules_whole(inplace=True)
if self.config.mapping:
cg_frame, mapping = self.mapping_loop(frame)
else:
logger.info('Skipping AA->CG mapping')
mapping = None
cg_frame = frame
if self.config.output_xtc:
cg_frame.save(self.get_output_filepath('xtc'))
if self.config.bondset:
self.measure_bonds(cg_frame, mapping)
class BooleanAction(argparse.Action):
"""Set up a boolean argparse argument with matching `--no` argument.
......@@ -274,6 +257,7 @@ def validate_arguments(args):
def main():
start_time = time.time()
args = parse_arguments(sys.argv[1:])
logging.basicConfig(level=args.log_level,
......@@ -281,7 +265,7 @@ def main():
datefmt='[%X]',
handlers=[RichHandler(rich_tracebacks=True)])
banner = """\
banner = r"""
_____ _____ _____ _______ ____ ____ _
| __ \ / ____/ ____|__ __/ __ \ / __ \| |
| |__) | _| | | | __ | | | | | | | | | |
......@@ -304,17 +288,19 @@ def main():
logger.info(30 * '-')
try:
pycgtool = PyCGTOOL(args)
if args.profile:
with cProfile.Profile() as profiler:
pycgtool.full_run()
pycgtool = PyCGTOOL(args)
profiler.dump_stats('gprof.out')
else:
pycgtool.full_run()
pycgtool = PyCGTOOL(args)
elapsed_time = time.time() - start_time
logger.info(
f'Processed {pycgtool.out_frame.n_frames} frames in {elapsed_time:.2f}s'
)
logger.info('Finished processing - goodbye!')
except Exception as exc:
......
......@@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
np.seterr(all="raise")
PathLike = typing.Union[pathlib.Path, str]
class UnsupportedFormatException(Exception):
"""Exception raised when a topology/trajectory format cannot be parsed."""
......@@ -35,10 +37,8 @@ class NonMatchingSystemError(ValueError):
class Frame:
"""Load and store data from a simulation trajectory."""
def __init__(self,
topology_file: typing.Optional[typing.Union[pathlib.Path,
str]] = None,
trajectory_file: typing.Optional[typing.Union[pathlib.Path,
str]] = None,
topology_file: typing.Optional[PathLike] = None,
trajectory_file: typing.Optional[PathLike] = None,
frame_start: int = 0,
frame_end: typing.Optional[int] = None):
"""Load a simulation trajectory.
......@@ -60,8 +60,10 @@ class Frame:
logging.info('Loading trajectory file - this may take a while')
self._trajectory = mdtraj.load(str(trajectory_file),
top=self._topology)
self._slice_trajectory(frame_start, frame_end)
logging.info('Finished loading trajectory file')
self._trajectory = self._slice_trajectory(frame_start, frame_end)
logging.info(
'Finished loading trajectory file - loaded %d frames',
self._trajectory.n_frames)
except ValueError as exc:
raise NonMatchingSystemError from exc
......
......@@ -344,6 +344,7 @@ class Mapping:
:param aa_residues: Iterable of atomistic residues to map from
:return: New CG Frame instance
"""
logger.info('Initialising output frame')
cg_frame = Frame()
missing_mappings = set()
......@@ -364,6 +365,7 @@ class Mapping:
for bmap in mol_map:
cg_frame.add_atom(bmap.name, None, cg_res)
logger.info('Finished initialising output frame')
return cg_frame
def apply(self, frame: Frame, cg_frame: typing.Optional[Frame] = None):
......@@ -385,6 +387,7 @@ class Mapping:
if not np.all(unitcell_lengths):
unitcell_lengths = None
logger.info('Applying AA->CG mapping')
residues_to_map = (res for res in frame.residues
if res.name in self._mappings)
for aa_res, cg_res in zip(residues_to_map, cg_frame.residues):
......@@ -417,6 +420,7 @@ class Mapping:
cg_frame.build_trajectory()
logger.info('Finished applying AA->CG mapping')
return cg_frame
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment