From a17e8f931035eebb338ac68a4c3a3f85d83b1213 Mon Sep 17 00:00:00 2001
From: jack-parsons <jack.parsons.uk@icloud.com>
Date: Mon, 22 Jul 2019 16:08:42 +0100
Subject: [PATCH] Adding test for translating (may want to move to separate
 file)

---
 AmpScan/core.py      | 11 ++++++++++-
 tests/sample_test.py | 27 +++++++++++++++++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/AmpScan/core.py b/AmpScan/core.py
index 1b64866..3fd4dcd 100644
--- a/AmpScan/core.py
+++ b/AmpScan/core.py
@@ -314,7 +314,16 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
             Translation in [x, y, z]
 
         """
-        self.vert[:] += trans
+
+        # Check that trans is array like
+        if isinstance(trans, (list, np.ndarray)):
+            # Check that trans has exactly 3 dimensions
+            if len(trans) == 3:
+                self.vert[:] += trans
+            else:
+                raise ValueError("Translation has incorrect dimensions. Expected 3 but found: " + str(len(trans)))
+        else:
+            raise TypeError("Translation is not array_like: " + trans)
 
     def centre(self):
         r"""
diff --git a/tests/sample_test.py b/tests/sample_test.py
index 5be234c..05e0071 100644
--- a/tests/sample_test.py
+++ b/tests/sample_test.py
@@ -2,7 +2,9 @@ import unittest
 import os
 import sys
 
+
 class TestBasicFunction(unittest.TestCase):
+    ACCURACY = 3  # The number of decimal places to value accuracy for
 
     def SetUp(self):
         modPath = os.path.abspath(os.getcwd())
@@ -48,6 +50,31 @@ class TestBasicFunction(unittest.TestCase):
         #with self.assertRaises(TypeError):
             #Amp.planarTrim([], plane=[])
 
+    def test_translate(self):
+        from AmpScan.core import AmpObject
+        stlPath = self.get_path("sample_stl_sphere_BIN.stl")
+        amp = AmpObject(stlPath)
+
+        # Check that everything has been translated by 1
+        start = amp.vert.mean(axis=0)[:]
+        amp.translate([1, -1, 0])
+        end = amp.vert.mean(axis=0)[:]
+        self.assertAlmostEqual(start[0], end[0]-1, places=TestBasicFunction.ACCURACY)
+        self.assertAlmostEqual(start[1], end[1]+1, places=TestBasicFunction.ACCURACY)
+        self.assertAlmostEqual(start[2], end[2], places=TestBasicFunction.ACCURACY)
+
+        # Check that translating raises TypeError when translating with an invalid type
+        with self.assertRaises(Exception):
+            amp.translate("")
+
+        # Check that translating raises ValueError when translating with 2 dimensions
+        with self.assertRaises(ValueError):
+            amp.translate([0, 0])
+
+        # Check that translating raises ValueError when translating with 4 dimensions
+        with self.assertRaises(ValueError):
+            amp.translate([0, 0, 0, 0])
+
     def get_path(self, filename):
         """
         Method to get the absolute path to the testing files
-- 
GitLab