diff --git a/pycgtool/__main__.py b/pycgtool/__main__.py index d339bfe1c8be5e61ba222bb7f9d6970a85d4bf51..6a10bd5b37323fadd31f9ed40e62334527357fdc 100755 --- a/pycgtool/__main__.py +++ b/pycgtool/__main__.py @@ -44,8 +44,8 @@ class PyCGTOOL: self.out_frame.save(self.get_output_filepath('gro'), frame_number=0) - if self.out_frame.n_frames > 1: - self.train_backmapper() + if self.config.backmapper_resname and self.out_frame.n_frames > 1: + self.train_backmapper(self.config.resname) self.bondset = None if self.config.bondset: @@ -93,29 +93,14 @@ class PyCGTOOL: self.config.out_dir) logger.info('Finished writing bond measurements to file') - @staticmethod - def get_coords(frame: Frame, resname: str) -> np.ndarray: - return np.concatenate([ - frame._trajectory.atom_slice( - [atom.index for atom in residue.atoms]).xyz - for residue in frame._trajectory.topology.residues - if residue.name == resname - ]) - - def train_backmapper(self): - # resname = 'POPC' - # aa_coords = get_coords(aa_frame, resname) - # cg_coords = get_coords(cg_frame, resname) - sel = 'resid 0' + def train_backmapper(self, resname: str): + sel = f'resname {resname}' aa_subset_traj = self.in_frame._trajectory.atom_slice( self.in_frame._trajectory.topology.select(sel)) cg_subset_traj = self.out_frame._trajectory.atom_slice( self.out_frame._trajectory.topology.select(sel)) - cg_subset_traj.save('cg_test.gro') - aa_subset_traj.save('aa_test.gro') - logger.info('Training backmapper') # Param x_valence is approximate number of bonds per CG bead # Values greater than 2 fail for small molecules e.g. sugar test case @@ -123,12 +108,6 @@ class PyCGTOOL: backmapper.fit(cg_subset_traj.xyz, aa_subset_traj.xyz) logger.info('Finished training backmapper') - # logger.info('Testing backmapper') - # backmapped = backmapper.transform(cg_subset_traj.xyz) - # aa_subset_traj.xyz = backmapped - # aa_subset_traj.save('backmapped.gro') - # logger.info('Finished testing backmapper') - backmapper.save(self.get_output_filepath('backmapper.pkl')) @@ -198,6 +177,8 @@ def parse_arguments(arg_list): mapping_options.add_argument("--virtual-map-center", default="geom", choices=["geom", "mass"], help="Virtual site mapping method") + mapping_options.add_argument("--backmapper-resname", default=None, + help="Residue name for which to train a backmapper") # Bond options bond_options = parser.add_argument_group("bond options")