From 9b63e31b389985696a7178ebe3e89548b6dc2ff2 Mon Sep 17 00:00:00 2001
From: James Graham <J.A.Graham@soton.ac.uk>
Date: Wed, 11 May 2016 17:18:38 +0100
Subject: [PATCH] Working JSON config reader with includes

---
 pycgtool/parsers/json.py          | 42 +++++++++----------
 test/data/martini.json            |  4 ++
 test/data/martini_aminoacids.json |  9 +++++
 test/data/martini_lipids.json     |  9 +++++
 test/data/twice.json              | 15 -------
 test/data/water.json              | 18 +++++----
 test/test_parsers_json.py         | 67 ++++++++++++++++++++-----------
 7 files changed, 98 insertions(+), 66 deletions(-)
 create mode 100644 test/data/martini.json
 create mode 100644 test/data/martini_aminoacids.json
 create mode 100644 test/data/martini_lipids.json
 delete mode 100644 test/data/twice.json

diff --git a/pycgtool/parsers/json.py b/pycgtool/parsers/json.py
index ca6efe8..bbfe9ae 100644
--- a/pycgtool/parsers/json.py
+++ b/pycgtool/parsers/json.py
@@ -1,12 +1,5 @@
 import json
-
-
-class DuplicateSectionError(Exception):
-    """
-    Exception used to indicate that a section has appeared twice in a file.
-    """
-    def __repr__(self):
-        return "Section {0} appears twice in file {1}.".format(*self.args)
+import os
 
 
 class Record(dict):
@@ -18,21 +11,28 @@ class Record(dict):
 
 
 class CFG:
-    def __init__(self, filename):
+    def __init__(self, filename, from_section=None):
         with open(filename) as f:
-            try:
-                self._json = json.load(f, object_hook=Record)
-            except ValueError:
-                raise DuplicateSectionError()
+            self._json = json.load(f, object_hook=Record)
 
-    def __getitem__(self, item):
-        return Record(self._json[item])
+        while self._json.include:
+            include_file = os.path.join(os.path.dirname(filename), self._json.include.pop())
+            with open(include_file) as include_file:
+                include_json = json.load(include_file, object_hook=Record)
 
-    def __contains__(self, item):
-        return item in self._json
+            for curr, incl in zip(self._json.values(), include_json.values()):
+                try:
+                    curr += incl
+                except TypeError:
+                    curr.update(incl)
 
-    def __enter__(self):
-        return self
+        if from_section is not None:
+            self._records = self._json[from_section]
+        else:
+            self._records = self._json
 
-    def __exit__(self, type, value, traceback):
-        pass
+    def __getitem__(self, item):
+        return self._records[item]
+
+    def __contains__(self, item):
+        return item in self._records
diff --git a/test/data/martini.json b/test/data/martini.json
new file mode 100644
index 0000000..5bd7733
--- /dev/null
+++ b/test/data/martini.json
@@ -0,0 +1,4 @@
+{
+    "include":["martini_lipids.json", "martini_aminoacids.json"],
+    "molecules":{}
+}
diff --git a/test/data/martini_aminoacids.json b/test/data/martini_aminoacids.json
new file mode 100644
index 0000000..b5d7f37
--- /dev/null
+++ b/test/data/martini_aminoacids.json
@@ -0,0 +1,9 @@
+{
+    "include":[],
+    "molecules":{
+        "GLY":{
+               "beads":[],
+               "bonds":[]
+        }
+    }
+}
diff --git a/test/data/martini_lipids.json b/test/data/martini_lipids.json
new file mode 100644
index 0000000..b9c64f7
--- /dev/null
+++ b/test/data/martini_lipids.json
@@ -0,0 +1,9 @@
+{
+    "include":[],
+    "molecules":{
+        "DOPC":{
+               "beads":[],
+               "bonds":[]
+        }
+    }
+}
diff --git a/test/data/twice.json b/test/data/twice.json
deleted file mode 100644
index 8f67dec..0000000
--- a/test/data/twice.json
+++ /dev/null
@@ -1,15 +0,0 @@
-{"SOL":{
-   "beads":[
-        {"name":"W", "type":"P4", "atoms":["OW", "HW1", "HW2"]}
-   ],
-   "bonds":[
-   ]
-}}
-
-{"SOL":{
-   "beads":[
-        {"name":"W", "type":"P4", "atoms":["OW", "HW1", "HW2"]}
-   ],
-   "bonds":[
-   ]
-}}
diff --git a/test/data/water.json b/test/data/water.json
index b9597be..af0c31b 100644
--- a/test/data/water.json
+++ b/test/data/water.json
@@ -1,7 +1,11 @@
-{"SOL":{
-   "beads":[
-        {"name":"W", "type":"P4", "atoms":["OW", "HW1", "HW2"]}
-   ],
-   "bonds":[
-   ]
-}}
+{
+    "include":[],
+    "molecules":{
+        "SOL":{
+            "beads":[
+                {"name":"W", "type":"P4", "atoms":["OW", "HW1", "HW2"]}
+            ],
+            "bonds":[]
+        }
+    }
+}
diff --git a/test/test_parsers_json.py b/test/test_parsers_json.py
index c7aebaf..c6f6c05 100644
--- a/test/test_parsers_json.py
+++ b/test/test_parsers_json.py
@@ -1,32 +1,53 @@
 import unittest
 
-from pycgtool.parsers.json import CFG, DuplicateSectionError
+from pycgtool.parsers.json import CFG
 
 
 class TestParsersJson(unittest.TestCase):
-    def test_cfg_with(self):
-        with CFG("test/data/water.json"):
-            pass
-
-    def test_cfg_get_section(self):
-        with CFG("test/data/water.json") as cfg:
-            self.assertTrue("SOL" in cfg)
-            self.assertEqual(1, len(cfg["SOL"].beads))
-            bead = cfg["SOL"].beads[0]
-            self.assertEqual("W", bead.name)
-            self.assertEqual("P4", bead.type)
-            self.assertEqual(["OW", "HW1", "HW2"], bead.atoms)
-            self.assertEqual(0, len(cfg["SOL"].bonds))
-
-    def test_cfg_duplicate_error(self):
-        with self.assertRaises(DuplicateSectionError):
-            CFG("test/data/twice.json")
-
-    @unittest.expectedFailure
+    def test_json_water(self):
+        cfg = CFG("test/data/water.json", "molecules")
+
+        self.assertTrue("SOL" in cfg)
+        self.assertEqual(1, len(cfg["SOL"].beads))
+        bead = cfg["SOL"].beads[0]
+        self.assertEqual("W", bead.name)
+        self.assertEqual("P4", bead.type)
+        self.assertEqual(["OW", "HW1", "HW2"], bead.atoms)
+        self.assertEqual(0, len(cfg["SOL"].bonds))
+
+    def test_json_sugar(self):
+        cfg = CFG("test/data/sugar.json", "molecules")
+
+        ref_beads = [["C1", "P3", "C1", "O1"],
+                     ["C2", "P3", "C2", "O2"],
+                     ["C3", "P3", "C3", "O3"],
+                     ["C4", "P3", "C4", "O4"],
+                     ["C5", "P2", "C5", "C6", "O6"],
+                     ["O5", "P4", "O5"]]
+
+        ref_bonds = [["C1", "C2"],
+                     ["C2", "C3"],
+                     ["C3", "C4"],
+                     ["C4", "C5"],
+                     ["C5", "O5"],
+                     ["O5", "C1"]]
+
+        self.assertTrue("ALLA" in cfg)
+
+        self.assertEqual(6, len(cfg["ALLA"].beads))
+        for ref_bead, bead in zip(ref_beads, cfg["ALLA"].beads):
+            self.assertEqual(ref_bead[0], bead.name)
+            self.assertEqual(ref_bead[1], bead.type)
+            self.assertEqual(ref_bead[2:], bead.atoms)
+
+        self.assertEqual(6, len(cfg["ALLA"].bonds))
+        for ref_bond, bond in zip(ref_bonds, cfg["ALLA"].bonds):
+            self.assertEqual(ref_bond, bond)
+
     def test_include_file(self):
-        with CFG("test/data/martini.json") as cfg:
-            self.assertTrue("DOPC" in cfg)
-            self.assertTrue("GLY" in cfg)
+        cfg = CFG("test/data/martini.json", "molecules")
+        self.assertTrue("DOPC" in cfg)
+        self.assertTrue("GLY" in cfg)
 
 
 if __name__ == '__main__':
-- 
GitLab