diff --git a/pycgtool.py b/pycgtool.py index d7f119eb024b029d0b0f881464e1cda0d7530efc..a33e6b04f4cc7e47fce41b9935e4d840df89a15a 100755 --- a/pycgtool.py +++ b/pycgtool.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse +import sys from pycgtool.frame import Frame from pycgtool.mapping import Mapping @@ -18,7 +19,7 @@ def main(args, config): :param args: Arguments from argparse :param config: Configuration dictionary """ - frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin) + frame = Frame(gro=args.gro, xtc=args.xtc, itp=args.itp, frame_start=args.begin, xtc_reader="mdtraj") if args.bnd: bonds = BondSet(args.bnd, config) @@ -29,18 +30,21 @@ def main(args, config): cgframe.output(config.output_name + ".gro", format=config.output) # Main loop - perform mapping and measurement on every frame in XTC - numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin - for _ in Progress(numframes, postwhile=frame.next_frame): + def main_loop(): + nonlocal cgframe + frame.next_frame() if args.map: cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) if config.output_xtc: cgframe.write_xtc(config.output_name + ".xtc") else: cgframe = frame - if args.bnd: bonds.apply(cgframe) + numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin + Progress(numframes, postwhile=main_loop).run() + if args.bnd: if args.map: bonds.boltzmann_invert() @@ -68,10 +72,16 @@ def map_only(args, config): if args.xtc and (config.output_xtc or args.outputxtc): numframes = frame.numframes - args.begin if args.end == -1 else args.end - args.begin - for _ in Progress(numframes, postwhile=frame.next_frame): + + # Main loop - perform mapping and measurement on every frame in XTC + def main_loop(): + nonlocal cgframe + frame.next_frame() cgframe = mapping.apply(frame, cgframe=cgframe, exclude={"SOL"}) cgframe.write_xtc(config.output_name + ".xtc") + Progress(numframes, postwhile=main_loop).run() + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Perform coarse-grain mapping of atomistic trajectory") @@ -84,7 +94,6 @@ if __name__ == "__main__": parser.add_argument('--interactive', default=False, action='store_true') parser.add_argument('--outputxtc', default=False, action='store_true') - # parser.add_argument('-f', '--frames', type=int, default=-1, help="Number of frames to read") input_files.add_argument('--begin', type=int, default=0, help="Frame number to begin") input_files.add_argument('--end', type=int, default=-1, help="Frame number to end") @@ -108,7 +117,10 @@ if __name__ == "__main__": parser.error("One or both of -m and -b is required.") if args.interactive: - config.interactive() + try: + config.interactive() + except KeyboardInterrupt: + sys.exit(0) else: print("Using GRO: {0}".format(args.gro)) print("Using XTC: {0}".format(args.xtc)) diff --git a/pycgtool/frame.py b/pycgtool/frame.py index 944474167444738c335b19f7c8e8acfcb5905132..ab2742bc2e408624add81e3ce9bfcb4d14841222 100644 --- a/pycgtool/frame.py +++ b/pycgtool/frame.py @@ -146,6 +146,12 @@ class Frame: self._parse_itp(itp) def _open_xtc_simpletraj(self, xtc, gro=None): + """ + Open input XTC file from which to read coordinates using simpletraj library. + + :param xtc: GROMACS XTC file to read subsequent frames + :param gro: GROMACS GRO file - not used + """ try: self.xtc = trajectory.XtcTrajectory(xtc) except OSError as e: @@ -159,6 +165,12 @@ class Frame: self.numframes += self.xtc.numframes def _open_xtc_mdtraj(self, xtc, gro): + """ + Open input XTC file from which to read coordinates using mdtraj library. + + :param xtc: GROMACS XTC file to read subsequent frames + :param gro: GROMACS GRO file from which to read topology + """ try: self.xtc = mdtraj.load_xtc(xtc, top=gro) except OSError as e: @@ -194,7 +206,7 @@ class Frame: def next_frame(self, exclude=None): """ - Read next frame from XTC + Read next frame from input XTC. :return: True if successful else False """ @@ -207,6 +219,11 @@ class Frame: raise def _next_frame_mdtraj(self, exclude=None): + """ + Read next frame from XTC using mdtraj library. + + :return: True if successful else False + """ try: i = 0 # This returns a slice of length 1, properties still need to be indexed @@ -227,6 +244,11 @@ class Frame: return False def _next_frame_simpletraj(self, exclude=None): + """ + Read next frame from XTC using simpletraj library. + + :return: True if successful else False + """ try: self.xtc.get_frame(self.number) i = 0 @@ -249,7 +271,13 @@ class Frame: return False def write_xtc(self, filename): + """ + Write frame to output XTC file. + + :param filename: XTC filename to write to + """ if self._xtc_buffer is None: + backup_file(filename, verbose=True) try: self._xtc_buffer = mdtraj.formats.XTCTrajectoryFile(filename, mode="w") except NameError as e: @@ -269,7 +297,6 @@ class Frame: box[0][i][i] = self.box[i] self._xtc_buffer.write(xyz, step=step, box=box) - # self._xtc_buffer.close() def _parse_gro(self, filename): """ diff --git a/pycgtool/interface.py b/pycgtool/interface.py index 2949359793d2e173cfbd3651fbb060db541acf79..02c80a6d83cb7180bf5c5c2f84b57f66cd4d8e03 100644 --- a/pycgtool/interface.py +++ b/pycgtool/interface.py @@ -220,6 +220,9 @@ class Progress: :param postwhile: Function to check after each iteration, stops if False :param quiet: Skip printing of progress bar - for testing """ + if prewhile is not None: + raise NotImplementedError("Prewhile conditions are not yet implemented") + self._maxits = maxits self._length = length self._prewhile = prewhile @@ -235,16 +238,20 @@ class Progress: """ Allow iteration over Progress while testing prewhile and postwhile conditions. + Will catch Ctrl-C and return control as if the iterator has been fully consumed. + :return: Iteration number """ - if self._postwhile is not None and self._its > 0 and not self._postwhile(): - self._stop() + try: + if self._postwhile is not None and self._its > 0 and not self._postwhile(): + self._stop() - if self._prewhile is not None and not self._prewhile(): + except KeyboardInterrupt: + print(end="\r") self._stop() self._its += 1 - if self._its % 10 == 0 and not self._quiet: + if self._its % 1 == 0 and not self._quiet: self._display() if self._its >= self._maxits: @@ -252,16 +259,24 @@ class Progress: return self._its + def run(self): + """ + Iterate through self until stopped by maximum iterations or False condition. + """ + collections.deque(self, maxlen=0) + + @property + def _bar(self): + done = int(self._length * (self._its / self._maxits)) + left = self._length - done + return "{0} [".format(self._its) + done * "#" + left * "-" + "] {0}".format(self._maxits) + def _stop(self): if not self._quiet: - done = int(self._length * (self._its / self._maxits)) - left = self._length - done time_taken = int(time.clock() - self._start_time) - print("{0} [".format(self._its) + done * "#" + left * "-" + "] {0} took {1}s".format(self._maxits, time_taken)) + print(self._bar + " took {0}s".format(time_taken)) raise StopIteration def _display(self): - done = int(self._length * (self._its / self._maxits)) - left = self._length - done time_remain = int((time.clock() - self._start_time) * ((self._maxits - self._its) / self._its)) - print("{0} [".format(self._its) + done * "#" + left * "-" + "] {0} {1}s left".format(self._maxits, time_remain), end="\r") + print(self._bar + " {0}s left".format(time_remain), end="\r") diff --git a/pycgtool/mapping.py b/pycgtool/mapping.py index 16828db316987aa2b3b3653d80d28b603939602e..701490369ed05fbe710704c9f8c45a8653a9ab4d 100644 --- a/pycgtool/mapping.py +++ b/pycgtool/mapping.py @@ -121,6 +121,7 @@ class Mapping: Apply the AA->CG mapping to an atomistic Frame. :param frame: Frame to which mapping will be applied + :param cgframe: CG Frame to remap - optional :param exclude: Set of molecule names to exclude from mapping - e.g. solvent :return: A new Frame instance containing the CG frame """ diff --git a/test/test_frame.py b/test/test_frame.py index 383fe6eee4c49b9da97481151a1721a289b5d659..604e88f32671f2c4ead547e18f975f2a1cab7f83 100644 --- a/test/test_frame.py +++ b/test/test_frame.py @@ -126,6 +126,10 @@ class FrameTest(unittest.TestCase): @unittest.skipIf(not mdtraj_present, "MDTRAJ not present") def test_frame_write_xtc_mdtraj(self): + try: + os.remove("water_test2.xtc") + except IOError: + pass frame = Frame(gro="test/data/water.gro", xtc="test/data/water.xtc", xtc_reader="mdtraj") while frame.next_frame():