From b190d03ba43259e5ffe86316aec13ce0a07614ca Mon Sep 17 00:00:00 2001 From: jack-parsons <jack.parsons.uk@icloud.com> Date: Tue, 23 Jul 2019 13:29:32 +0100 Subject: [PATCH] Fixing and improving rotation test --- AmpScan/core.py | 10 ++++++++-- tests/core_tests.py | 46 +++++++++++++++++++++++++++++++++++---------- tests/test_trim.py | 3 ++- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/AmpScan/core.py b/AmpScan/core.py index 0e5e8f7..7463ebd 100644 --- a/AmpScan/core.py +++ b/AmpScan/core.py @@ -389,9 +389,15 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): """ if R is not None: - self.rotate(R, True) + if isinstance(R, (tuple, list, np.ndarray)): + self.rotate(R, True) + else: + raise TypeError("Expecting array-like rotation, but found: "+type(R)) if T is not None: - self.translate(T) + if isinstance(T, (tuple, list, np.ndarray)): + self.translate(T) + else: + raise TypeError("Expecting array-like translation, but found: "+type(T)) @staticmethod diff --git a/tests/core_tests.py b/tests/core_tests.py index 4af62ef..801878a 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -5,6 +5,7 @@ Testing suite for the core functionality import unittest import os import numpy as np +from random import randrange def suite(): @@ -45,13 +46,16 @@ class TestCore(unittest.TestCase): Tests the rotate method of AmpObject """ - # Test rotation on first node - rot = [np.pi/2, -np.pi/4, np.pi/3] - before_vert_pos = self.amp.vert[0][:] + # Test rotation on random node + n = randrange(len(self.amp.vert)) + rot = [0, 0, np.pi/3] + before = self.amp.vert[n].copy() self.amp.rotateAng(rot) - after_vert_pos = self.amp.vert[0][:] - np.dot(before_vert_pos, rot) - self.assertAlmostEqual(before_vert_pos, after_vert_pos, TestCore.ACCURACY) + after_vert_pos = self.amp.vert[n].copy() + # Use 2D rotation matrix formula to test rotate method on z axis + expected = [np.cos(rot[2])*before[0]-np.sin(rot[2])*before[1], np.sin(rot[2])*before[0]+np.cos(rot[2])*before[1], before[2]] + # Check all coordinate dimensions are correct + all(self.assertAlmostEqual(expected[i], after_vert_pos[i], TestCore.ACCURACY) for i in range(3)) # Check single floats cause TypeError with self.assertRaises(TypeError): @@ -67,15 +71,15 @@ class TestCore(unittest.TestCase): """ # Check that everything has been translated correctly to a certain accuracy - start = self.amp.vert.mean(axis=0)[:] + start = self.amp.vert.mean(axis=0).copy() self.amp.translate([1, -1, 0]) - end = self.amp.vert.mean(axis=0)[:] + end = self.amp.vert.mean(axis=0).copy() self.assertAlmostEqual(start[0]+1, end[0], places=TestCore.ACCURACY) self.assertAlmostEqual(start[1]-1, end[1], places=TestCore.ACCURACY) self.assertAlmostEqual(start[2], end[2], places=TestCore.ACCURACY) # Check that translating raises TypeError when translating with an invalid type - with self.assertRaises(Exception): + with self.assertRaises(TypeError): self.amp.translate("") # Check that translating raises ValueError when translating with 2 dimensions @@ -86,7 +90,29 @@ class TestCore(unittest.TestCase): with self.assertRaises(ValueError): self.amp.translate([0, 0, 0, 0]) - def get_path(self, filename): + def test_rigid_transform(self): + """ + Test the rigid transform method of AmpObject + """ + + # before_vert_pos = self.amp.vert[0][:] + # rot = [np.pi/6, -np.pi/2, np.pi/3] + # tran = [-1, 0, 1] + # self.amp.rigidTransform(R=rot, T=tran) + # after_vert_pos = self.amp.vert[0][:] + # np.dot(before_vert_pos, rot) + # self.assertAlmostEqual() + + # Check that translating raises TypeError when translating with an invalid type + with self.assertRaises(TypeError): + self.amp.rigidTransform(T=dict()) + + # Check that rotating raises TypeError when translating with an invalid type + with self.assertRaises(TypeError): + self.amp.rigidTransform(R=7) + + @staticmethod + def get_path(filename): """ Returns the absolute path to the testing files diff --git a/tests/test_trim.py b/tests/test_trim.py index 169bc41..82c49dd 100644 --- a/tests/test_trim.py +++ b/tests/test_trim.py @@ -30,7 +30,8 @@ class TestTrim(unittest.TestCase): with self.assertRaises(TypeError): self.amp.planarTrim([], plane=[]) - def get_path(self, filename): + @staticmethod + def get_path(filename): """ Returns the absolute path to the testing files -- GitLab