From f1d0e32a2da82a4e64df8aee6a0ff3244024424a Mon Sep 17 00:00:00 2001
From: jack-parsons <jack.parsons.uk@icloud.com>
Date: Wed, 24 Jul 2019 17:28:04 +0100
Subject: [PATCH] Switching to using pytest for automated doctests and unittest
 discovery.

Currently GUI docs are not working, but I'll wait until we manage branches to fix them.
---
 .gitlab-ci.yml                                | 10 +---
 AmpScan/align.py                              |  2 +-
 AmpScan/core.py                               | 13 +++--
 AmpScan/ssm.py                                | 37 ++++++-------
 AmpScan/trim.py                               |  9 +++-
 .../sample_stl_sphere_ASCII.stl               |  0
 tests/pca_tests/sample_stl_sphere_BIN.stl     |  3 ++
 tests/sample_test_local.py                    | 53 -------------------
 8 files changed, 38 insertions(+), 89 deletions(-)
 rename tests/{ => ascii_examples}/sample_stl_sphere_ASCII.stl (100%)
 create mode 100644 tests/pca_tests/sample_stl_sphere_BIN.stl
 delete mode 100644 tests/sample_test_local.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index b36ed0e..fea2f94 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,8 +1,2 @@
-unittests:
-    script: python -m unittest discover tests -v
-core doctests:
-    script: python -m doctest -v AmpScan/core.py
-registration doctests:
-    script: python -m doctest -v AmpScan/registration.py
-align doctests:
-    script: python -m doctest -v AmpScan/align.py
\ No newline at end of file
+doctests and unitests:
+    script: pytest --doctest-modules -v --ignore=GUIs
\ No newline at end of file
diff --git a/AmpScan/align.py b/AmpScan/align.py
index 2049023..6327e4c 100644
--- a/AmpScan/align.py
+++ b/AmpScan/align.py
@@ -280,7 +280,7 @@ class align(object):
             
         Examples
         --------
-        >>> import AmpScan, os
+        >>> import os, AmpScan
         >>> staticfh = os.getcwd()+"\\tests\\stl_file.stl"
         >>> movingfh = os.getcwd()+"\\tests\\stl_file_2.stl"
         >>> static = AmpScan.AmpObject(staticfh)
diff --git a/AmpScan/core.py b/AmpScan/core.py
index 834c775..1e2f836 100644
--- a/AmpScan/core.py
+++ b/AmpScan/core.py
@@ -6,12 +6,18 @@ Copyright: Joshua Steer 2018, Joshua.Steer@soton.ac.uk
 """
 
 import numpy as np
+import os
 import struct
 from AmpScan.trim import trimMixin
 from AmpScan.smooth import smoothMixin
 from AmpScan.analyse import analyseMixin
 from AmpScan.ampVis import visMixin
 
+
+# The file path used in doc examples
+filename = os.getcwd()+"\\tests\\stl_file.stl"
+
+
 class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
     r"""
     Base class for the AmpScan project.
@@ -36,9 +42,6 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
     
     Examples
     -------
-    >>> import AmpScan
-    >>> import os
-    >>> filename = os.getcwd()+"\\tests\\stl_file.stl"
     >>> amp = AmpObject(filename)
 
     """
@@ -151,8 +154,6 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
         
         Examples
         --------
-        >>> import os
-        >>> filename = os.getcwd()+"\\tests\\stl_file.stl"
         >>> amp = AmpObject(filename, unify=False)
         >>> amp.vert.shape
         (44832, 3)
@@ -349,8 +350,6 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
         
         Examples
         --------
-        >>> import os
-        >>> filename = os.getcwd()+"\\tests\\stl_file.stl"
         >>> amp = AmpObject(filename)
         >>> ang = [np.pi/2, -np.pi/4, np.pi/3]
         >>> amp.rotateAng(ang, ang='rad')
diff --git a/AmpScan/ssm.py b/AmpScan/ssm.py
index 4cb0541..b1b685b 100644
--- a/AmpScan/ssm.py
+++ b/AmpScan/ssm.py
@@ -19,20 +19,20 @@ class pca(object):
     
     Examples
     --------
+    >>> import os
     >>> p = pca()
-    >>> p.importFolder('/path/')
-    >>> p.baseline('dir/baselinefh.stl')
-    >>> p.register(save = '/regpath/')
+    >>> p.importFolder(os.getcwd()+"\\tests\\pca_tests")
+    >>> p.setBaseline(os.getcwd()+"\\tests\\pca_tests\\sample_stl_sphere_BIN.stl")
+    >>> p.register(save=os.getcwd()+"\\tests\\pca_tests\\")
     >>> p.pca()
-    >>> sfs = [0, 0.1, -0.5 ... 0]
+    >>> sfs = [1, 2]
     >>> newS = p.newShape(sfs)
     
     """
-    
+
     def __init__(self):
         self.shapes = []
-        
-        
+
     def setBaseline(self, baseline):
         r"""
         Function to set the baseline mesh used for registration of the 
@@ -45,8 +45,7 @@ class pca(object):
             
         """
         self.baseline = AmpObject(baseline, 'limb')
-        
-        
+
     def importFolder(self, path, unify=True):
         r"""
         Function to import multiple stl files from folder into the pca object 
@@ -64,7 +63,7 @@ class pca(object):
         self.shapes = [AmpObject(os.path.join(path, f), 'limb', unify=unify) for f in self.fnames]
         for s in self.shapes:
             s.lp_smooth(3, brim=True)
-        
+
     def sliceFiles(self, height):
         r"""
         Function to run a planar trim on all the training data for the PCA 
@@ -78,7 +77,7 @@ class pca(object):
         """
         for s in self.shapes:
             s.planarTrim(height)
-        
+
     def register(self, scale=None, save=None, baseline=True):
         r"""
         Register all the AmpObject training data to the baseline AmpObject 
@@ -105,8 +104,7 @@ class pca(object):
         self.X = np.array([r.vert.flatten() for r in self.registered]).T
         if baseline is True:
             self.X = np.c_[self.X, self.baseline.vert.flatten()]
-        
-        
+
     def pca(self):
         r"""
         Function to run mean centered pca using a singular value decomposition 
@@ -118,8 +116,8 @@ class pca(object):
         (self.pca_U, self.pca_S, self.pca_V) = np.linalg.svd(X_meanC, full_matrices=False)
         self.pc_weights = np.dot(np.diag(self.pca_S), self.pca_V)
         self.pc_stdevs = np.std(self.pc_weights, axis=1)
-    
-    def newShape(self, sfs, scale = 'eigs'):
+
+    def newShape(self, sfs, scale='eigs'):
         r"""
         Function to calculate a new shape based upon the eigenvalues 
         or stdevs
@@ -134,11 +132,14 @@ class pca(object):
             to standard deviations about the mean
 
         """
-        try: len(sfs) == len(self.pc_stdevs)
-        except: ValueError('sfs must be of the same length as the number of '
-                           'principal components')
+        sfs = np.array(sfs)
+        if not sfs.shape == self.pc_stdevs.shape:
+            raise ValueError('sfs must be of the same length as the number of '
+                             'principal components (expected {} but found {})'.format(self.pc_stdevs.shape, sfs.shape))
         if scale == 'eigs':
             sf = (self.pca_U * sfs).sum(axis=1)
         elif scale == 'std':
             sf = (self.pca_U * self.pc_stdevs * sfs).sum(axis=1)
+        else:
+            raise ValueError("Invalid scale (expected 'eigs' or 'std' but found{}".format(scale))
         return self.pca_mean + sf
diff --git a/AmpScan/trim.py b/AmpScan/trim.py
index 7bc61a2..cc0abef 100644
--- a/AmpScan/trim.py
+++ b/AmpScan/trim.py
@@ -6,6 +6,8 @@ Copyright: Joshua Steer 2018, Joshua.Steer@soton.ac.uk
 
 import numpy as np
 from numbers import Number
+import os
+
 
 class trimMixin(object):
     r"""
@@ -27,7 +29,10 @@ class trimMixin(object):
         
         Examples
         --------
-        >>> amp = AmpObject(fh)
+
+        >>> from AmpScan import AmpObject
+        >>> filename = os.getcwd()+"\\tests\\stl_file.stl"
+        >>> amp = AmpObject(filename)
         >>> amp.planarTrim(100, 2)
 
         """
@@ -53,4 +58,4 @@ class trimMixin(object):
             self.values = self.values[~delv]
             self.calcStruct()
         else:
-            raise TypeError("height arg must be a float")
\ No newline at end of file
+            raise TypeError("height arg must be a float")
diff --git a/tests/sample_stl_sphere_ASCII.stl b/tests/ascii_examples/sample_stl_sphere_ASCII.stl
similarity index 100%
rename from tests/sample_stl_sphere_ASCII.stl
rename to tests/ascii_examples/sample_stl_sphere_ASCII.stl
diff --git a/tests/pca_tests/sample_stl_sphere_BIN.stl b/tests/pca_tests/sample_stl_sphere_BIN.stl
new file mode 100644
index 0000000..de40fb5
--- /dev/null
+++ b/tests/pca_tests/sample_stl_sphere_BIN.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f92d690baf69fd2dc5817c90ee85c5ffb9b36dc627521dc6be136c6c7fd4de6
+size 64084
diff --git a/tests/sample_test_local.py b/tests/sample_test_local.py
deleted file mode 100644
index 7232f18..0000000
--- a/tests/sample_test_local.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import unittest
-import os
-import sys
-
-class TestBasicFunction(unittest.TestCase):
-
-    def SetUp(self):
-        modPath = os.path.abspath(os.getcwd())
-        sys.path.insert(0, modPath)
-    
-    def test_running(self):
-        print("Running sample_test.py")
-        self.assertTrue(True)
-    
-    def test_python_imports(self):
-        import numpy, scipy, matplotlib, vtk, AmpScan.core
-        s = str(type(numpy))
-        self.assertEqual(s, "<class 'module'>")
-        s = str(type(scipy))
-        self.assertEqual(s, "<class 'module'>")
-        s = str(type(matplotlib))
-        self.assertEqual(s, "<class 'module'>")
-        s = str(type(vtk))
-        self.assertEqual(s, "<class 'module'>")
-        s = str(type(AmpScan.core))
-        self.assertEqual(s, "<class 'module'>", "Failed import: AmpScan.core")
-
-    @unittest.expectedFailure
-    def test_failure(self):
-        s = str(type("string"))
-        self.assertEqual(s, "<class 'module'>")
-
-    def test_rotate(self):
-        stlPath = os.path.abspath(os.getcwd()) + "/sample_stl_sphere_BIN.stl"
-        from AmpScan.core import AmpObject
-        Amp = AmpObject(stlPath)
-        s = str(type(Amp))
-        self.assertEqual(s, "<class 'AmpScan.core.AmpObject'>", "Not expected Object")
-        with self.assertRaises(TypeError):
-            Amp.rotateAng(7)
-            Amp.rotateAng({})
-
-    def test_trim(self):
-        # a new test for the trim module
-        stlPath = os.path.abspath(os.getcwd()) + "/sample_stl_sphere_BIN.stl"
-        from AmpScan.core import AmpObject
-        Amp = AmpObject(stlPath)
-        #with self.assertRaises(TypeError):
-            #Amp.planarTrim([], plane=[])
-        
-
-if __name__ == '__main__':
-    unittest.main()
-- 
GitLab