Skip to content
Snippets Groups Projects
Commit 1567614d authored by jack-parsons's avatar jack-parsons
Browse files

Validation for AmpObject rotate method. Fixes #43

Added checks for incorrect dimensions in rotate method of AmpObject.
Added automatic casting to np array if a list or tuple is passed in.
Added tests for these new functions
parent b190d03b
No related branches found
No related tags found
1 merge request!23Merge in Jack's changes
Pipeline #852 passed
...@@ -368,6 +368,18 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -368,6 +368,18 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
norms: boolean, default True norms: boolean, default True
""" """
if isinstance(R, (list, tuple)):
# Make R a np array if its a list or tuple
R = np.array(R, np.float)
elif not isinstance(R, np.ndarray):
# If
raise TypeError("Expected R to be array-like but found: " + str(type(R)))
if len(R) != 3 or len(R[0]) != 3:
# Incorrect dimensions
if isinstance(R, np.ndarray):
raise ValueError("Expected 3x3 array, but found: {}".format(R.shape))
else:
raise ValueError("Expected 3x3 array, but found: 3x"+str(len(R)))
self.vert[:, :] = np.dot(self.vert, R.T) self.vert[:, :] = np.dot(self.vert, R.T)
if norms is True: if norms is True:
self.norm[:, :] = np.dot(self.norm, R.T) self.norm[:, :] = np.dot(self.norm, R.T)
......
...@@ -41,9 +41,9 @@ class TestCore(unittest.TestCase): ...@@ -41,9 +41,9 @@ class TestCore(unittest.TestCase):
# Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY) # Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY)
self.assertTrue(all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3))) self.assertTrue(all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3)))
def test_rotate(self): def test_rotate_ang(self):
""" """
Tests the rotate method of AmpObject Tests the rotateAng method of AmpObject
""" """
# Test rotation on random node # Test rotation on random node
...@@ -65,6 +65,30 @@ class TestCore(unittest.TestCase): ...@@ -65,6 +65,30 @@ class TestCore(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.amp.rotateAng(dict()) self.amp.rotateAng(dict())
def test_rotate(self):
"""
Tests the rotate method of AmpObject
"""
# A test rotation and translation using list
m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
self.amp.rotate(m)
# Check single floats cause TypeError
with self.assertRaises(TypeError):
self.amp.rotate(7)
# Check dictionaries cause TypeError
with self.assertRaises(TypeError):
self.amp.rotate(dict())
# Check invalid dimensions cause ValueError
with self.assertRaises(ValueError):
self.amp.rotate([])
with self.assertRaises(ValueError):
self.amp.rotate([[0, 0, 1]])
with self.assertRaises(ValueError):
self.amp.rotate([[], [], []])
def test_translate(self): def test_translate(self):
""" """
Test translating method of AmpObject Test translating method of AmpObject
...@@ -95,13 +119,16 @@ class TestCore(unittest.TestCase): ...@@ -95,13 +119,16 @@ class TestCore(unittest.TestCase):
Test the rigid transform method of AmpObject Test the rigid transform method of AmpObject
""" """
# before_vert_pos = self.amp.vert[0][:] # Test if no transform is applied, vertices aren't affected
# rot = [np.pi/6, -np.pi/2, np.pi/3] before_vert = self.amp.vert.copy()
# tran = [-1, 0, 1] self.amp.rigidTransform(R=None, T=None)
# self.amp.rigidTransform(R=rot, T=tran) all(self.assertEqual(self.amp.vert[y][x], before_vert[y][x])
# after_vert_pos = self.amp.vert[0][:] for y in range(len(self.amp.vert))
# np.dot(before_vert_pos, rot) for x in range(len(self.amp.vert[0])))
# self.assertAlmostEqual()
# A test rotation and translation
m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
self.amp.rigidTransform(R=m, T=[1, 0, -1])
# Check that translating raises TypeError when translating with an invalid type # Check that translating raises TypeError when translating with an invalid type
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
...@@ -111,6 +138,7 @@ class TestCore(unittest.TestCase): ...@@ -111,6 +138,7 @@ class TestCore(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.amp.rigidTransform(R=7) self.amp.rigidTransform(R=7)
@staticmethod @staticmethod
def get_path(filename): def get_path(filename):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment