diff --git a/pycgtool/parsers/json.py b/pycgtool/parsers/json.py index 8ad4d5b5a1b35b07edf53308e9551e380048d9d5..a41bffed300a935d734f4d545b73452b37a00f1d 100644 --- a/pycgtool/parsers/json.py +++ b/pycgtool/parsers/json.py @@ -13,43 +13,63 @@ class AttrDict(dict): self.__dict__ = self -class CFG: +class Parser: """ Class to read data from JSON files. Supports including other files and filtering a single section. """ - def __init__(self, filename, from_section=None): + def __init__(self, filename, section=None): """ Create a new CFG JSON parser. :param filename: JSON file to read - :param from_section: Optional section to select from file + :param section: Optional section to select from file """ + with open(filename) as f: self._json = json.load(f, object_hook=AttrDict) + included = set() + # Recurse through include lists and add to self._json - while self._json.include: - include_file = os.path.join(os.path.dirname(filename), self._json.include.pop()) - with open(include_file) as include_file: - include_json = json.load(include_file, object_hook=AttrDict) + try: + while self._json.include: + include_file = os.path.join(os.path.dirname(filename), self._json.include.pop()) + if include_file in included: + continue + included.add(include_file) + + with open(include_file) as include_file: + include_json = json.load(include_file, object_hook=AttrDict) - for curr, incl in zip(self._json.values(), include_json.values()): - try: - curr += incl - except TypeError: - curr.update(incl) + for sec_name, sec_data in include_json.items(): + try: + # Assume is list + self._json[sec_name] += sec_data + except TypeError: + # Is actually a dictionary + self._json[sec_name].update(sec_data) + except KeyError: + # Doesn't exist in self._json, add it + self._json[sec_name] = sec_data + del self._json.include + except AttributeError: + # File doesn't have an include section + pass self._records = self._json - if from_section is not None: + if section is not None: try: - self._records = self._json[from_section] + self._records = self._json[section] except KeyError as e: - e.args = ("Section '{0}' not in file '{1}'".format(from_section, filename),) + e.args = ("Section '{0}' not in file '{1}'".format(section, filename),) raise def __getitem__(self, item): return self._records[item] + def __getattr__(self, item): + return self._records[item] + def __contains__(self, item): return item in self._records diff --git a/test/data/martini.json b/test/data/martini.json index 5bd77332ac150e35c452f62b528f4fef3c1cf446..86b805528b220df5420378f64287678f40fd74b3 100644 --- a/test/data/martini.json +++ b/test/data/martini.json @@ -1,4 +1,3 @@ { - "include":["martini_lipids.json", "martini_aminoacids.json"], - "molecules":{} + "include":["martini_lipids.json", "martini_aminoacids.json"] } diff --git a/test/test_parsers_json.py b/test/test_parsers_json.py index 8a35bc9f0bce7b5938e2dea352afca503e91c0b6..033611566122dc0af887d48fa1f6ad9808a76c6e 100644 --- a/test/test_parsers_json.py +++ b/test/test_parsers_json.py @@ -1,22 +1,51 @@ import unittest -from pycgtool.parsers.json import CFG, jsonify +from pycgtool.parsers.json import Parser, jsonify class TestParsersJson(unittest.TestCase): + def test_json_read(self): + parser = Parser("test/data/sugar.json") + + self.assertTrue("molecules" in parser) + self.assertTrue("ALLA" in parser.molecules) + + self.assertTrue("beads" in parser.molecules["ALLA"]) + self.assertEqual("C1", parser.molecules["ALLA"].beads[0].name) + + self.assertTrue("bonds" in parser.molecules["ALLA"]) + self.assertEqual(["C1", "C2"], parser.molecules["ALLA"].bonds[0]) + + def test_json_section(self): + molecules = Parser("test/data/sugar.json", section="molecules") + + self.assertTrue("ALLA" in molecules) + + self.assertTrue("beads" in molecules["ALLA"]) + self.assertEqual("C1", molecules["ALLA"].beads[0].name) + + self.assertTrue("bonds" in molecules["ALLA"]) + self.assertEqual(["C1", "C2"], molecules["ALLA"].bonds[0]) + + def test_include_file(self): + martini = Parser("test/data/martini.json") + self.assertTrue("molecules" in martini) + self.assertTrue("DOPC" in martini.molecules) + self.assertTrue("GLY" in martini.molecules) + def test_json_water(self): - cfg = CFG("test/data/water.json", "molecules") + parser = Parser("test/data/water.json", "molecules") - self.assertTrue("SOL" in cfg) - self.assertEqual(1, len(cfg["SOL"].beads)) - bead = cfg["SOL"].beads[0] + self.assertTrue("SOL" in parser) + self.assertEqual(1, len(parser["SOL"].beads)) + bead = parser["SOL"].beads[0] self.assertEqual("W", bead.name) self.assertEqual("P4", bead.type) self.assertEqual(["OW", "HW1", "HW2"], bead.atoms) self.assertEqual(0, len(cfg["SOL"].bonds)) def test_json_sugar(self): - cfg = CFG("test/data/sugar.json", "molecules") + parser = Parser("test/data/sugar.json", "molecules") ref_beads = [["C1", "P3", "C1", "O1"], ["C2", "P3", "C2", "O2"], @@ -32,31 +61,26 @@ class TestParsersJson(unittest.TestCase): ["C5", "O5"], ["O5", "C1"]] - self.assertTrue("ALLA" in cfg) + self.assertTrue("ALLA" in parser) - self.assertEqual(6, len(cfg["ALLA"].beads)) - for ref_bead, bead in zip(ref_beads, cfg["ALLA"].beads): + self.assertEqual(6, len(parser["ALLA"].beads)) + for ref_bead, bead in zip(ref_beads, parser["ALLA"].beads): self.assertEqual(ref_bead[0], bead.name) self.assertEqual(ref_bead[1], bead.type) self.assertEqual(ref_bead[2:], bead.atoms) - self.assertEqual(6, len(cfg["ALLA"].bonds)) - for ref_bond, bond in zip(ref_bonds, cfg["ALLA"].bonds): + self.assertEqual(6, len(parser["ALLA"].bonds)) + for ref_bond, bond in zip(ref_bonds, parser["ALLA"].bonds): self.assertEqual(ref_bond, bond) - def test_include_file(self): - cfg = CFG("test/data/martini.json", "molecules") - self.assertTrue("DOPC" in cfg) - self.assertTrue("GLY" in cfg) - def test_missing_section(self): with self.assertRaises(KeyError): - cfg = CFG("test/data/water.json", "potato") + parser = Parser("test/data/water.json", "potato") def test_convert(self): jsonify("test/data/sugar.map", "test/data/sugar.bnd", "test.json") - test_json = CFG("test.json", from_section="molecules") - ref_json = CFG("test/data/sugar.json", from_section="molecules") + test_json = Parser("test.json", section="molecules") + ref_json = Parser("test/data/sugar.json", section="molecules") for tbead, rbead in zip(test_json["ALLA"].beads, ref_json["ALLA"].beads): self.assertEqual(tbead, rbead)