Skip to content
Snippets Groups Projects
Commit c02ddb7a authored by James Graham's avatar James Graham
Browse files

Added cleaner handling of Ctrl-C

Main loops now moved entirely inside progress iterator
parent 7cc99cbe
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import sys
from pycgtool.frame import Frame from pycgtool.frame import Frame
from pycgtool.mapping import Mapping from pycgtool.mapping import Mapping
...@@ -18,7 +19,7 @@ def main(args, config): ...@@ -18,7 +19,7 @@ def main(args, config):
:param args: Arguments from argparse :param args: Arguments from argparse
:param config: Configuration dictionary :param config: Configuration dictionary
""" """
frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin) frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin, xtc_reader="mdtraj")
if args.bnd: if args.bnd:
bonds = BondSet(args.bnd, config) bonds = BondSet(args.bnd, config)
...@@ -29,18 +30,21 @@ def main(args, config): ...@@ -29,18 +30,21 @@ def main(args, config):
cgframe.output(config.output_name + ".gro", format=config.output) cgframe.output(config.output_name + ".gro", format=config.output)
# Main loop - perform mapping and measurement on every frame in XTC # Main loop - perform mapping and measurement on every frame in XTC
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin def main_loop():
for _ in Progress(numframes, postwhile=frame.next_frame): nonlocal cgframe
frame.next_frame()
if args.map: if args.map:
cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
if config.output_xtc: if config.output_xtc:
cgframe.write_xtc(config.output_name + ".xtc") cgframe.write_xtc(config.output_name + ".xtc")
else: else:
cgframe = frame cgframe = frame
if args.bnd: if args.bnd:
bonds.apply(cgframe) bonds.apply(cgframe)
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
Progress(numframes, postwhile=main_loop).run()
if args.bnd: if args.bnd:
if args.map: if args.map:
bonds.boltzmann_invert() bonds.boltzmann_invert()
...@@ -68,10 +72,16 @@ def map_only(args, config): ...@@ -68,10 +72,16 @@ def map_only(args, config):
if args.xtc and (config.output_xtc or args.outputxtc): if args.xtc and (config.output_xtc or args.outputxtc):
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
for _ in Progress(numframes, postwhile=frame.next_frame):
# Main loop - perform mapping and measurement on every frame in XTC
def main_loop():
nonlocal cgframe
frame.next_frame()
cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
cgframe.write_xtc(config.output_name + ".xtc") cgframe.write_xtc(config.output_name + ".xtc")
Progress(numframes, postwhile=main_loop).run()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Perform coarse-grain mapping of atomistic trajectory") parser = argparse.ArgumentParser(description="Perform coarse-grain mapping of atomistic trajectory")
...@@ -84,7 +94,6 @@ if __name__ == "__main__": ...@@ -84,7 +94,6 @@ if __name__ == "__main__":
parser.add_argument('--interactive', default=False, action='store_true') parser.add_argument('--interactive', default=False, action='store_true')
parser.add_argument('--outputxtc', default=False, action='store_true') parser.add_argument('--outputxtc', default=False, action='store_true')
# parser.add_argument('-f', '--frames', type=int, default=-1, help="Number of frames to read")
input_files.add_argument('--begin', type=int, default=0, help="Frame number to begin") input_files.add_argument('--begin', type=int, default=0, help="Frame number to begin")
input_files.add_argument('--end', type=int, default=-1, help="Frame number to end") input_files.add_argument('--end', type=int, default=-1, help="Frame number to end")
...@@ -108,7 +117,10 @@ if __name__ == "__main__": ...@@ -108,7 +117,10 @@ if __name__ == "__main__":
parser.error("One or both of -m and -b is required.") parser.error("One or both of -m and -b is required.")
if args.interactive: if args.interactive:
try:
config.interactive() config.interactive()
except KeyboardInterrupt:
sys.exit(0)
else: else:
print("Using GRO: {0}".format(args.gro)) print("Using GRO: {0}".format(args.gro))
print("Using XTC: {0}".format(args.xtc)) print("Using XTC: {0}".format(args.xtc))
......
...@@ -146,6 +146,12 @@ class Frame: ...@@ -146,6 +146,12 @@ class Frame:
self._parse_itp(itp) self._parse_itp(itp)
def _open_xtc_simpletraj(self, xtc, gro=None): def _open_xtc_simpletraj(self, xtc, gro=None):
"""
Open input XTC file from which to read coordinates using simpletraj library.
:param xtc: GROMACS XTC file to read subsequent frames
:param gro: GROMACS GRO file - not used
"""
try: try:
self.xtc = trajectory.XtcTrajectory(xtc) self.xtc = trajectory.XtcTrajectory(xtc)
except OSError as e: except OSError as e:
...@@ -159,6 +165,12 @@ class Frame: ...@@ -159,6 +165,12 @@ class Frame:
self.numframes += self.xtc.numframes self.numframes += self.xtc.numframes
def _open_xtc_mdtraj(self, xtc, gro): def _open_xtc_mdtraj(self, xtc, gro):
"""
Open input XTC file from which to read coordinates using mdtraj library.
:param xtc: GROMACS XTC file to read subsequent frames
:param gro: GROMACS GRO file from which to read topology
"""
try: try:
self.xtc = mdtraj.load_xtc(xtc, top=gro) self.xtc = mdtraj.load_xtc(xtc, top=gro)
except OSError as e: except OSError as e:
...@@ -194,7 +206,7 @@ class Frame: ...@@ -194,7 +206,7 @@ class Frame:
def next_frame(self, exclude=None): def next_frame(self, exclude=None):
""" """
Read next frame from XTC Read next frame from input XTC.
:return: True if successful else False :return: True if successful else False
""" """
...@@ -207,6 +219,11 @@ class Frame: ...@@ -207,6 +219,11 @@ class Frame:
raise raise
def _next_frame_mdtraj(self, exclude=None): def _next_frame_mdtraj(self, exclude=None):
"""
Read next frame from XTC using mdtraj library.
:return: True if successful else False
"""
try: try:
i = 0 i = 0
# This returns a slice of length 1, properties still need to be indexed # This returns a slice of length 1, properties still need to be indexed
...@@ -227,6 +244,11 @@ class Frame: ...@@ -227,6 +244,11 @@ class Frame:
return False return False
def _next_frame_simpletraj(self, exclude=None): def _next_frame_simpletraj(self, exclude=None):
"""
Read next frame from XTC using simpletraj library.
:return: True if successful else False
"""
try: try:
self.xtc.get_frame(self.number) self.xtc.get_frame(self.number)
i = 0 i = 0
...@@ -249,7 +271,13 @@ class Frame: ...@@ -249,7 +271,13 @@ class Frame:
return False return False
def write_xtc(self, filename): def write_xtc(self, filename):
"""
Write frame to output XTC file.
:param filename: XTC filename to write to
"""
if self._xtc_buffer is None: if self._xtc_buffer is None:
backup_file(filename, verbose=True)
try: try:
self._xtc_buffer = mdtraj.formats.XTCTrajectoryFile(filename, mode="w") self._xtc_buffer = mdtraj.formats.XTCTrajectoryFile(filename, mode="w")
except NameError as e: except NameError as e:
...@@ -269,7 +297,6 @@ class Frame: ...@@ -269,7 +297,6 @@ class Frame:
box[0][i][i] = self.box[i] box[0][i][i] = self.box[i]
self._xtc_buffer.write(xyz, step=step, box=box) self._xtc_buffer.write(xyz, step=step, box=box)
# self._xtc_buffer.close()
def _parse_gro(self, filename): def _parse_gro(self, filename):
""" """
......
...@@ -220,6 +220,9 @@ class Progress: ...@@ -220,6 +220,9 @@ class Progress:
:param postwhile: Function to check after each iteration, stops if False :param postwhile: Function to check after each iteration, stops if False
:param quiet: Skip printing of progress bar - for testing :param quiet: Skip printing of progress bar - for testing
""" """
if prewhile is not None:
raise NotImplementedError("Prewhile conditions are not yet implemented")
self._maxits = maxits self._maxits = maxits
self._length = length self._length = length
self._prewhile = prewhile self._prewhile = prewhile
...@@ -235,16 +238,20 @@ class Progress: ...@@ -235,16 +238,20 @@ class Progress:
""" """
Allow iteration over Progress while testing prewhile and postwhile conditions. Allow iteration over Progress while testing prewhile and postwhile conditions.
Will catch Ctrl-C and return control as if the iterator has been fully consumed.
:return: Iteration number :return: Iteration number
""" """
try:
if self._postwhile is not None and self._its > 0 and not self._postwhile(): if self._postwhile is not None and self._its > 0 and not self._postwhile():
self._stop() self._stop()
if self._prewhile is not None and not self._prewhile(): except KeyboardInterrupt:
print(end="\r")
self._stop() self._stop()
self._its += 1 self._its += 1
if self._its % 10 == 0 and not self._quiet: if self._its % 1 == 0 and not self._quiet:
self._display() self._display()
if self._its >= self._maxits: if self._its >= self._maxits:
...@@ -252,16 +259,24 @@ class Progress: ...@@ -252,16 +259,24 @@ class Progress:
return self._its return self._its
def _stop(self): def run(self):
if not self._quiet: """
Iterate through self until stopped by maximum iterations or False condition.
"""
collections.deque(self, maxlen=0)
@property
def _bar(self):
done = int(self._length * (self._its / self._maxits)) done = int(self._length * (self._its / self._maxits))
left = self._length - done left = self._length - done
return "{0} [".format(self._its) + done * "#" + left * "-" + "] {0}".format(self._maxits)
def _stop(self):
if not self._quiet:
time_taken = int(time.clock() - self._start_time) time_taken = int(time.clock() - self._start_time)
print("{0} [".format(self._its) + done * "#" + left * "-" + "] {0} took {1}s".format(self._maxits, time_taken)) print(self._bar + " took {0}s".format(time_taken))
raise StopIteration raise StopIteration
def _display(self): def _display(self):
done = int(self._length * (self._its / self._maxits))
left = self._length - done
time_remain = int((time.clock() - self._start_time) * ((self._maxits - self._its) / self._its)) time_remain = int((time.clock() - self._start_time) * ((self._maxits - self._its) / self._its))
print("{0} [".format(self._its) + done * "#" + left * "-" + "] {0} {1}s left".format(self._maxits, time_remain), end="\r") print(self._bar + " {0}s left".format(time_remain), end="\r")
...@@ -121,6 +121,7 @@ class Mapping: ...@@ -121,6 +121,7 @@ class Mapping:
Apply the AA->CG mapping to an atomistic Frame. Apply the AA->CG mapping to an atomistic Frame.
:param frame: Frame to which mapping will be applied :param frame: Frame to which mapping will be applied
:param cgframe: CG Frame to remap - optional
:param exclude: Set of molecule names to exclude from mapping - e.g. solvent :param exclude: Set of molecule names to exclude from mapping - e.g. solvent
:return: A new Frame instance containing the CG frame :return: A new Frame instance containing the CG frame
""" """
......
...@@ -126,6 +126,10 @@ class FrameTest(unittest.TestCase): ...@@ -126,6 +126,10 @@ class FrameTest(unittest.TestCase):
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present") @unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
def test_frame_write_xtc_mdtraj(self): def test_frame_write_xtc_mdtraj(self):
try:
os.remove("water_test2.xtc")
except IOError:
pass
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj") xtc_reader="mdtraj")
while frame.next_frame(): while frame.next_frame():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment