Skip to content
Snippets Groups Projects
Commit 1567614d authored by jack-parsons's avatar jack-parsons
Browse files

Validation for AmpObject rotate method. Fixes #43

Added checks for incorrect dimensions in rotate method of AmpObject.
Added automatic casting to np array if a list or tuple is passed in.
Added tests for these new functions
parent b190d03b
No related branches found
No related tags found
1 merge request!23Merge in Jack's changes
Pipeline #852 passed
......@@ -368,6 +368,18 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
norms: boolean, default True
"""
if isinstance(R, (list, tuple)):
# Make R a np array if its a list or tuple
R = np.array(R, np.float)
elif not isinstance(R, np.ndarray):
# If
raise TypeError("Expected R to be array-like but found: " + str(type(R)))
if len(R) != 3 or len(R[0]) != 3:
# Incorrect dimensions
if isinstance(R, np.ndarray):
raise ValueError("Expected 3x3 array, but found: {}".format(R.shape))
else:
raise ValueError("Expected 3x3 array, but found: 3x"+str(len(R)))
self.vert[:, :] = np.dot(self.vert, R.T)
if norms is True:
self.norm[:, :] = np.dot(self.norm, R.T)
......
......@@ -41,9 +41,9 @@ class TestCore(unittest.TestCase):
# Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY)
self.assertTrue(all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3)))
def test_rotate(self):
def test_rotate_ang(self):
"""
Tests the rotate method of AmpObject
Tests the rotateAng method of AmpObject
"""
# Test rotation on random node
......@@ -65,6 +65,30 @@ class TestCore(unittest.TestCase):
with self.assertRaises(TypeError):
self.amp.rotateAng(dict())
def test_rotate(self):
"""
Tests the rotate method of AmpObject
"""
# A test rotation and translation using list
m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
self.amp.rotate(m)
# Check single floats cause TypeError
with self.assertRaises(TypeError):
self.amp.rotate(7)
# Check dictionaries cause TypeError
with self.assertRaises(TypeError):
self.amp.rotate(dict())
# Check invalid dimensions cause ValueError
with self.assertRaises(ValueError):
self.amp.rotate([])
with self.assertRaises(ValueError):
self.amp.rotate([[0, 0, 1]])
with self.assertRaises(ValueError):
self.amp.rotate([[], [], []])
def test_translate(self):
"""
Test translating method of AmpObject
......@@ -95,13 +119,16 @@ class TestCore(unittest.TestCase):
Test the rigid transform method of AmpObject
"""
# before_vert_pos = self.amp.vert[0][:]
# rot = [np.pi/6, -np.pi/2, np.pi/3]
# tran = [-1, 0, 1]
# self.amp.rigidTransform(R=rot, T=tran)
# after_vert_pos = self.amp.vert[0][:]
# np.dot(before_vert_pos, rot)
# self.assertAlmostEqual()
# Test if no transform is applied, vertices aren't affected
before_vert = self.amp.vert.copy()
self.amp.rigidTransform(R=None, T=None)
all(self.assertEqual(self.amp.vert[y][x], before_vert[y][x])
for y in range(len(self.amp.vert))
for x in range(len(self.amp.vert[0])))
# A test rotation and translation
m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
self.amp.rigidTransform(R=m, T=[1, 0, -1])
# Check that translating raises TypeError when translating with an invalid type
with self.assertRaises(TypeError):
......@@ -111,6 +138,7 @@ class TestCore(unittest.TestCase):
with self.assertRaises(TypeError):
self.amp.rigidTransform(R=7)
@staticmethod
def get_path(filename):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment