From ab97010846815df7309b646d0bfe9f9e246890ba Mon Sep 17 00:00:00 2001
From: James Graham <j.graham@soton.ac.uk>
Date: Fri, 19 Feb 2021 23:15:24 +0000
Subject: [PATCH] feat: add initial trial for backmapping

Uses the backmapper from Charles Laughton's MDPlus
---
 .gitignore           |  2 +
 poetry.lock          | 97 +++++++++++++++++++++++++++++++++++++++++++-
 pycgtool/__main__.py | 37 +++++++++++++++++
 pyproject.toml       |  1 +
 4 files changed, 136 insertions(+), 1 deletion(-)

diff --git a/.gitignore b/.gitignore
index 6f36e7c..29da604 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,6 +11,7 @@
 /minenv/
 /venv/
 *.pyc
+poetry.toml
 
 # Development and Testing
 .cache
@@ -31,3 +32,4 @@ settings.json
 .cache/
 /notes/
 *.offsets
+*.pkl
diff --git a/poetry.lock b/poetry.lock
index d208057..1cc0faf 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -235,6 +235,14 @@ MarkupSafe = ">=0.23"
 [package.extras]
 i18n = ["Babel (>=0.8)"]
 
+[[package]]
+name = "joblib"
+version = "1.0.1"
+description = "Lightweight pipelining with Python functions"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
 [[package]]
 name = "lazy-object-proxy"
 version = "1.4.3"
@@ -259,6 +267,25 @@ category = "dev"
 optional = false
 python-versions = "*"
 
+[[package]]
+name = "mdplus"
+version = "0.0.4"
+description = "Tools for molecular dynamics simulation setup and analysis."
+category = "main"
+optional = false
+python-versions = "*"
+develop = false
+
+[package.dependencies]
+scikit-learn = "*"
+scipy = "*"
+
+[package.source]
+type = "git"
+url = "https://bitbucket.org/claughton/mdplus.git"
+reference = "master"
+resolved_reference = "acd94bb4988011f37724bdbd9f63e62f83c56c95"
+
 [[package]]
 name = "mdtraj"
 version = "1.9.5"
@@ -586,6 +613,26 @@ python-versions = "*"
 [package.dependencies]
 docutils = ">=0.7"
 
+[[package]]
+name = "scikit-learn"
+version = "0.24.1"
+description = "A set of python modules for machine learning and data mining"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+joblib = ">=0.11"
+numpy = ">=1.13.3"
+scipy = ">=0.19.1"
+threadpoolctl = ">=2.0.0"
+
+[package.extras]
+benchmark = ["matplotlib (>=2.1.1)", "pandas (>=0.25.0)", "memory-profiler (>=0.57.0)"]
+docs = ["matplotlib (>=2.1.1)", "scikit-image (>=0.13)", "pandas (>=0.25.0)", "seaborn (>=0.9.0)", "memory-profiler (>=0.57.0)", "sphinx (>=3.2.0)", "sphinx-gallery (>=0.7.0)", "numpydoc (>=1.0.0)", "Pillow (>=7.1.2)", "sphinx-prompt (>=1.3.0)"]
+examples = ["matplotlib (>=2.1.1)", "scikit-image (>=0.13)", "pandas (>=0.25.0)", "seaborn (>=0.9.0)"]
+tests = ["matplotlib (>=2.1.1)", "scikit-image (>=0.13)", "pandas (>=0.25.0)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "flake8 (>=3.8.2)", "mypy (>=0.770)", "pyamg (>=4.0.0)"]
+
 [[package]]
 name = "scipy"
 version = "1.5.4"
@@ -760,6 +807,14 @@ python-versions = ">=3.5"
 lint = ["flake8", "mypy", "docutils-stubs"]
 test = ["pytest"]
 
+[[package]]
+name = "threadpoolctl"
+version = "2.1.0"
+description = "threadpoolctl"
+category = "main"
+optional = false
+python-versions = ">=3.5"
+
 [[package]]
 name = "toml"
 version = "0.10.2"
@@ -847,7 +902,7 @@ testing = ["pytest (>=3.5,!=3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake
 [metadata]
 lock-version = "1.1"
 python-versions = "^3.6"
-content-hash = "f35c140d13741413ff816520486bec91842863ee8991b5177b11c0bb8a76ff05"
+content-hash = "cb58621a19006e715ec04d35c426637d55b88df8478cb043bebcaea8ef2cbcc5"
 
 [metadata.files]
 alabaster = [
@@ -1020,6 +1075,10 @@ jinja2 = [
     {file = "Jinja2-2.11.3-py2.py3-none-any.whl", hash = "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419"},
     {file = "Jinja2-2.11.3.tar.gz", hash = "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6"},
 ]
+joblib = [
+    {file = "joblib-1.0.1-py3-none-any.whl", hash = "sha256:feeb1ec69c4d45129954f1b7034954241eedfd6ba39b5e9e4b6883be3332d5e5"},
+    {file = "joblib-1.0.1.tar.gz", hash = "sha256:9c17567692206d2f3fb9ecf5e991084254fe631665c450b443761c4186a613f7"},
+]
 lazy-object-proxy = [
     {file = "lazy-object-proxy-1.4.3.tar.gz", hash = "sha256:f3900e8a5de27447acbf900b4750b0ddfd7ec1ea7fbaf11dfa911141bc522af0"},
     {file = "lazy_object_proxy-1.4.3-cp27-cp27m-macosx_10_13_x86_64.whl", hash = "sha256:a2238e9d1bb71a56cd710611a1614d1194dc10a175c1e08d75e1a7bcc250d442"},
@@ -1082,6 +1141,7 @@ mccabe = [
     {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"},
     {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"},
 ]
+mdplus = []
 mdtraj = [
     {file = "mdtraj-1.9.5-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:c02a9a589acc98dd3cc4db9b0cb21725f5e2cb9484c14cc4fa032a4663c7a1e9"},
     {file = "mdtraj-1.9.5-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:8816f9a91826e46a413fd1a5ffd663b4076cf4c3b8930bb2046814d1438a88ee"},
@@ -1255,6 +1315,37 @@ rich = [
 rstcheck = [
     {file = "rstcheck-3.3.1.tar.gz", hash = "sha256:92c4f79256a54270e0402ba16a2f92d0b3c15c8f4410cb9c57127067c215741f"},
 ]
+scikit-learn = [
+    {file = "scikit-learn-0.24.1.tar.gz", hash = "sha256:a0334a1802e64d656022c3bfab56a73fbd6bf4b1298343f3688af2151810bbdf"},
+    {file = "scikit_learn-0.24.1-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:9bed8a1ef133c8e2f13966a542cb8125eac7f4b67dcd234197c827ba9c7dd3e0"},
+    {file = "scikit_learn-0.24.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a36e159a0521e13bbe15ca8c8d038b3a1dd4c7dad18d276d76992e03b92cf643"},
+    {file = "scikit_learn-0.24.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:c658432d8a20e95398f6bb95ff9731ce9dfa343fdf21eea7ec6a7edfacd4b4d9"},
+    {file = "scikit_learn-0.24.1-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:9dfa564ef27e8e674aa1cc74378416d580ac4ede1136c13dd555a87996e13422"},
+    {file = "scikit_learn-0.24.1-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:9c6097b6a9b2bafc5e0f31f659e6ab5e131383209c30c9e978c5b8abdac5ed2a"},
+    {file = "scikit_learn-0.24.1-cp36-cp36m-win32.whl", hash = "sha256:7b04691eb2f41d2c68dbda8d1bd3cb4ef421bdc43aaa56aeb6c762224552dfb6"},
+    {file = "scikit_learn-0.24.1-cp36-cp36m-win_amd64.whl", hash = "sha256:1adf483e91007a87171d7ce58c34b058eb5dab01b5fee6052f15841778a8ecd8"},
+    {file = "scikit_learn-0.24.1-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:ddb52d088889f5596bc4d1de981f2eca106b58243b6679e4782f3ba5096fd645"},
+    {file = "scikit_learn-0.24.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a29460499c1e62b7a830bb57ca42e615375a6ab1bcad053cd25b493588348ea8"},
+    {file = "scikit_learn-0.24.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0567a2d29ad08af98653300c623bd8477b448fe66ced7198bef4ed195925f082"},
+    {file = "scikit_learn-0.24.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:99349d77f54e11f962d608d94dfda08f0c9e5720d97132233ebdf35be2858b2d"},
+    {file = "scikit_learn-0.24.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:83b21ff053b1ff1c018a2d24db6dd3ea339b1acfbaa4d9c881731f43748d8b3b"},
+    {file = "scikit_learn-0.24.1-cp37-cp37m-win32.whl", hash = "sha256:c3deb3b19dd9806acf00cf0d400e84562c227723013c33abefbbc3cf906596e9"},
+    {file = "scikit_learn-0.24.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d54dbaadeb1425b7d6a66bf44bee2bb2b899fe3e8850b8e94cfb9c904dcb46d0"},
+    {file = "scikit_learn-0.24.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:3c4f07f47c04e81b134424d53c3f5e16dfd7f494e44fd7584ba9ce9de2c5e6c1"},
+    {file = "scikit_learn-0.24.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:c13ebac42236b1c46397162471ea1c46af68413000e28b9309f8c05722c65a09"},
+    {file = "scikit_learn-0.24.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:4ddd2b6f7449a5d539ff754fa92d75da22de261fd8fdcfb3596799fadf255101"},
+    {file = "scikit_learn-0.24.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:826b92bf45b8ad80444814e5f4ac032156dd481e48d7da33d611f8fe96d5f08b"},
+    {file = "scikit_learn-0.24.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:259ec35201e82e2db1ae2496f229e63f46d7f1695ae68eef9350b00dc74ba52f"},
+    {file = "scikit_learn-0.24.1-cp38-cp38-win32.whl", hash = "sha256:8772b99d683be8f67fcc04789032f1b949022a0e6880ee7b75a7ec97dbbb5d0b"},
+    {file = "scikit_learn-0.24.1-cp38-cp38-win_amd64.whl", hash = "sha256:ed9d65594948678827f4ff0e7ae23344e2f2b4cabbca057ccaed3118fdc392ca"},
+    {file = "scikit_learn-0.24.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:8aa1b3ac46b80eaa552b637eeadbbce3be5931e4b5002b964698e33a1b589e1e"},
+    {file = "scikit_learn-0.24.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:c7f4eb77504ac586d8ac1bde1b0c04b504487210f95297235311a0ab7edd7e38"},
+    {file = "scikit_learn-0.24.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:087dfede39efb06ab30618f9ab55a0397f29c38d63cd0ab88d12b500b7d65fd7"},
+    {file = "scikit_learn-0.24.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:895dbf2030aa7337649e36a83a007df3c9811396b4e2fa672a851160f36ce90c"},
+    {file = "scikit_learn-0.24.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:9a24d1ccec2a34d4cd3f2a1f86409f3f5954cc23d4d2270ba0d03cf018aa4780"},
+    {file = "scikit_learn-0.24.1-cp39-cp39-win32.whl", hash = "sha256:fab31f48282ebf54dd69f6663cd2d9800096bad1bb67bbc9c9ac84eb77b41972"},
+    {file = "scikit_learn-0.24.1-cp39-cp39-win_amd64.whl", hash = "sha256:4562dcf4793e61c5d0f89836d07bc37521c3a1889da8f651e2c326463c4bd697"},
+]
 scipy = [
     {file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"},
     {file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"},
@@ -1329,6 +1420,10 @@ sphinxcontrib-serializinghtml = [
     {file = "sphinxcontrib-serializinghtml-1.1.4.tar.gz", hash = "sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc"},
     {file = "sphinxcontrib_serializinghtml-1.1.4-py2.py3-none-any.whl", hash = "sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a"},
 ]
+threadpoolctl = [
+    {file = "threadpoolctl-2.1.0-py3-none-any.whl", hash = "sha256:38b74ca20ff3bb42caca8b00055111d74159ee95c4370882bbff2b93d24da725"},
+    {file = "threadpoolctl-2.1.0.tar.gz", hash = "sha256:ddc57c96a38beb63db45d6c159b5ab07b6bced12c45a1f07b2b92f272aebfa6b"},
+]
 toml = [
     {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
     {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
diff --git a/pycgtool/__main__.py b/pycgtool/__main__.py
index a49263f..102e94f 100755
--- a/pycgtool/__main__.py
+++ b/pycgtool/__main__.py
@@ -8,6 +8,8 @@ import sys
 import textwrap
 import typing
 
+from mdplus.multiscale import GLIMPS
+import numpy as np
 from rich.logging import RichHandler
 
 from .frame import Frame
@@ -83,6 +85,39 @@ def mapping_loop(frame: Frame, config) -> typing.Tuple[Frame, Mapping]:
     return cg_frame, mapping
 
 
+def get_coords(frame: Frame, resname: str) -> np.ndarray:
+    return np.concatenate([
+        frame._trajectory.atom_slice([atom.index for atom in residue.atoms]).xyz
+        for residue in frame._trajectory.topology.residues
+        if residue.name == resname
+    ])
+
+
+def train_backmapper(aa_frame: Frame, cg_frame: Frame):
+    # resname = 'POPC'
+    # aa_coords = get_coords(aa_frame, resname)
+    # cg_coords = get_coords(cg_frame, resname)
+
+    cg_subset_traj = cg_frame._trajectory.atom_slice(cg_frame._trajectory.topology.select('resid 1'))
+    aa_subset_traj = aa_frame._trajectory.atom_slice(aa_frame._trajectory.topology.select('resid 1'))
+
+    cg_subset_traj.save('cg_test.gro')
+    aa_subset_traj.save('aa_test.gro')
+
+    logger.info('Training backmapper')
+    backmapper = GLIMPS()
+    backmapper.fit(cg_subset_traj.xyz, aa_subset_traj.xyz)
+    logger.info('Finished training backmapper')
+
+    logger.info('Testing backmapper')
+    backmapped = backmapper.transform(cg_subset_traj.xyz)
+    aa_subset_traj.xyz = backmapped
+    aa_subset_traj.save('backmapped.gro')
+    logger.info('Finished testing backmapper')
+
+    backmapper.save('backmapper.pkl')
+
+
 def full_run(config):
     """Main function of the program PyCGTOOL.
 
@@ -95,9 +130,11 @@ def full_run(config):
         trajectory_file=config.trajectory,  # May be None
         frame_start=config.begin,
         frame_end=config.end)
+    frame._trajectory.make_molecules_whole(inplace=True)
 
     if config.mapping:
         cg_frame, mapping = mapping_loop(frame, config)
+        train_backmapper(frame, cg_frame)
 
     else:
         logger.info('Skipping AA->CG mapping')
diff --git a/pyproject.toml b/pyproject.toml
index 3db3f8c..52fe41c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,6 +35,7 @@ numpy = [
 cython = "^0.29.21"
 mdtraj = "^1.9.5"
 rich = "^9.2.0"
+mdplus = { git = "https://bitbucket.org/claughton/mdplus.git" }
 
 [tool.poetry.dev-dependencies]
 prospector = "^1.3.0"
-- 
GitLab