Skip to content
Snippets Groups Projects
Verified Commit 4639bcd6 authored by James Graham's avatar James Graham
Browse files

refactor: add backmapper resname to CLI

parent 65d1b348
No related branches found
No related tags found
No related merge requests found
......@@ -44,8 +44,8 @@ class PyCGTOOL:
self.out_frame.save(self.get_output_filepath('gro'),
frame_number=0)
if self.out_frame.n_frames > 1:
self.train_backmapper()
if self.config.backmapper_resname and self.out_frame.n_frames > 1:
self.train_backmapper(self.config.resname)
self.bondset = None
if self.config.bondset:
......@@ -93,29 +93,14 @@ class PyCGTOOL:
self.config.out_dir)
logger.info('Finished writing bond measurements to file')
@staticmethod
def get_coords(frame: Frame, resname: str) -> np.ndarray:
return np.concatenate([
frame._trajectory.atom_slice(
[atom.index for atom in residue.atoms]).xyz
for residue in frame._trajectory.topology.residues
if residue.name == resname
])
def train_backmapper(self):
# resname = 'POPC'
# aa_coords = get_coords(aa_frame, resname)
# cg_coords = get_coords(cg_frame, resname)
sel = 'resid 0'
def train_backmapper(self, resname: str):
sel = f'resname {resname}'
aa_subset_traj = self.in_frame._trajectory.atom_slice(
self.in_frame._trajectory.topology.select(sel))
cg_subset_traj = self.out_frame._trajectory.atom_slice(
self.out_frame._trajectory.topology.select(sel))
cg_subset_traj.save('cg_test.gro')
aa_subset_traj.save('aa_test.gro')
logger.info('Training backmapper')
# Param x_valence is approximate number of bonds per CG bead
# Values greater than 2 fail for small molecules e.g. sugar test case
......@@ -123,12 +108,6 @@ class PyCGTOOL:
backmapper.fit(cg_subset_traj.xyz, aa_subset_traj.xyz)
logger.info('Finished training backmapper')
# logger.info('Testing backmapper')
# backmapped = backmapper.transform(cg_subset_traj.xyz)
# aa_subset_traj.xyz = backmapped
# aa_subset_traj.save('backmapped.gro')
# logger.info('Finished testing backmapper')
backmapper.save(self.get_output_filepath('backmapper.pkl'))
......@@ -198,6 +177,8 @@ def parse_arguments(arg_list):
mapping_options.add_argument("--virtual-map-center", default="geom",
choices=["geom", "mass"],
help="Virtual site mapping method")
mapping_options.add_argument("--backmapper-resname", default=None,
help="Residue name for which to train a backmapper")
# Bond options
bond_options = parser.add_argument_group("bond options")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment