From 29e8836cf4f6153315dc829ae43025ce02b95ccc Mon Sep 17 00:00:00 2001
From: jack-parsons <jack.parsons.uk@icloud.com>
Date: Tue, 23 Jul 2019 16:57:27 +0100
Subject: [PATCH] Adding tests for flip and rotMatrix.

Added validation to both flip and rotMatrix.
---
 AmpScan/core.py    | 35 +++++++++++++++++++++++++++++------
 tests/test_core.py | 45 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 74 insertions(+), 6 deletions(-)

diff --git a/AmpScan/core.py b/AmpScan/core.py
index c691c91..94d8c46 100644
--- a/AmpScan/core.py
+++ b/AmpScan/core.py
@@ -350,6 +350,11 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
         >>> ang = [np.pi/2, -np.pi/4, np.pi/3]
         >>> amp.rotateAng(ang, ang='rad')
         """
+
+        # Check that ang is valid
+        if ang not in ('rad', 'deg'):
+            raise ValueError("Ang expected 'rad' or 'deg' but {} was found".format(ang))
+
         if isinstance(rot, (tuple, list, np.ndarray)):
             R = self.rotMatrix(rot, ang)
             self.rotate(R, norms)
@@ -423,7 +428,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
         rot: array_like
             Rotation around [x, y, z]
         ang: str, default 'rad'
-            Specift if the Euler angles are in degrees or radians 
+            Specify if the Euler angles are in degrees or radians
         
         Returns
         -------
@@ -431,8 +436,20 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
             The calculated 3x3 rotation matrix 
     
         """
+
+        # Check that rot is valid
+        if not isinstance(rot, (tuple, list, np.ndarray)):
+            raise TypeError("Expecting array-like rotation, but found: "+type(rot))
+        elif len(rot) != 3:
+            raise ValueError("Expecting 3 arguments but found: {}".format(len(rot)))
+
+        # Check that ang is valid
+        if ang not in ('rad', 'deg'):
+            raise ValueError("Ang expected 'rad' or 'deg' but {} was found".format(ang))
+
         if ang == 'deg':
             rot = np.deg2rad(rot)
+
         [angx, angy, angz] = rot
         Rx = np.array([[1, 0, 0],
                        [0, np.cos(angx), -np.sin(angx)],
@@ -456,8 +473,14 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
             The axis in which to flip the mesh
 
         """
-        self.vert[:, axis] *= -1.0
-        # Switch face order to normals face same direction
-        self.faces[:, [1, 2]] = self.faces[:, [2, 1]]
-        self.calcNorm()
-        self.calcVNorm()
+        if isinstance(axis, int):
+            if 0 <= axis < 3:  # Check axis is between 0-2
+                self.vert[:, axis] *= -1.0
+                # Switch face order to normals face same direction
+                self.faces[:, [1, 2]] = self.faces[:, [2, 1]]
+                self.calcNorm()
+                self.calcVNorm()
+            else:
+                raise ValueError("Expected axis to be within range 0-2 but found: {}".format(axis))
+        else:
+            raise TypeError("Expected axis to be int, but found: {}".format(type(axis)))
diff --git a/tests/test_core.py b/tests/test_core.py
index e53afd2..4c06037 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -54,6 +54,12 @@ class TestCore(unittest.TestCase):
         with self.assertRaises(TypeError):
             self.amp.rotateAng(dict())
 
+        # Tests that incorrect number of elements causes ValueError
+        with self.assertRaises(ValueError):
+            self.amp.rotateAng(rot, "test")
+        with self.assertRaises(ValueError):
+            self.amp.rotateAng(rot, [])
+
     def test_rotate(self):
         """Tests the rotate method of AmpObject"""
         # A test rotation and translation using list
@@ -121,6 +127,45 @@ class TestCore(unittest.TestCase):
         with self.assertRaises(TypeError):
             self.amp.rigidTransform(R=7)
 
+    def test_rot_matrix(self):
+        """Tests the rotMatrix method in AmpObject"""
+
+        # Tests that a transformation by 0 in all axis is 0 matrix
+        all(self.amp.rotMatrix([0, 0, 0])[y][x] == 0
+            for x in range(3)
+            for y in range(3))
+
+        expected = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
+        all(self.amp.rotMatrix([np.pi/6, 0, 0])[y][x] == expected[y][x]
+            for x in range(3)
+            for y in range(3))
+
+        # Tests that string passed into rot causes TypeError
+        with self.assertRaises(TypeError):
+            self.amp.rotMatrix(" ")
+        with self.assertRaises(TypeError):
+            self.amp.rotMatrix(dict())
+
+        # Tests that incorrect number of elements causes ValueError
+        with self.assertRaises(ValueError):
+            self.amp.rotMatrix([0, 1])
+        with self.assertRaises(ValueError):
+            self.amp.rotMatrix([0, 1, 3, 0])
+
+    def test_flip(self):
+        """Tests the flip method in AmpObject"""
+        # Check invalid axis types cause TypeError
+        with self.assertRaises(TypeError):
+            self.amp.flip(" ")
+        with self.assertRaises(TypeError):
+            self.amp.flip(dict())
+
+        # Check invalid axis values cause ValueError
+        with self.assertRaises(ValueError):
+            self.amp.flip(-1)
+        with self.assertRaises(ValueError):
+            self.amp.flip(3)
+
 
     @staticmethod
     def get_path(filename):
-- 
GitLab