Skip to content
Snippets Groups Projects
Commit b190d03b authored by jack-parsons's avatar jack-parsons
Browse files

Fixing and improving rotation test

parent 2ce1d57b
No related branches found
No related tags found
1 merge request!23Merge in Jack's changes
Pipeline #851 passed
...@@ -389,9 +389,15 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -389,9 +389,15 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
""" """
if R is not None: if R is not None:
self.rotate(R, True) if isinstance(R, (tuple, list, np.ndarray)):
self.rotate(R, True)
else:
raise TypeError("Expecting array-like rotation, but found: "+type(R))
if T is not None: if T is not None:
self.translate(T) if isinstance(T, (tuple, list, np.ndarray)):
self.translate(T)
else:
raise TypeError("Expecting array-like translation, but found: "+type(T))
@staticmethod @staticmethod
......
...@@ -5,6 +5,7 @@ Testing suite for the core functionality ...@@ -5,6 +5,7 @@ Testing suite for the core functionality
import unittest import unittest
import os import os
import numpy as np import numpy as np
from random import randrange
def suite(): def suite():
...@@ -45,13 +46,16 @@ class TestCore(unittest.TestCase): ...@@ -45,13 +46,16 @@ class TestCore(unittest.TestCase):
Tests the rotate method of AmpObject Tests the rotate method of AmpObject
""" """
# Test rotation on first node # Test rotation on random node
rot = [np.pi/2, -np.pi/4, np.pi/3] n = randrange(len(self.amp.vert))
before_vert_pos = self.amp.vert[0][:] rot = [0, 0, np.pi/3]
before = self.amp.vert[n].copy()
self.amp.rotateAng(rot) self.amp.rotateAng(rot)
after_vert_pos = self.amp.vert[0][:] after_vert_pos = self.amp.vert[n].copy()
np.dot(before_vert_pos, rot) # Use 2D rotation matrix formula to test rotate method on z axis
self.assertAlmostEqual(before_vert_pos, after_vert_pos, TestCore.ACCURACY) expected = [np.cos(rot[2])*before[0]-np.sin(rot[2])*before[1], np.sin(rot[2])*before[0]+np.cos(rot[2])*before[1], before[2]]
# Check all coordinate dimensions are correct
all(self.assertAlmostEqual(expected[i], after_vert_pos[i], TestCore.ACCURACY) for i in range(3))
# Check single floats cause TypeError # Check single floats cause TypeError
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
...@@ -67,15 +71,15 @@ class TestCore(unittest.TestCase): ...@@ -67,15 +71,15 @@ class TestCore(unittest.TestCase):
""" """
# Check that everything has been translated correctly to a certain accuracy # Check that everything has been translated correctly to a certain accuracy
start = self.amp.vert.mean(axis=0)[:] start = self.amp.vert.mean(axis=0).copy()
self.amp.translate([1, -1, 0]) self.amp.translate([1, -1, 0])
end = self.amp.vert.mean(axis=0)[:] end = self.amp.vert.mean(axis=0).copy()
self.assertAlmostEqual(start[0]+1, end[0], places=TestCore.ACCURACY) self.assertAlmostEqual(start[0]+1, end[0], places=TestCore.ACCURACY)
self.assertAlmostEqual(start[1]-1, end[1], places=TestCore.ACCURACY) self.assertAlmostEqual(start[1]-1, end[1], places=TestCore.ACCURACY)
self.assertAlmostEqual(start[2], end[2], places=TestCore.ACCURACY) self.assertAlmostEqual(start[2], end[2], places=TestCore.ACCURACY)
# Check that translating raises TypeError when translating with an invalid type # Check that translating raises TypeError when translating with an invalid type
with self.assertRaises(Exception): with self.assertRaises(TypeError):
self.amp.translate("") self.amp.translate("")
# Check that translating raises ValueError when translating with 2 dimensions # Check that translating raises ValueError when translating with 2 dimensions
...@@ -86,7 +90,29 @@ class TestCore(unittest.TestCase): ...@@ -86,7 +90,29 @@ class TestCore(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.amp.translate([0, 0, 0, 0]) self.amp.translate([0, 0, 0, 0])
def get_path(self, filename): def test_rigid_transform(self):
"""
Test the rigid transform method of AmpObject
"""
# before_vert_pos = self.amp.vert[0][:]
# rot = [np.pi/6, -np.pi/2, np.pi/3]
# tran = [-1, 0, 1]
# self.amp.rigidTransform(R=rot, T=tran)
# after_vert_pos = self.amp.vert[0][:]
# np.dot(before_vert_pos, rot)
# self.assertAlmostEqual()
# Check that translating raises TypeError when translating with an invalid type
with self.assertRaises(TypeError):
self.amp.rigidTransform(T=dict())
# Check that rotating raises TypeError when translating with an invalid type
with self.assertRaises(TypeError):
self.amp.rigidTransform(R=7)
@staticmethod
def get_path(filename):
""" """
Returns the absolute path to the testing files Returns the absolute path to the testing files
......
...@@ -30,7 +30,8 @@ class TestTrim(unittest.TestCase): ...@@ -30,7 +30,8 @@ class TestTrim(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.amp.planarTrim([], plane=[]) self.amp.planarTrim([], plane=[])
def get_path(self, filename): @staticmethod
def get_path(filename):
""" """
Returns the absolute path to the testing files Returns the absolute path to the testing files
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment