From 1567614d7efc338a39e9f4e10dabae239f70382e Mon Sep 17 00:00:00 2001
From: jack-parsons <jack.parsons.uk@icloud.com>
Date: Tue, 23 Jul 2019 14:17:06 +0100
Subject: [PATCH] Validation for AmpObject rotate method. Fixes #43

Added checks for incorrect dimensions in rotate method of AmpObject.
Added automatic casting to np array if a list or tuple is passed in.
Added tests for these new functions
---
 AmpScan/core.py     | 12 ++++++++++++
 tests/core_tests.py | 46 ++++++++++++++++++++++++++++++++++++---------
 2 files changed, 49 insertions(+), 9 deletions(-)

diff --git a/AmpScan/core.py b/AmpScan/core.py
index 7463ebd..c691c91 100644
--- a/AmpScan/core.py
+++ b/AmpScan/core.py
@@ -368,6 +368,18 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
         norms: boolean, default True
             
         """
+        if isinstance(R, (list, tuple)):
+            # Make R a np array if its a list or tuple
+            R = np.array(R, np.float)
+        elif not isinstance(R, np.ndarray):
+            # If
+            raise TypeError("Expected R to be array-like but found: " + str(type(R)))
+        if len(R) != 3 or len(R[0]) != 3:
+            # Incorrect dimensions
+            if isinstance(R, np.ndarray):
+                raise ValueError("Expected 3x3 array, but found: {}".format(R.shape))
+            else:
+                raise ValueError("Expected 3x3 array, but found: 3x"+str(len(R)))
         self.vert[:, :] = np.dot(self.vert, R.T)
         if norms is True:
             self.norm[:, :] = np.dot(self.norm, R.T)
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 801878a..1a4edb5 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -41,9 +41,9 @@ class TestCore(unittest.TestCase):
         # Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY)
         self.assertTrue(all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3)))
 
-    def test_rotate(self):
+    def test_rotate_ang(self):
         """
-        Tests the rotate method of AmpObject
+        Tests the rotateAng method of AmpObject
         """
 
         # Test rotation on random node
@@ -65,6 +65,30 @@ class TestCore(unittest.TestCase):
         with self.assertRaises(TypeError):
             self.amp.rotateAng(dict())
 
+    def test_rotate(self):
+        """
+        Tests the rotate method of AmpObject
+        """
+        # A test rotation and translation using list
+        m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
+        self.amp.rotate(m)
+
+        # Check single floats cause TypeError
+        with self.assertRaises(TypeError):
+            self.amp.rotate(7)
+
+        # Check dictionaries cause TypeError
+        with self.assertRaises(TypeError):
+            self.amp.rotate(dict())
+
+        # Check invalid dimensions cause ValueError
+        with self.assertRaises(ValueError):
+            self.amp.rotate([])
+        with self.assertRaises(ValueError):
+            self.amp.rotate([[0, 0, 1]])
+        with self.assertRaises(ValueError):
+            self.amp.rotate([[], [], []])
+
     def test_translate(self):
         """
         Test translating method of AmpObject
@@ -95,13 +119,16 @@ class TestCore(unittest.TestCase):
         Test the rigid transform method of AmpObject
         """
 
-        # before_vert_pos = self.amp.vert[0][:]
-        # rot = [np.pi/6, -np.pi/2, np.pi/3]
-        # tran = [-1, 0, 1]
-        # self.amp.rigidTransform(R=rot, T=tran)
-        # after_vert_pos = self.amp.vert[0][:]
-        # np.dot(before_vert_pos, rot)
-        # self.assertAlmostEqual()
+        # Test if no transform is applied, vertices aren't affected
+        before_vert = self.amp.vert.copy()
+        self.amp.rigidTransform(R=None, T=None)
+        all(self.assertEqual(self.amp.vert[y][x], before_vert[y][x])
+            for y in range(len(self.amp.vert))
+            for x in range(len(self.amp.vert[0])))
+
+        # A test rotation and translation
+        m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
+        self.amp.rigidTransform(R=m, T=[1, 0, -1])
 
         # Check that translating raises TypeError when translating with an invalid type
         with self.assertRaises(TypeError):
@@ -111,6 +138,7 @@ class TestCore(unittest.TestCase):
         with self.assertRaises(TypeError):
             self.amp.rigidTransform(R=7)
 
+
     @staticmethod
     def get_path(filename):
         """
-- 
GitLab