From 29e8836cf4f6153315dc829ae43025ce02b95ccc Mon Sep 17 00:00:00 2001 From: jack-parsons <jack.parsons.uk@icloud.com> Date: Tue, 23 Jul 2019 16:57:27 +0100 Subject: [PATCH] Adding tests for flip and rotMatrix. Added validation to both flip and rotMatrix. --- AmpScan/core.py | 35 +++++++++++++++++++++++++++++------ tests/test_core.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/AmpScan/core.py b/AmpScan/core.py index c691c91..94d8c46 100644 --- a/AmpScan/core.py +++ b/AmpScan/core.py @@ -350,6 +350,11 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): >>> ang = [np.pi/2, -np.pi/4, np.pi/3] >>> amp.rotateAng(ang, ang='rad') """ + + # Check that ang is valid + if ang not in ('rad', 'deg'): + raise ValueError("Ang expected 'rad' or 'deg' but {} was found".format(ang)) + if isinstance(rot, (tuple, list, np.ndarray)): R = self.rotMatrix(rot, ang) self.rotate(R, norms) @@ -423,7 +428,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): rot: array_like Rotation around [x, y, z] ang: str, default 'rad' - Specift if the Euler angles are in degrees or radians + Specify if the Euler angles are in degrees or radians Returns ------- @@ -431,8 +436,20 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): The calculated 3x3 rotation matrix """ + + # Check that rot is valid + if not isinstance(rot, (tuple, list, np.ndarray)): + raise TypeError("Expecting array-like rotation, but found: "+type(rot)) + elif len(rot) != 3: + raise ValueError("Expecting 3 arguments but found: {}".format(len(rot))) + + # Check that ang is valid + if ang not in ('rad', 'deg'): + raise ValueError("Ang expected 'rad' or 'deg' but {} was found".format(ang)) + if ang == 'deg': rot = np.deg2rad(rot) + [angx, angy, angz] = rot Rx = np.array([[1, 0, 0], [0, np.cos(angx), -np.sin(angx)], @@ -456,8 +473,14 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): The axis in which to flip the mesh """ - self.vert[:, axis] *= -1.0 - # Switch face order to normals face same direction - self.faces[:, [1, 2]] = self.faces[:, [2, 1]] - self.calcNorm() - self.calcVNorm() + if isinstance(axis, int): + if 0 <= axis < 3: # Check axis is between 0-2 + self.vert[:, axis] *= -1.0 + # Switch face order to normals face same direction + self.faces[:, [1, 2]] = self.faces[:, [2, 1]] + self.calcNorm() + self.calcVNorm() + else: + raise ValueError("Expected axis to be within range 0-2 but found: {}".format(axis)) + else: + raise TypeError("Expected axis to be int, but found: {}".format(type(axis))) diff --git a/tests/test_core.py b/tests/test_core.py index e53afd2..4c06037 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -54,6 +54,12 @@ class TestCore(unittest.TestCase): with self.assertRaises(TypeError): self.amp.rotateAng(dict()) + # Tests that incorrect number of elements causes ValueError + with self.assertRaises(ValueError): + self.amp.rotateAng(rot, "test") + with self.assertRaises(ValueError): + self.amp.rotateAng(rot, []) + def test_rotate(self): """Tests the rotate method of AmpObject""" # A test rotation and translation using list @@ -121,6 +127,45 @@ class TestCore(unittest.TestCase): with self.assertRaises(TypeError): self.amp.rigidTransform(R=7) + def test_rot_matrix(self): + """Tests the rotMatrix method in AmpObject""" + + # Tests that a transformation by 0 in all axis is 0 matrix + all(self.amp.rotMatrix([0, 0, 0])[y][x] == 0 + for x in range(3) + for y in range(3)) + + expected = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]] + all(self.amp.rotMatrix([np.pi/6, 0, 0])[y][x] == expected[y][x] + for x in range(3) + for y in range(3)) + + # Tests that string passed into rot causes TypeError + with self.assertRaises(TypeError): + self.amp.rotMatrix(" ") + with self.assertRaises(TypeError): + self.amp.rotMatrix(dict()) + + # Tests that incorrect number of elements causes ValueError + with self.assertRaises(ValueError): + self.amp.rotMatrix([0, 1]) + with self.assertRaises(ValueError): + self.amp.rotMatrix([0, 1, 3, 0]) + + def test_flip(self): + """Tests the flip method in AmpObject""" + # Check invalid axis types cause TypeError + with self.assertRaises(TypeError): + self.amp.flip(" ") + with self.assertRaises(TypeError): + self.amp.flip(dict()) + + # Check invalid axis values cause ValueError + with self.assertRaises(ValueError): + self.amp.flip(-1) + with self.assertRaises(ValueError): + self.amp.flip(3) + @staticmethod def get_path(filename): -- GitLab