Skip to content
Snippets Groups Projects
Commit c02ddb7a authored by James Graham's avatar James Graham
Browse files

Added cleaner handling of Ctrl-C

Main loops now moved entirely inside progress iterator
parent 7cc99cbe
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3
import argparse
import sys
from pycgtool.frame import Frame
from pycgtool.mapping import Mapping
......@@ -18,7 +19,7 @@ def main(args, config):
:param args: Arguments from argparse
:param config: Configuration dictionary
"""
frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin)
frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin, xtc_reader="mdtraj")
if args.bnd:
bonds = BondSet(args.bnd, config)
......@@ -29,18 +30,21 @@ def main(args, config):
cgframe.output(config.output_name + ".gro", format=config.output)
# Main loop - perform mapping and measurement on every frame in XTC
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
for _ in Progress(numframes, postwhile=frame.next_frame):
def main_loop():
nonlocal cgframe
frame.next_frame()
if args.map:
cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
if config.output_xtc:
cgframe.write_xtc(config.output_name + ".xtc")
else:
cgframe = frame
if args.bnd:
bonds.apply(cgframe)
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
Progress(numframes, postwhile=main_loop).run()
if args.bnd:
if args.map:
bonds.boltzmann_invert()
......@@ -68,10 +72,16 @@ def map_only(args, config):
if args.xtc and (config.output_xtc or args.outputxtc):
numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin
for _ in Progress(numframes, postwhile=frame.next_frame):
# Main loop - perform mapping and measurement on every frame in XTC
def main_loop():
nonlocal cgframe
frame.next_frame()
cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"})
cgframe.write_xtc(config.output_name + ".xtc")
Progress(numframes, postwhile=main_loop).run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Perform coarse-grain mapping of atomistic trajectory")
......@@ -84,7 +94,6 @@ if __name__ == "__main__":
parser.add_argument('--interactive', default=False, action='store_true')
parser.add_argument('--outputxtc', default=False, action='store_true')
# parser.add_argument('-f', '--frames', type=int, default=-1, help="Number of frames to read")
input_files.add_argument('--begin', type=int, default=0, help="Frame number to begin")
input_files.add_argument('--end', type=int, default=-1, help="Frame number to end")
......@@ -108,7 +117,10 @@ if __name__ == "__main__":
parser.error("One or both of -m and -b is required.")
if args.interactive:
try:
config.interactive()
except KeyboardInterrupt:
sys.exit(0)
else:
print("Using GRO: {0}".format(args.gro))
print("Using XTC: {0}".format(args.xtc))
......
......@@ -146,6 +146,12 @@ class Frame:
self._parse_itp(itp)
def _open_xtc_simpletraj(self, xtc, gro=None):
"""
Open input XTC file from which to read coordinates using simpletraj library.
:param xtc: GROMACS XTC file to read subsequent frames
:param gro: GROMACS GRO file - not used
"""
try:
self.xtc = trajectory.XtcTrajectory(xtc)
except OSError as e:
......@@ -159,6 +165,12 @@ class Frame:
self.numframes += self.xtc.numframes
def _open_xtc_mdtraj(self, xtc, gro):
"""
Open input XTC file from which to read coordinates using mdtraj library.
:param xtc: GROMACS XTC file to read subsequent frames
:param gro: GROMACS GRO file from which to read topology
"""
try:
self.xtc = mdtraj.load_xtc(xtc, top=gro)
except OSError as e:
......@@ -194,7 +206,7 @@ class Frame:
def next_frame(self, exclude=None):
"""
Read next frame from XTC
Read next frame from input XTC.
:return: True if successful else False
"""
......@@ -207,6 +219,11 @@ class Frame:
raise
def _next_frame_mdtraj(self, exclude=None):
"""
Read next frame from XTC using mdtraj library.
:return: True if successful else False
"""
try:
i = 0
# This returns a slice of length 1, properties still need to be indexed
......@@ -227,6 +244,11 @@ class Frame:
return False
def _next_frame_simpletraj(self, exclude=None):
"""
Read next frame from XTC using simpletraj library.
:return: True if successful else False
"""
try:
self.xtc.get_frame(self.number)
i = 0
......@@ -249,7 +271,13 @@ class Frame:
return False
def write_xtc(self, filename):
"""
Write frame to output XTC file.
:param filename: XTC filename to write to
"""
if self._xtc_buffer is None:
backup_file(filename, verbose=True)
try:
self._xtc_buffer = mdtraj.formats.XTCTrajectoryFile(filename, mode="w")
except NameError as e:
......@@ -269,7 +297,6 @@ class Frame:
box[0][i][i] = self.box[i]
self._xtc_buffer.write(xyz, step=step, box=box)
# self._xtc_buffer.close()
def _parse_gro(self, filename):
"""
......
......@@ -220,6 +220,9 @@ class Progress:
:param postwhile: Function to check after each iteration, stops if False
:param quiet: Skip printing of progress bar - for testing
"""
if prewhile is not None:
raise NotImplementedError("Prewhile conditions are not yet implemented")
self._maxits = maxits
self._length = length
self._prewhile = prewhile
......@@ -235,16 +238,20 @@ class Progress:
"""
Allow iteration over Progress while testing prewhile and postwhile conditions.
Will catch Ctrl-C and return control as if the iterator has been fully consumed.
:return: Iteration number
"""
try:
if self._postwhile is not None and self._its > 0 and not self._postwhile():
self._stop()
if self._prewhile is not None and not self._prewhile():
except KeyboardInterrupt:
print(end="\r")
self._stop()
self._its += 1
if self._its % 10 == 0 and not self._quiet:
if self._its % 1 == 0 and not self._quiet:
self._display()
if self._its >= self._maxits:
......@@ -252,16 +259,24 @@ class Progress:
return self._its
def _stop(self):
if not self._quiet:
def run(self):
"""
Iterate through self until stopped by maximum iterations or False condition.
"""
collections.deque(self, maxlen=0)
@property
def _bar(self):
done = int(self._length * (self._its / self._maxits))
left = self._length - done
return "{0} [".format(self._its) + done * "#" + left * "-" + "] {0}".format(self._maxits)
def _stop(self):
if not self._quiet:
time_taken = int(time.clock() - self._start_time)
print("{0} [".format(self._its) + done * "#" + left * "-" + "] {0} took {1}s".format(self._maxits, time_taken))
print(self._bar + " took {0}s".format(time_taken))
raise StopIteration
def _display(self):
done = int(self._length * (self._its / self._maxits))
left = self._length - done
time_remain = int((time.clock() - self._start_time) * ((self._maxits - self._its) / self._its))
print("{0} [".format(self._its) + done * "#" + left * "-" + "] {0} {1}s left".format(self._maxits, time_remain), end="\r")
print(self._bar + " {0}s left".format(time_remain), end="\r")
......@@ -121,6 +121,7 @@ class Mapping:
Apply the AA->CG mapping to an atomistic Frame.
:param frame: Frame to which mapping will be applied
:param cgframe: CG Frame to remap - optional
:param exclude: Set of molecule names to exclude from mapping - e.g. solvent
:return: A new Frame instance containing the CG frame
"""
......
......@@ -126,6 +126,10 @@ class FrameTest(unittest.TestCase):
@unittest.skipIf(not mdtraj_present, "MDTRAJ not present")
def test_frame_write_xtc_mdtraj(self):
try:
os.remove("water_test2.xtc")
except IOError:
pass
frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc",
xtc_reader="mdtraj")
while frame.next_frame():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment