From 1567614d7efc338a39e9f4e10dabae239f70382e Mon Sep 17 00:00:00 2001 From: jack-parsons <jack.parsons.uk@icloud.com> Date: Tue, 23 Jul 2019 14:17:06 +0100 Subject: [PATCH] Validation for AmpObject rotate method. Fixes #43 Added checks for incorrect dimensions in rotate method of AmpObject. Added automatic casting to np array if a list or tuple is passed in. Added tests for these new functions --- AmpScan/core.py | 12 ++++++++++++ tests/core_tests.py | 46 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/AmpScan/core.py b/AmpScan/core.py index 7463ebd..c691c91 100644 --- a/AmpScan/core.py +++ b/AmpScan/core.py @@ -368,6 +368,18 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): norms: boolean, default True """ + if isinstance(R, (list, tuple)): + # Make R a np array if its a list or tuple + R = np.array(R, np.float) + elif not isinstance(R, np.ndarray): + # If + raise TypeError("Expected R to be array-like but found: " + str(type(R))) + if len(R) != 3 or len(R[0]) != 3: + # Incorrect dimensions + if isinstance(R, np.ndarray): + raise ValueError("Expected 3x3 array, but found: {}".format(R.shape)) + else: + raise ValueError("Expected 3x3 array, but found: 3x"+str(len(R))) self.vert[:, :] = np.dot(self.vert, R.T) if norms is True: self.norm[:, :] = np.dot(self.norm, R.T) diff --git a/tests/core_tests.py b/tests/core_tests.py index 801878a..1a4edb5 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -41,9 +41,9 @@ class TestCore(unittest.TestCase): # Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY) self.assertTrue(all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3))) - def test_rotate(self): + def test_rotate_ang(self): """ - Tests the rotate method of AmpObject + Tests the rotateAng method of AmpObject """ # Test rotation on random node @@ -65,6 +65,30 @@ class TestCore(unittest.TestCase): with self.assertRaises(TypeError): self.amp.rotateAng(dict()) + def test_rotate(self): + """ + Tests the rotate method of AmpObject + """ + # A test rotation and translation using list + m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]] + self.amp.rotate(m) + + # Check single floats cause TypeError + with self.assertRaises(TypeError): + self.amp.rotate(7) + + # Check dictionaries cause TypeError + with self.assertRaises(TypeError): + self.amp.rotate(dict()) + + # Check invalid dimensions cause ValueError + with self.assertRaises(ValueError): + self.amp.rotate([]) + with self.assertRaises(ValueError): + self.amp.rotate([[0, 0, 1]]) + with self.assertRaises(ValueError): + self.amp.rotate([[], [], []]) + def test_translate(self): """ Test translating method of AmpObject @@ -95,13 +119,16 @@ class TestCore(unittest.TestCase): Test the rigid transform method of AmpObject """ - # before_vert_pos = self.amp.vert[0][:] - # rot = [np.pi/6, -np.pi/2, np.pi/3] - # tran = [-1, 0, 1] - # self.amp.rigidTransform(R=rot, T=tran) - # after_vert_pos = self.amp.vert[0][:] - # np.dot(before_vert_pos, rot) - # self.assertAlmostEqual() + # Test if no transform is applied, vertices aren't affected + before_vert = self.amp.vert.copy() + self.amp.rigidTransform(R=None, T=None) + all(self.assertEqual(self.amp.vert[y][x], before_vert[y][x]) + for y in range(len(self.amp.vert)) + for x in range(len(self.amp.vert[0]))) + + # A test rotation and translation + m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]] + self.amp.rigidTransform(R=m, T=[1, 0, -1]) # Check that translating raises TypeError when translating with an invalid type with self.assertRaises(TypeError): @@ -111,6 +138,7 @@ class TestCore(unittest.TestCase): with self.assertRaises(TypeError): self.amp.rigidTransform(R=7) + @staticmethod def get_path(filename): """ -- GitLab