diff --git a/AmpScan/core.py b/AmpScan/core.py index 3fd4dcdd103c83e62cf2cf6a36112a4137f3dc04..0e5e8f7927e66edd011116925c4195ccdd324da8 100644 --- a/AmpScan/core.py +++ b/AmpScan/core.py @@ -316,7 +316,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): """ # Check that trans is array like - if isinstance(trans, (list, np.ndarray)): + if isinstance(trans, (list, np.ndarray, tuple)): # Check that trans has exactly 3 dimensions if len(trans) == 3: self.vert[:] += trans diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecadf745f3a5552524d7a87ebd7db0fd9b0ae12f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +from tests import * diff --git a/tests/basic_tests.py b/tests/basic_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..472616b926a9e2e2b0338965e70db9b4d9632c6d --- /dev/null +++ b/tests/basic_tests.py @@ -0,0 +1,37 @@ +import unittest +import os +import sys + + +def suite(): + return unittest.TestLoader().loadTestsFromTestCase(TestBasicFunction) + + +class TestBasicFunction(unittest.TestCase): + ACCURACY = 3 # The number of decimal places to value accuracy for + + def SetUp(self): + modPath = os.path.abspath(os.getcwd()) + sys.path.insert(0, modPath) + + def test_running(self): + print("Running sample_test.py") + self.assertTrue(True) + + def test_python_imports(self): + import numpy, scipy, matplotlib, vtk, AmpScan.core + s = str(type(numpy)) + self.assertEqual(s, "<class 'module'>") + s = str(type(scipy)) + self.assertEqual(s, "<class 'module'>") + s = str(type(matplotlib)) + self.assertEqual(s, "<class 'module'>") + s = str(type(vtk)) + self.assertEqual(s, "<class 'module'>") + s = str(type(AmpScan.core)) + self.assertEqual(s, "<class 'module'>", "Failed import: AmpScan.core") + + @unittest.expectedFailure + def test_failure(self): + s = str(type("string")) + self.assertEqual(s, "<class 'module'>") diff --git a/tests/core_tests.py b/tests/core_tests.py index 6ef33b086f346cc678170775a1b03c1709d16079..e8f883195d96bba9674425352c61433b33ffdd17 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -3,3 +3,92 @@ Testing suite for the core functionality """ +import unittest +import os +import sys + + +def suite(): + return unittest.TestLoader().loadTestsFromTestCase(TestCore) + + +class TestCore(unittest.TestCase): + ACCURACY = 3 # The number of decimal places to value accuracy for + + def setUp(self): + """ + Set up the AmpObject object from "sample_stl_sphere_BIN.stl" + """ + from AmpScan.core import AmpObject + stl_path = self.get_path("sample_stl_sphere_BIN.stl") + self.amp = AmpObject(stl_path) + + def test_centre(self): + """ + Test the centre method of AmpObject + """ + + # Translate the mesh + self.amp.translate([1, 0, 0]) + # Recenter the mesh + self.amp.centre() + centre = self.amp.vert.mean(axis=0) + + # Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY) + self.assertTrue(all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3))) + + def test_rotate(self): + s = str(type(self.amp)) + self.assertEqual(s, "<class 'AmpScan.core.AmpObject'>", "Not expected Object") + with self.assertRaises(TypeError): + self.amp.rotateAng(7) + self.amp.rotateAng({}) + + def test_trim(self): + # a new test for the trim module + stlPath = self.get_path("sample_stl_sphere_BIN.stl") + from AmpScan.core import AmpObject + Amp = AmpObject(stlPath) + #with self.assertRaises(TypeError): + #Amp.planarTrim([], plane=[]) + + def test_translate(self): + # Test translating method of AmpObject + + # Check that everything has been translated correctly to a certain accuracy + start = self.amp.vert.mean(axis=0)[:] + self.amp.translate([1, -1, 0]) + end = self.amp.vert.mean(axis=0)[:] + self.assertAlmostEqual(start[0], end[0]-1, places=TestCore.ACCURACY) + self.assertAlmostEqual(start[1], end[1]+1, places=TestCore.ACCURACY) + self.assertAlmostEqual(start[2], end[2], places=TestCore.ACCURACY) + + # Check that translating raises TypeError when translating with an invalid type + with self.assertRaises(Exception): + self.amp.translate("") + + # Check that translating raises ValueError when translating with 2 dimensions + with self.assertRaises(ValueError): + self.amp.translate([0, 0]) + + # Check that translating raises ValueError when translating with 4 dimensions + with self.assertRaises(ValueError): + self.amp.translate([0, 0, 0, 0]) + + def get_path(self, filename): + """ + Method to get the absolute path to the testing files + + :param filename: Name of the file in tests folder + :return: The absolute path to the file + """ + + # Check if the parent directory is tests (this is for Pycharm unittests) + if os.path.basename(os.getcwd()) == "tests": + # This is for Pycharm testing + stlPath = filename + else: + # This is for the Gitlab testing + stlPath = os.path.abspath(os.getcwd()) + "\\tests\\"+filename + return stlPath + diff --git a/tests/sample_test.py b/tests/sample_test.py index 68950ad892fd108f55cf22119b4769ec265d6daf..499b96dc74861d541eeb80a68a257a624ac883a6 100644 --- a/tests/sample_test.py +++ b/tests/sample_test.py @@ -1,114 +1,22 @@ import unittest import os import sys +import tests.core_tests +import tests.basic_tests -class TestBasicFunction(unittest.TestCase): - ACCURACY = 3 # The number of decimal places to value accuracy for +def suite(): + """ + Get all the unittests for the whole project + :return: The suite containing the test suites for each module + """ + s = unittest.TestSuite() + # Add the tests to the suite + s.addTest(tests.core_tests.suite()) + s.addTest(tests.basic_tests.suite()) + return s - def SetUp(self): - modPath = os.path.abspath(os.getcwd()) - sys.path.insert(0, modPath) - - def test_running(self): - print("Running sample_test.py") - self.assertTrue(True) - - def test_python_imports(self): - import numpy, scipy, matplotlib, vtk, AmpScan.core - s = str(type(numpy)) - self.assertEqual(s, "<class 'module'>") - s = str(type(scipy)) - self.assertEqual(s, "<class 'module'>") - s = str(type(matplotlib)) - self.assertEqual(s, "<class 'module'>") - s = str(type(vtk)) - self.assertEqual(s, "<class 'module'>") - s = str(type(AmpScan.core)) - self.assertEqual(s, "<class 'module'>", "Failed import: AmpScan.core") - - @unittest.expectedFailure - def test_failure(self): - s = str(type("string")) - self.assertEqual(s, "<class 'module'>") - - def test_rotate(self): - from AmpScan.core import AmpObject - stlPath = self.get_path("sample_stl_sphere_BIN.stl") - Amp = AmpObject(stlPath) - s = str(type(Amp)) - self.assertEqual(s, "<class 'AmpScan.core.AmpObject'>", "Not expected Object") - with self.assertRaises(TypeError): - Amp.rotateAng(7) - Amp.rotateAng({}) - - def test_trim(self): - # a new test for the trim module - stlPath = self.get_path("sample_stl_sphere_BIN.stl") - from AmpScan.core import AmpObject - Amp = AmpObject(stlPath) - #with self.assertRaises(TypeError): - #Amp.planarTrim([], plane=[]) - - def test_translate(self): - # Test translating method of AmpObject - - from AmpScan.core import AmpObject - stlPath = self.get_path("sample_stl_sphere_BIN.stl") - amp = AmpObject(stlPath) - - # Check that everything has been translated correctly to a certain accuracy - start = amp.vert.mean(axis=0)[:] - amp.translate([1, -1, 0]) - end = amp.vert.mean(axis=0)[:] - self.assertAlmostEqual(start[0], end[0]-1, places=TestBasicFunction.ACCURACY) - self.assertAlmostEqual(start[1], end[1]+1, places=TestBasicFunction.ACCURACY) - self.assertAlmostEqual(start[2], end[2], places=TestBasicFunction.ACCURACY) - - # Check that translating raises TypeError when translating with an invalid type - with self.assertRaises(Exception): - amp.translate("") - - # Check that translating raises ValueError when translating with 2 dimensions - with self.assertRaises(ValueError): - amp.translate([0, 0]) - - # Check that translating raises ValueError when translating with 4 dimensions - with self.assertRaises(ValueError): - amp.translate([0, 0, 0, 0]) - - def test_centre(self): - # Test the centre method of AmpObject - from AmpScan.core import AmpObject - stlPath = self.get_path("sample_stl_sphere_BIN.stl") - amp = AmpObject(stlPath) - - # Translate the mesh - amp.translate([1, 0, 0]) - # Recenter the mesh - amp.centre() - centre = amp.vert.mean(axis=0) - - # Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY) - self.assertTrue(all(centre[i] < (10**-TestBasicFunction.ACCURACY) for i in range(3))) - - def get_path(self, filename): - """ - Method to get the absolute path to the testing files - - :param filename: Name of the file in tests folder - :return: The absolute path to the file - """ - - # Check if the parent directory is tests (this is for Pycharm unittests) - if os.path.basename(os.getcwd()) == "tests": - # This is for Pycharm testing - stlPath = filename - else: - # This is for the Gitlab testing - stlPath = os.path.abspath(os.getcwd()) + "\\tests\\"+filename - return stlPath - if __name__ == '__main__': - unittest.main() + # Run the test suites + unittest.TextTestRunner().run(suite())