Skip to content
Snippets Groups Projects
Commit 45500a20 authored by James Graham's avatar James Graham
Browse files

refactor: move __main__ code into class

Avoids passing so much data around manually
parent ab970108
No related branches found
No related tags found
No related merge requests found
...@@ -24,68 +24,70 @@ class ArgumentValidationError(ValueError): ...@@ -24,68 +24,70 @@ class ArgumentValidationError(ValueError):
"""Exception raised for invalid combinations of command line arguments.""" """Exception raised for invalid combinations of command line arguments."""
def get_output_filepath(ext: str, config) -> pathlib.Path: class PyCGTOOL:
def __init__(self, config):
self.config = config
def get_output_filepath(self, ext: str) -> pathlib.Path:
"""Get file path for an output file by extension. """Get file path for an output file by extension.
:param ext: :param ext:
:param config: Program arguments from argparse
""" """
out_dir = pathlib.Path(config.out_dir) out_dir = pathlib.Path(self.config.out_dir)
return out_dir.joinpath(config.output_name + '.' + ext) return out_dir.joinpath(self.config.output_name + '.' + ext)
def measure_bonds(frame: Frame, mapping: typing.Optional[Mapping], def measure_bonds(self, frame: Frame, mapping: typing.Optional[Mapping]) -> None:
config) -> None:
"""Measure bonds at the end of a run. """Measure bonds at the end of a run.
:param frame: :param frame:
:param mapping: :param mapping:
:param config: Program arguments from argparse
""" """
bonds = BondSet(config.bondset, config) bonds = BondSet(self.config.bondset, self.config)
bonds.apply(frame) bonds.apply(frame)
if config.mapping and config.trajectory: if self.config.mapping and self.config.trajectory:
# Only perform Boltzmann Inversion if we have a mapping and a trajectory. # Only perform Boltzmann Inversion if we have a mapping and a trajectory.
# Otherwise we get infinite force constants. # Otherwise we get infinite force constants.
logger.info('Starting Boltzmann Inversion') logger.info('Starting Boltzmann Inversion')
bonds.boltzmann_invert() bonds.boltzmann_invert()
logger.info('Finished Boltzmann Inversion') logger.info('Finished Boltzmann Inversion')
if config.output_forcefield: if self.config.output_forcefield:
logger.info("Writing GROMACS forcefield directory") logger.info("Writing GROMACS forcefield directory")
out_dir = pathlib.Path(config.out_dir) out_dir = pathlib.Path(self.config.out_dir)
forcefield = ForceField(config.output_name, dir_path=out_dir) forcefield = ForceField(self.config.output_name, dir_path=out_dir)
forcefield.write(config.output_name, mapping, bonds) forcefield.write(self.config.output_name, mapping, bonds)
logger.info("Finished writing GROMACS forcefield directory") logger.info("Finished writing GROMACS forcefield directory")
else: else:
bonds.write_itp(get_output_filepath('itp', config), bonds.write_itp(self.get_output_filepath('itp'),
mapping=mapping) mapping=mapping)
if config.dump_measurements: if self.config.dump_measurements:
logger.info('Writing bond measurements to file') logger.info('Writing bond measurements to file')
bonds.dump_values(config.dump_n_values, config.out_dir) bonds.dump_values(self.config.dump_n_values, self.config.out_dir)
logger.info('Finished writing bond measurements to file') logger.info('Finished writing bond measurements to file')
def mapping_loop(frame: Frame, config) -> typing.Tuple[Frame, Mapping]: def mapping_loop(self, frame: Frame) -> typing.Tuple[Frame, Mapping]:
"""Perform mapping loop over input trajectory. """Perform mapping loop over input trajectory.
:param frame: :param frame:
:param config: Program arguments from argparse
""" """
logger.info('Starting AA->CG mapping') logger.info('Starting AA->CG mapping')
mapping = Mapping(config.mapping, config, itp_filename=config.itp) mapping = Mapping(self.config.mapping, self.config, itp_filename=self.config.itp)
cg_frame = mapping.apply(frame) cg_frame = mapping.apply(frame)
cg_frame.save(get_output_filepath('gro', config), frame_number=0) cg_frame.save(self.get_output_filepath('gro'), frame_number=0)
logging.info('Finished AA->CG mapping') logging.info('Finished AA->CG mapping')
self.train_backmapper(frame, cg_frame)
return cg_frame, mapping return cg_frame, mapping
def get_coords(frame: Frame, resname: str) -> np.ndarray: def get_coords(self, frame: Frame, resname: str) -> np.ndarray:
return np.concatenate([ return np.concatenate([
frame._trajectory.atom_slice([atom.index for atom in residue.atoms]).xyz frame._trajectory.atom_slice([atom.index for atom in residue.atoms]).xyz
for residue in frame._trajectory.topology.residues for residue in frame._trajectory.topology.residues
...@@ -93,7 +95,7 @@ def get_coords(frame: Frame, resname: str) -> np.ndarray: ...@@ -93,7 +95,7 @@ def get_coords(frame: Frame, resname: str) -> np.ndarray:
]) ])
def train_backmapper(aa_frame: Frame, cg_frame: Frame): def train_backmapper(self, aa_frame: Frame, cg_frame: Frame):
# resname = 'POPC' # resname = 'POPC'
# aa_coords = get_coords(aa_frame, resname) # aa_coords = get_coords(aa_frame, resname)
# cg_coords = get_coords(cg_frame, resname) # cg_coords = get_coords(cg_frame, resname)
...@@ -115,37 +117,34 @@ def train_backmapper(aa_frame: Frame, cg_frame: Frame): ...@@ -115,37 +117,34 @@ def train_backmapper(aa_frame: Frame, cg_frame: Frame):
aa_subset_traj.save('backmapped.gro') aa_subset_traj.save('backmapped.gro')
logger.info('Finished testing backmapper') logger.info('Finished testing backmapper')
backmapper.save('backmapper.pkl') backmapper.save(self.get_output_filepath('backmapper.pkl'))
def full_run(config): def full_run(self):
"""Main function of the program PyCGTOOL. """Main function of the program PyCGTOOL.
Performs the complete AA->CG mapping and outputs a files dependent on given input. Performs the complete AA->CG mapping and outputs a files dependent on given input.
:param config: Program arguments from argparse
""" """
frame = Frame( frame = Frame(
topology_file=config.topology, topology_file=self.config.topology,
trajectory_file=config.trajectory, # May be None trajectory_file=self.config.trajectory, # May be None
frame_start=config.begin, frame_start=self.config.begin,
frame_end=config.end) frame_end=self.config.end)
frame._trajectory.make_molecules_whole(inplace=True) frame._trajectory.make_molecules_whole(inplace=True)
if config.mapping: if self.config.mapping:
cg_frame, mapping = mapping_loop(frame, config) cg_frame, mapping = self.mapping_loop(frame)
train_backmapper(frame, cg_frame)
else: else:
logger.info('Skipping AA->CG mapping') logger.info('Skipping AA->CG mapping')
mapping = None mapping = None
cg_frame = frame cg_frame = frame
if config.output_xtc: if self.config.output_xtc:
cg_frame.save(get_output_filepath('xtc', config)) cg_frame.save(self.get_output_filepath('xtc'))
if config.bondset: if self.config.bondset:
measure_bonds(cg_frame, mapping, config) self.measure_bonds(cg_frame, mapping)
class BooleanAction(argparse.Action): class BooleanAction(argparse.Action):
...@@ -280,7 +279,7 @@ def main(): ...@@ -280,7 +279,7 @@ def main():
logging.basicConfig(level=args.log_level, logging.basicConfig(level=args.log_level,
format='%(message)s', format='%(message)s',
datefmt='[%X]', datefmt='[%X]',
handlers=[RichHandler()]) handlers=[RichHandler(rich_tracebacks=True)])
banner = """\ banner = """\
_____ _____ _____ _______ ____ ____ _ _____ _____ _____ _______ ____ ____ _
...@@ -305,19 +304,21 @@ def main(): ...@@ -305,19 +304,21 @@ def main():
logger.info(30 * '-') logger.info(30 * '-')
try: try:
pycgtool = PyCGTOOL(args)
if args.profile: if args.profile:
with cProfile.Profile() as profiler: with cProfile.Profile() as profiler:
full_run(args) pycgtool.full_run()
profiler.dump_stats('gprof.out') profiler.dump_stats('gprof.out')
else: else:
full_run(args) pycgtool.full_run()
logger.info('Finished processing - goodbye!') logger.info('Finished processing - goodbye!')
except Exception as exc: except Exception as exc:
logger.error(exc) logger.exception(exc)
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment