Skip to content
Snippets Groups Projects
Select Git revision
  • 29e8836cf4f6153315dc829ae43025ce02b95ccc
  • master default
  • Omar
  • Jack
4 results

test_core.py

Blame
  • test_core.py 6.80 KiB
    """
    Testing suite for the core module
    """
    
    import unittest
    import os
    import numpy as np
    from random import randrange
    
    
    class TestCore(unittest.TestCase):
        ACCURACY = 5  # The number of decimal places to value accuracy for - needed due to floating point inaccuracies
    
        def setUp(self):
            """
            Runs before each unit test
            Sets up the AmpObject object using "sample_stl_sphere_BIN.stl"
            """
            from AmpScan.core import AmpObject
            stl_path = self.get_path("sample_stl_sphere_BIN.stl")
            self.amp = AmpObject(stl_path)
    
        def test_centre(self):
            """Test the centre method of AmpObject"""
    
            # Translate the mesh
            self.amp.translate([1, 0, 0])
            # Recenter the mesh
            self.amp.centre()
            centre = self.amp.vert.mean(axis=0)
    
            # Check that the mesh is centred correctly (to at least the number of decimal places of ACCURACY)
            self.assertTrue(all(centre[i] < (10**-TestCore.ACCURACY) for i in range(3)))
    
        def test_rotate_ang(self):
            """Tests the rotateAng method of AmpObject"""
    
            # Test rotation on random node
            n = randrange(len(self.amp.vert))
            rot = [0, 0, np.pi/3]
            before = self.amp.vert[n].copy()
            self.amp.rotateAng(rot)
            after_vert_pos = self.amp.vert[n].copy()
            # Use 2D rotation matrix formula to test rotate method on z axis
            expected = [np.cos(rot[2])*before[0]-np.sin(rot[2])*before[1], np.sin(rot[2])*before[0]+np.cos(rot[2])*before[1], before[2]]
            # Check all coordinate dimensions are correct
            all(self.assertAlmostEqual(expected[i], after_vert_pos[i], TestCore.ACCURACY) for i in range(3))
    
            # Check single floats cause TypeError
            with self.assertRaises(TypeError):
                self.amp.rotateAng(7)
    
            # Check dictionaries cause TypeError
            with self.assertRaises(TypeError):
                self.amp.rotateAng(dict())
    
            # Tests that incorrect number of elements causes ValueError
            with self.assertRaises(ValueError):
                self.amp.rotateAng(rot, "test")
            with self.assertRaises(ValueError):
                self.amp.rotateAng(rot, [])
    
        def test_rotate(self):
            """Tests the rotate method of AmpObject"""
            # A test rotation and translation using list
            m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
            self.amp.rotate(m)
    
            # Check single floats cause TypeError
            with self.assertRaises(TypeError):
                self.amp.rotate(7)
    
            # Check dictionaries cause TypeError
            with self.assertRaises(TypeError):
                self.amp.rotate(dict())
    
            # Check invalid dimensions cause ValueError
            with self.assertRaises(ValueError):
                self.amp.rotate([])
            with self.assertRaises(ValueError):
                self.amp.rotate([[0, 0, 1]])
            with self.assertRaises(ValueError):
                self.amp.rotate([[], [], []])
    
        def test_translate(self):
            """Test translating method of AmpObject"""
    
            # Check that everything has been translated correctly to a certain accuracy
            start = self.amp.vert.mean(axis=0).copy()
            self.amp.translate([1, -1, 0])
            end = self.amp.vert.mean(axis=0).copy()
            self.assertAlmostEqual(start[0]+1, end[0], places=TestCore.ACCURACY)
            self.assertAlmostEqual(start[1]-1, end[1], places=TestCore.ACCURACY)
            self.assertAlmostEqual(start[2], end[2], places=TestCore.ACCURACY)
    
            # Check that translating raises TypeError when translating with an invalid type
            with self.assertRaises(TypeError):
                self.amp.translate("")
    
            # Check that translating raises ValueError when translating with 2 dimensions
            with self.assertRaises(ValueError):
                self.amp.translate([0, 0])
    
            # Check that translating raises ValueError when translating with 4 dimensions
            with self.assertRaises(ValueError):
                self.amp.translate([0, 0, 0, 0])
    
        def test_rigid_transform(self):
            """Test the rigid transform method of AmpObject"""
    
            # Test if no transform is applied, vertices aren't affected
            before_vert = self.amp.vert.copy()
            self.amp.rigidTransform(R=None, T=None)
            all(self.assertEqual(self.amp.vert[y][x], before_vert[y][x])
                for y in range(len(self.amp.vert))
                for x in range(len(self.amp.vert[0])))
    
            # A test rotation and translation
            m = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
            self.amp.rigidTransform(R=m, T=[1, 0, -1])
    
            # Check that translating raises TypeError when translating with an invalid type
            with self.assertRaises(TypeError):
                self.amp.rigidTransform(T=dict())
    
            # Check that rotating raises TypeError when translating with an invalid type
            with self.assertRaises(TypeError):
                self.amp.rigidTransform(R=7)
    
        def test_rot_matrix(self):
            """Tests the rotMatrix method in AmpObject"""
    
            # Tests that a transformation by 0 in all axis is 0 matrix
            all(self.amp.rotMatrix([0, 0, 0])[y][x] == 0
                for x in range(3)
                for y in range(3))
    
            expected = [[1, 0, 0], [0, np.sqrt(3)/2, 1/2], [0, -1/2, np.sqrt(3)/2]]
            all(self.amp.rotMatrix([np.pi/6, 0, 0])[y][x] == expected[y][x]
                for x in range(3)
                for y in range(3))
    
            # Tests that string passed into rot causes TypeError
            with self.assertRaises(TypeError):
                self.amp.rotMatrix(" ")
            with self.assertRaises(TypeError):
                self.amp.rotMatrix(dict())
    
            # Tests that incorrect number of elements causes ValueError
            with self.assertRaises(ValueError):
                self.amp.rotMatrix([0, 1])
            with self.assertRaises(ValueError):
                self.amp.rotMatrix([0, 1, 3, 0])
    
        def test_flip(self):
            """Tests the flip method in AmpObject"""
            # Check invalid axis types cause TypeError
            with self.assertRaises(TypeError):
                self.amp.flip(" ")
            with self.assertRaises(TypeError):
                self.amp.flip(dict())
    
            # Check invalid axis values cause ValueError
            with self.assertRaises(ValueError):
                self.amp.flip(-1)
            with self.assertRaises(ValueError):
                self.amp.flip(3)
    
    
        @staticmethod
        def get_path(filename):
            """
            Returns the absolute path to the testing files
    
            :param filename: Name of the file in tests folder
            :return: The absolute path to the file
            """
    
            # Check if the parent directory is tests (this is for Pycharm unittests)
            if os.path.basename(os.getcwd()) == "tests":
                # This is for Pycharm testing
                stl_path = filename
            else:
                # This is for the Gitlab testing
                stl_path = os.path.abspath(os.getcwd()) + "\\tests\\"+filename
            return stl_path