Skip to content
Snippets Groups Projects
Commit 45500a20 authored by James Graham's avatar James Graham
Browse files

refactor: move __main__ code into class

Avoids passing so much data around manually
parent ab970108
Branches main
No related tags found
No related merge requests found
......@@ -24,68 +24,70 @@ class ArgumentValidationError(ValueError):
"""Exception raised for invalid combinations of command line arguments."""
def get_output_filepath(ext: str, config) -> pathlib.Path:
class PyCGTOOL:
def __init__(self, config):
self.config = config
def get_output_filepath(self, ext: str) -> pathlib.Path:
"""Get file path for an output file by extension.
:param ext:
:param config: Program arguments from argparse
"""
out_dir = pathlib.Path(config.out_dir)
return out_dir.joinpath(config.output_name + '.' + ext)
out_dir = pathlib.Path(self.config.out_dir)
return out_dir.joinpath(self.config.output_name + '.' + ext)
def measure_bonds(frame: Frame, mapping: typing.Optional[Mapping],
config) -> None:
def measure_bonds(self, frame: Frame, mapping: typing.Optional[Mapping]) -> None:
"""Measure bonds at the end of a run.
:param frame:
:param mapping:
:param config: Program arguments from argparse
"""
bonds = BondSet(config.bondset, config)
bonds = BondSet(self.config.bondset, self.config)
bonds.apply(frame)
if config.mapping and config.trajectory:
if self.config.mapping and self.config.trajectory:
# Only perform Boltzmann Inversion if we have a mapping and a trajectory.
# Otherwise we get infinite force constants.
logger.info('Starting Boltzmann Inversion')
bonds.boltzmann_invert()
logger.info('Finished Boltzmann Inversion')
if config.output_forcefield:
if self.config.output_forcefield:
logger.info("Writing GROMACS forcefield directory")
out_dir = pathlib.Path(config.out_dir)
forcefield = ForceField(config.output_name, dir_path=out_dir)
forcefield.write(config.output_name, mapping, bonds)
out_dir = pathlib.Path(self.config.out_dir)
forcefield = ForceField(self.config.output_name, dir_path=out_dir)
forcefield.write(self.config.output_name, mapping, bonds)
logger.info("Finished writing GROMACS forcefield directory")
else:
bonds.write_itp(get_output_filepath('itp', config),
bonds.write_itp(self.get_output_filepath('itp'),
mapping=mapping)
if config.dump_measurements:
if self.config.dump_measurements:
logger.info('Writing bond measurements to file')
bonds.dump_values(config.dump_n_values, config.out_dir)
bonds.dump_values(self.config.dump_n_values, self.config.out_dir)
logger.info('Finished writing bond measurements to file')
def mapping_loop(frame: Frame, config) -> typing.Tuple[Frame, Mapping]:
def mapping_loop(self, frame: Frame) -> typing.Tuple[Frame, Mapping]:
"""Perform mapping loop over input trajectory.
:param frame:
:param config: Program arguments from argparse
"""
logger.info('Starting AA->CG mapping')
mapping = Mapping(config.mapping, config, itp_filename=config.itp)
mapping = Mapping(self.config.mapping, self.config, itp_filename=self.config.itp)
cg_frame = mapping.apply(frame)
cg_frame.save(get_output_filepath('gro', config), frame_number=0)
cg_frame.save(self.get_output_filepath('gro'), frame_number=0)
logging.info('Finished AA->CG mapping')
self.train_backmapper(frame, cg_frame)
return cg_frame, mapping
def get_coords(frame: Frame, resname: str) -> np.ndarray:
def get_coords(self, frame: Frame, resname: str) -> np.ndarray:
return np.concatenate([
frame._trajectory.atom_slice([atom.index for atom in residue.atoms]).xyz
for residue in frame._trajectory.topology.residues
......@@ -93,7 +95,7 @@ def get_coords(frame: Frame, resname: str) -> np.ndarray:
])
def train_backmapper(aa_frame: Frame, cg_frame: Frame):
def train_backmapper(self, aa_frame: Frame, cg_frame: Frame):
# resname = 'POPC'
# aa_coords = get_coords(aa_frame, resname)
# cg_coords = get_coords(cg_frame, resname)
......@@ -115,37 +117,34 @@ def train_backmapper(aa_frame: Frame, cg_frame: Frame):
aa_subset_traj.save('backmapped.gro')
logger.info('Finished testing backmapper')
backmapper.save('backmapper.pkl')
backmapper.save(self.get_output_filepath('backmapper.pkl'))
def full_run(config):
def full_run(self):
"""Main function of the program PyCGTOOL.
Performs the complete AA->CG mapping and outputs a files dependent on given input.
:param config: Program arguments from argparse
"""
frame = Frame(
topology_file=config.topology,
trajectory_file=config.trajectory, # May be None
frame_start=config.begin,
frame_end=config.end)
topology_file=self.config.topology,
trajectory_file=self.config.trajectory, # May be None
frame_start=self.config.begin,
frame_end=self.config.end)
frame._trajectory.make_molecules_whole(inplace=True)
if config.mapping:
cg_frame, mapping = mapping_loop(frame, config)
train_backmapper(frame, cg_frame)
if self.config.mapping:
cg_frame, mapping = self.mapping_loop(frame)
else:
logger.info('Skipping AA->CG mapping')
mapping = None
cg_frame = frame
if config.output_xtc:
cg_frame.save(get_output_filepath('xtc', config))
if self.config.output_xtc:
cg_frame.save(self.get_output_filepath('xtc'))
if config.bondset:
measure_bonds(cg_frame, mapping, config)
if self.config.bondset:
self.measure_bonds(cg_frame, mapping)
class BooleanAction(argparse.Action):
......@@ -280,7 +279,7 @@ def main():
logging.basicConfig(level=args.log_level,
format='%(message)s',
datefmt='[%X]',
handlers=[RichHandler()])
handlers=[RichHandler(rich_tracebacks=True)])
banner = """\
_____ _____ _____ _______ ____ ____ _
......@@ -305,19 +304,21 @@ def main():
logger.info(30 * '-')
try:
pycgtool = PyCGTOOL(args)
if args.profile:
with cProfile.Profile() as profiler:
full_run(args)
pycgtool.full_run()
profiler.dump_stats('gprof.out')
else:
full_run(args)
pycgtool.full_run()
logger.info('Finished processing - goodbye!')
except Exception as exc:
logger.error(exc)
logger.exception(exc)
if __name__ == "__main__":
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment