From b190d03ba43259e5ffe86316aec13ce0a07614ca Mon Sep 17 00:00:00 2001
From: jack-parsons <jack.parsons.uk@icloud.com>
Date: Tue, 23 Jul 2019 13:29:32 +0100
Subject: [PATCH] Fixing and improving rotation test

---
 AmpScan/core.py     | 10 ++++++++--
 tests/core_tests.py | 46 +++++++++++++++++++++++++++++++++++----------
 tests/test_trim.py  |  3 ++-
 3 files changed, 46 insertions(+), 13 deletions(-)

diff --git a/AmpScan/core.py b/AmpScan/core.py
index 0e5e8f7..7463ebd 100644
--- a/AmpScan/core.py
+++ b/AmpScan/core.py
@@ -389,9 +389,15 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
             
         """
         if R is not None:
-            self.rotate(R, True)
+            if isinstance(R, (tuple, list, np.ndarray)):
+                self.rotate(R, True)
+            else:
+                raise TypeError("Expecting array-like rotation, but found: "+type(R))
         if T is not None:
-            self.translate(T)
+            if isinstance(T, (tuple, list, np.ndarray)):
+                self.translate(T)
+            else:
+                raise TypeError("Expecting array-like translation, but found: "+type(T))
         
 
     @staticmethod
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 4af62ef..801878a 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -5,6 +5,7 @@ Testing suite for the core functionality
 import unittest
 import os
 import numpy as np
+from random import randrange
 
 
 def suite():
@@ -45,13 +46,16 @@ class TestCore(unittest.TestCase):
         Tests the rotate method of AmpObject
         """
 
-        # Test rotation on first node
-        rot = [np.pi/2, -np.pi/4, np.pi/3]
-        before_vert_pos = self.amp.vert[0][:]
+        # Test rotation on random node
+        n = randrange(len(self.amp.vert))
+        rot = [0, 0, np.pi/3]
+        before = self.amp.vert[n].copy()
         self.amp.rotateAng(rot)
-        after_vert_pos = self.amp.vert[0][:]
-        np.dot(before_vert_pos, rot)
-        self.assertAlmostEqual(before_vert_pos, after_vert_pos, TestCore.ACCURACY)
+        after_vert_pos = self.amp.vert[n].copy()
+        # Use 2D rotation matrix formula to test rotate method on z axis
+        expected = [np.cos(rot[2])*before[0]-np.sin(rot[2])*before[1], np.sin(rot[2])*before[0]+np.cos(rot[2])*before[1], before[2]]
+        # Check all coordinate dimensions are correct
+        all(self.assertAlmostEqual(expected[i], after_vert_pos[i], TestCore.ACCURACY) for i in range(3))
 
         # Check single floats cause TypeError
         with self.assertRaises(TypeError):
@@ -67,15 +71,15 @@ class TestCore(unittest.TestCase):
         """
 
         # Check that everything has been translated correctly to a certain accuracy
-        start = self.amp.vert.mean(axis=0)[:]
+        start = self.amp.vert.mean(axis=0).copy()
         self.amp.translate([1, -1, 0])
-        end = self.amp.vert.mean(axis=0)[:]
+        end = self.amp.vert.mean(axis=0).copy()
         self.assertAlmostEqual(start[0]+1, end[0], places=TestCore.ACCURACY)
         self.assertAlmostEqual(start[1]-1, end[1], places=TestCore.ACCURACY)
         self.assertAlmostEqual(start[2], end[2], places=TestCore.ACCURACY)
 
         # Check that translating raises TypeError when translating with an invalid type
-        with self.assertRaises(Exception):
+        with self.assertRaises(TypeError):
             self.amp.translate("")
 
         # Check that translating raises ValueError when translating with 2 dimensions
@@ -86,7 +90,29 @@ class TestCore(unittest.TestCase):
         with self.assertRaises(ValueError):
             self.amp.translate([0, 0, 0, 0])
 
-    def get_path(self, filename):
+    def test_rigid_transform(self):
+        """
+        Test the rigid transform method of AmpObject
+        """
+
+        # before_vert_pos = self.amp.vert[0][:]
+        # rot = [np.pi/6, -np.pi/2, np.pi/3]
+        # tran = [-1, 0, 1]
+        # self.amp.rigidTransform(R=rot, T=tran)
+        # after_vert_pos = self.amp.vert[0][:]
+        # np.dot(before_vert_pos, rot)
+        # self.assertAlmostEqual()
+
+        # Check that translating raises TypeError when translating with an invalid type
+        with self.assertRaises(TypeError):
+            self.amp.rigidTransform(T=dict())
+
+        # Check that rotating raises TypeError when translating with an invalid type
+        with self.assertRaises(TypeError):
+            self.amp.rigidTransform(R=7)
+
+    @staticmethod
+    def get_path(filename):
         """
         Returns the absolute path to the testing files
 
diff --git a/tests/test_trim.py b/tests/test_trim.py
index 169bc41..82c49dd 100644
--- a/tests/test_trim.py
+++ b/tests/test_trim.py
@@ -30,7 +30,8 @@ class TestTrim(unittest.TestCase):
         with self.assertRaises(TypeError):
             self.amp.planarTrim([], plane=[])
 
-    def get_path(self, filename):
+    @staticmethod
+    def get_path(filename):
         """
         Returns the absolute path to the testing files
 
-- 
GitLab