Commit 13398e31 authored by jp6g18's avatar jp6g18
Browse files

Merge branch 'Jack' into 'master'

Merge in Jack's changes

Closes #45 and #43

See merge request !23
parents 164f6938 7097056c
Pipeline #883 passed with stage
in 43 seconds
sample_job: #doctests and unitests:
script: python tests/sample_test.py # script: pytest --doctest-modules -v --ignore=GUIs
# Temporarily added back old testing
unittests:
script: python -m unittest discover tests -v
core doctests:
script: python -m doctest -v AmpScan/core.py
registration doctests:
script: python -m doctest -v AmpScan/registration.py
align doctests:
script: python -m doctest -v AmpScan/align.py
...@@ -10,8 +10,13 @@ import vtk ...@@ -10,8 +10,13 @@ import vtk
import math import math
from scipy import spatial from scipy import spatial
from scipy.optimize import minimize from scipy.optimize import minimize
from .core import AmpObject from AmpScan.core import AmpObject
from .ampVis import vtkRenWin from AmpScan.ampVis import vtkRenWin
# For doc examples
import os
staticfh = os.getcwd() + "\\tests\\stl_file.stl"
movingfh = os.getcwd() + "\\tests\\stl_file_2.stl"
class align(object): class align(object):
...@@ -41,9 +46,9 @@ class align(object): ...@@ -41,9 +46,9 @@ class align(object):
Examples Examples
-------- --------
>>> static = AmpScan.AmpObject(staticfh) >>> static = AmpObject(staticfh)
>>> moving = AmpScan.AmpObject(movingfh) >>> moving = AmpObject(movingfh)
>>> al = AmpScan.align(moving, static).m >>> al = align(moving, static).m
""" """
...@@ -176,9 +181,9 @@ class align(object): ...@@ -176,9 +181,9 @@ class align(object):
Examples Examples
-------- --------
>>> static = AmpScan.AmpObject(staticfh) >>> static = AmpObject(staticfh)
>>> moving = AmpScan.AmpObject(movingfh) >>> moving = AmpObject(movingfh)
>>> al = AmpScan.align(moving, static, method='linPoint2Plane').m >>> al = align(moving, static, method='linPoint2Plane').m
""" """
cn = np.c_[np.cross(mv, sn), sn] cn = np.c_[np.cross(mv, sn), sn]
...@@ -229,9 +234,9 @@ class align(object): ...@@ -229,9 +234,9 @@ class align(object):
Examples Examples
-------- --------
>>> static = AmpScan.AmpObject(staticfh) >>> static = AmpObject(staticfh)
>>> moving = AmpScan.AmpObject(movingfh) >>> moving = AmpObject(movingfh)
>>> al = AmpScan.align(moving, static, method='linPoint2Point').m >>> al = align(moving, static, method='linPoint2Point').m
""" """
mCent = mv - mv.mean(axis=0) mCent = mv - mv.mean(axis=0)
...@@ -271,9 +276,9 @@ class align(object): ...@@ -271,9 +276,9 @@ class align(object):
Examples Examples
-------- --------
>>> static = AmpScan.AmpObject(staticfh) >>> static = AmpObject(staticfh)
>>> moving = AmpScan.AmpObject(movingfh) >>> moving = AmpObject(movingfh)
>>> al = AmpScan.align(moving, static, method='optPoint2Point', opt='SLSQP').m >>> al = align(moving, static, method='optPoint2Point', opt='SLSQP').m
""" """
X = np.zeros(6) X = np.zeros(6)
......
...@@ -6,11 +6,17 @@ Copyright: Joshua Steer 2018, Joshua.Steer@soton.ac.uk ...@@ -6,11 +6,17 @@ Copyright: Joshua Steer 2018, Joshua.Steer@soton.ac.uk
""" """
import numpy as np import numpy as np
import os
import struct import struct
from .trim import trimMixin from AmpScan.trim import trimMixin
from .smooth import smoothMixin from AmpScan.smooth import smoothMixin
from .analyse import analyseMixin from AmpScan.analyse import analyseMixin
from .ampVis import visMixin from AmpScan.ampVis import visMixin
# The file path used in doc examples
filename = os.getcwd()+"\\tests\\stl_file.stl"
class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
r""" r"""
...@@ -36,8 +42,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -36,8 +42,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
Examples Examples
------- -------
>>> fh = 'test.stl' >>> amp = AmpObject(filename)
>>> amp = AmpScan.AmpObject(fh)
""" """
...@@ -149,13 +154,12 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -149,13 +154,12 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
Examples Examples
-------- --------
>>> fh = 'test.stl' >>> amp = AmpObject(filename, unify=False)
>>> amp = AmpObject(fh, unify=False)
>>> amp.vert.shape >>> amp.vert.shape
(600, 3) (44832, 3)
>>> amp.unifyVert() >>> amp.unifyVert()
>>> amp.vert.shape >>> amp.vert.shape
(125, 3) (7530, 3)
""" """
# Requires numpy 1.13 # Requires numpy 1.13
...@@ -314,7 +318,16 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -314,7 +318,16 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
Translation in [x, y, z] Translation in [x, y, z]
""" """
self.vert[:] += trans
# Check that trans is array like
if isinstance(trans, (list, np.ndarray, tuple)):
# Check that trans has exactly 3 dimensions
if len(trans) == 3:
self.vert[:] += trans
else:
raise ValueError("Translation has incorrect dimensions. Expected 3 but found: " + str(len(trans)))
else:
raise TypeError("Translation is not array_like: " + trans)
def centre(self): def centre(self):
r""" r"""
...@@ -337,10 +350,15 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -337,10 +350,15 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
Examples Examples
-------- --------
>>> amp = AmpObject('test.stl') >>> amp = AmpObject(filename)
>>> ang = [np.pi/2, -np.pi/4, np.pi/3] >>> ang = [np.pi/2, -np.pi/4, np.pi/3]
>>> amp.rotateAng(ang, ang='rad') >>> amp.rotateAng(ang, ang='rad')
""" """
# Check that ang is valid
if ang not in ('rad', 'deg'):
raise ValueError("Ang expected 'rad' or 'deg' but {} was found".format(ang))
if isinstance(rot, (tuple, list, np.ndarray)): if isinstance(rot, (tuple, list, np.ndarray)):
R = self.rotMatrix(rot, ang) R = self.rotMatrix(rot, ang)
self.rotate(R, norms) self.rotate(R, norms)
...@@ -359,6 +377,18 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -359,6 +377,18 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
norms: boolean, default True norms: boolean, default True
""" """
if isinstance(R, (list, tuple)):
# Make R a np array if its a list or tuple
R = np.array(R, np.float)
elif not isinstance(R, np.ndarray):
# If
raise TypeError("Expected R to be array-like but found: " + str(type(R)))
if len(R) != 3 or len(R[0]) != 3:
# Incorrect dimensions
if isinstance(R, np.ndarray):
raise ValueError("Expected 3x3 array, but found: {}".format(R.shape))
else:
raise ValueError("Expected 3x3 array, but found: 3x"+str(len(R)))
self.vert[:, :] = np.dot(self.vert, R.T) self.vert[:, :] = np.dot(self.vert, R.T)
if norms is True: if norms is True:
self.norm[:, :] = np.dot(self.norm, R.T) self.norm[:, :] = np.dot(self.norm, R.T)
...@@ -380,9 +410,15 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -380,9 +410,15 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
""" """
if R is not None: if R is not None:
self.rotate(R, True) if isinstance(R, (tuple, list, np.ndarray)):
self.rotate(R, True)
else:
raise TypeError("Expecting array-like rotation, but found: "+type(R))
if T is not None: if T is not None:
self.translate(T) if isinstance(T, (tuple, list, np.ndarray)):
self.translate(T)
else:
raise TypeError("Expecting array-like translation, but found: "+type(T))
@staticmethod @staticmethod
...@@ -396,7 +432,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -396,7 +432,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
rot: array_like rot: array_like
Rotation around [x, y, z] Rotation around [x, y, z]
ang: str, default 'rad' ang: str, default 'rad'
Specift if the Euler angles are in degrees or radians Specify if the Euler angles are in degrees or radians
Returns Returns
------- -------
...@@ -404,8 +440,20 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -404,8 +440,20 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
The calculated 3x3 rotation matrix The calculated 3x3 rotation matrix
""" """
# Check that rot is valid
if not isinstance(rot, (tuple, list, np.ndarray)):
raise TypeError("Expecting array-like rotation, but found: "+type(rot))
elif len(rot) != 3:
raise ValueError("Expecting 3 arguments but found: {}".format(len(rot)))
# Check that ang is valid
if ang not in ('rad', 'deg'):
raise ValueError("Ang expected 'rad' or 'deg' but {} was found".format(ang))
if ang == 'deg': if ang == 'deg':
rot = np.deg2rad(rot) rot = np.deg2rad(rot)
[angx, angy, angz] = rot [angx, angy, angz] = rot
Rx = np.array([[1, 0, 0], Rx = np.array([[1, 0, 0],
[0, np.cos(angx), -np.sin(angx)], [0, np.cos(angx), -np.sin(angx)],
...@@ -429,8 +477,14 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin): ...@@ -429,8 +477,14 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin):
The axis in which to flip the mesh The axis in which to flip the mesh
""" """
self.vert[:, axis] *= -1.0 if isinstance(axis, int):
# Switch face order to normals face same direction if 0 <= axis < 3: # Check axis is between 0-2
self.faces[:, [1, 2]] = self.faces[:, [2, 1]] self.vert[:, axis] *= -1.0
self.calcNorm() # Switch face order to normals face same direction
self.calcVNorm() self.faces[:, [1, 2]] = self.faces[:, [2, 1]]
self.calcNorm()
self.calcVNorm()
else:
raise ValueError("Expected axis to be within range 0-2 but found: {}".format(axis))
else:
raise TypeError("Expected axis to be int, but found: {}".format(type(axis)))
...@@ -6,9 +6,14 @@ Copyright: Joshua Steer 2018, Joshua.Steer@soton.ac.uk ...@@ -6,9 +6,14 @@ Copyright: Joshua Steer 2018, Joshua.Steer@soton.ac.uk
import numpy as np import numpy as np
import copy import copy
from scipy import spatial from scipy import spatial
from .core import AmpObject from AmpScan.core import AmpObject
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# For the doc examples
import os
basefh = os.getcwd()+"\\tests\\stl_file.stl"
targfh = os.getcwd()+"\\tests\\stl_file_2.stl"
class registration(object): class registration(object):
r""" r"""
Registration methods between two AmpObject meshes. This function morphs the baseline Registration methods between two AmpObject meshes. This function morphs the baseline
...@@ -36,9 +41,10 @@ class registration(object): ...@@ -36,9 +41,10 @@ class registration(object):
Examples Examples
-------- --------
>>> baseline = AmpScan.AmpObject(basefh) >>> from AmpScan.core import AmpObject
>>> target = AmpScan.AmpObject(targfh) >>> baseline = AmpObject(basefh)
>>> reg = AmpScan.registration(steps=10, neigh=10, smooth=1).reg >>> target = AmpObject(targfh)
>>> reg = registration(baseline, target, steps=10, neigh=10, smooth=1).reg
""" """
def __init__(self, baseline, target, method='point2plane', *args, **kwargs): def __init__(self, baseline, target, method='point2plane', *args, **kwargs):
......
...@@ -52,7 +52,7 @@ class smoothMixin(object): ...@@ -52,7 +52,7 @@ class smoothMixin(object):
def smoothValues(self, n=1): def smoothValues(self, n=1):
""" """
Function to apply a simple laplacian smooth to the values array. Function to apply a simple laplacian smooth to the values array.
Identical to the vertex smoothing expect it applies the smoothing Identical to the vertex smoothing except it applies the smoothing
to the values to the values
Parameters Parameters
......
...@@ -19,20 +19,20 @@ class pca(object): ...@@ -19,20 +19,20 @@ class pca(object):
Examples Examples
-------- --------
>>> import os
>>> p = pca() >>> p = pca()
>>> p.importFolder('/path/') >>> p.importFolder(os.getcwd()+"\\tests\\pca_tests")
>>> p.baseline('dir/baselinefh.stl') >>> p.setBaseline(os.getcwd()+"\\tests\\stl_file_3.stl")
>>> p.register(save = '/regpath/') >>> p.register(save=os.getcwd()+"\\tests\\pca_tests\\")
>>> p.pca() >>> p.pca()
>>> sfs = [0, 0.1, -0.5 ... 0] >>> sfs = [1, 2]
>>> newS = p.newShape(sfs) >>> newS = p.newShape(sfs)
""" """
def __init__(self): def __init__(self):
self.shapes = [] self.shapes = []
def setBaseline(self, baseline): def setBaseline(self, baseline):
r""" r"""
Function to set the baseline mesh used for registration of the Function to set the baseline mesh used for registration of the
...@@ -45,8 +45,7 @@ class pca(object): ...@@ -45,8 +45,7 @@ class pca(object):
""" """
self.baseline = AmpObject(baseline, 'limb') self.baseline = AmpObject(baseline, 'limb')
def importFolder(self, path, unify=True): def importFolder(self, path, unify=True):
r""" r"""
Function to import multiple stl files from folder into the pca object Function to import multiple stl files from folder into the pca object
...@@ -64,7 +63,7 @@ class pca(object): ...@@ -64,7 +63,7 @@ class pca(object):
self.shapes = [AmpObject(os.path.join(path, f), 'limb', unify=unify) for f in self.fnames] self.shapes = [AmpObject(os.path.join(path, f), 'limb', unify=unify) for f in self.fnames]
for s in self.shapes: for s in self.shapes:
s.lp_smooth(3, brim=True) s.lp_smooth(3, brim=True)
def sliceFiles(self, height): def sliceFiles(self, height):
r""" r"""
Function to run a planar trim on all the training data for the PCA Function to run a planar trim on all the training data for the PCA
...@@ -78,7 +77,7 @@ class pca(object): ...@@ -78,7 +77,7 @@ class pca(object):
""" """
for s in self.shapes: for s in self.shapes:
s.planarTrim(height) s.planarTrim(height)
def register(self, scale=None, save=None, baseline=True): def register(self, scale=None, save=None, baseline=True):
r""" r"""
Register all the AmpObject training data to the baseline AmpObject Register all the AmpObject training data to the baseline AmpObject
...@@ -105,8 +104,7 @@ class pca(object): ...@@ -105,8 +104,7 @@ class pca(object):
self.X = np.array([r.vert.flatten() for r in self.registered]).T self.X = np.array([r.vert.flatten() for r in self.registered]).T
if baseline is True: if baseline is True:
self.X = np.c_[self.X, self.baseline.vert.flatten()] self.X = np.c_[self.X, self.baseline.vert.flatten()]
def pca(self): def pca(self):
r""" r"""
Function to run mean centered pca using a singular value decomposition Function to run mean centered pca using a singular value decomposition
...@@ -118,8 +116,8 @@ class pca(object): ...@@ -118,8 +116,8 @@ class pca(object):
(self.pca_U, self.pca_S, self.pca_V) = np.linalg.svd(X_meanC, full_matrices=False) (self.pca_U, self.pca_S, self.pca_V) = np.linalg.svd(X_meanC, full_matrices=False)
self.pc_weights = np.dot(np.diag(self.pca_S), self.pca_V) self.pc_weights = np.dot(np.diag(self.pca_S), self.pca_V)
self.pc_stdevs = np.std(self.pc_weights, axis=1) self.pc_stdevs = np.std(self.pc_weights, axis=1)
def newShape(self, sfs, scale = 'eigs'): def newShape(self, sfs, scale='eigs'):
r""" r"""
Function to calculate a new shape based upon the eigenvalues Function to calculate a new shape based upon the eigenvalues
or stdevs or stdevs
...@@ -134,11 +132,15 @@ class pca(object): ...@@ -134,11 +132,15 @@ class pca(object):
to standard deviations about the mean to standard deviations about the mean
""" """
try: len(sfs) == len(self.pc_stdevs) if not isinstance(sfs, (list, tuple, np.ndarray)):
except: ValueError('sfs must be of the same length as the number of ' raise TypeError('sfs is invalid type (expected array-like, found: {}'.format(type(sfs)))
'principal components') if len(sfs) != len(self.pc_stdevs):
raise ValueError('sfs must be of the same length as the number of '
'principal components (expected {} but found {})'.format(len(self.pc_stdevs), len(sfs)))
if scale == 'eigs': if scale == 'eigs':
sf = (self.pca_U * sfs).sum(axis=1) sf = (self.pca_U * sfs).sum(axis=1)
elif scale == 'std': elif scale == 'std':
sf = (self.pca_U * self.pc_stdevs * sfs).sum(axis=1) sf = (self.pca_U * self.pc_stdevs * sfs).sum(axis=1)
else:
raise ValueError("Invalid scale (expected 'eigs' or 'std' but found{}".format(scale))
return self.pca_mean + sf return self.pca_mean + sf
...@@ -5,6 +5,12 @@ Copyright: Joshua Steer 2018, Joshua.Steer@soton.ac.uk ...@@ -5,6 +5,12 @@ Copyright: Joshua Steer 2018, Joshua.Steer@soton.ac.uk
""" """
import numpy as np import numpy as np
from numbers import Number
import os
# Used by doc tests
filename = os.getcwd() + "\\tests\\stl_file.stl"
class trimMixin(object): class trimMixin(object):
r""" r"""
...@@ -26,30 +32,32 @@ class trimMixin(object): ...@@ -26,30 +32,32 @@ class trimMixin(object):
Examples Examples
-------- --------
>>> amp = AmpObject(fh)
>>> from AmpScan import AmpObject
>>> amp = AmpObject(filename)
>>> amp.planarTrim(100, 2) >>> amp.planarTrim(100, 2)
""" """
# if isinstance(height, float): if isinstance(height, Number) and isinstance(plane, int):
# planar values for each vert on face # planar values for each vert on face
fv = self.vert[self.faces, plane] fv = self.vert[self.faces, plane]
# Number points on each face are above cut plane # Number points on each face are above cut plane
fvlogic = (fv > height).sum(axis=1) fvlogic = (fv > height).sum(axis=1)
# Faces with points both above and below cut plane # Faces with points both above and below cut plane
adjf = self.faces[np.logical_or(fvlogic == 2, fvlogic == 1)] adjf = self.faces[np.logical_or(fvlogic == 2, fvlogic == 1)]
# Get adjacent vertices # Get adjacent vertices
adjv = np.unique(adjf) adjv = np.unique(adjf)
# Get vert above height and set to height # Get vert above height and set to height
abvInd = adjv[self.vert[adjv, plane] > height] abvInd = adjv[self.vert[adjv, plane] > height]
self.vert[abvInd, plane] = height self.vert[abvInd, plane] = height
# Find all verts above plane # Find all verts above plane
delv = self.vert[:, plane] > height delv = self.vert[:, plane] > height
# Reorder verts to account for deleted one # Reorder verts to account for deleted one
vInd = np.cumsum(~delv) - 1 vInd = np.cumsum(~delv) - 1
self.faces = self.faces[fvlogic != 3, :] self.faces = self.faces[fvlogic != 3, :]
self.faces = vInd[self.faces] self.faces = vInd[self.faces]
self.vert = self.vert[~delv, :] self.vert = self.vert[~delv, :]
self.values = self.values[~delv] self.values = self.values[~delv]
self.calcStruct() self.calcStruct()
# else: else:
# raise TypeError("height arg must be a float") raise TypeError("height arg must be a float")
\ No newline at end of file
...@@ -10,9 +10,9 @@ from PyQt5.QtGui import (QColor, QFontMetrics, QImage, QPainter, QIcon, ...@@ -10,9 +10,9 @@ from PyQt5.QtGui import (QColor, QFontMetrics, QImage, QPainter, QIcon,
QOpenGLVersionProfile) QOpenGLVersionProfile)
from PyQt5.QtWidgets import (QAction, QApplication, QGridLayout, QHBoxLayout, from PyQt5.QtWidgets import (QAction, QApplication, QGridLayout, QHBoxLayout,
QMainWindow, QMessageBox, QComboBox, QButtonGroup, QMainWindow, QMessageBox, QComboBox, QButtonGroup,
QOpenGLWidget, QFileDialog,QLabel,QPushButton, QOpenGLWidget, QFileDialog, QLabel, QPushButton,
QSlider, QWidget, QTableWidget, QTableWidgetItem, QSlider, QWidget, QTableWidget, QTableWidgetItem,
QAbstractButton) QAbstractButton, QErrorMessage)
class AmpScanGUI(QMainWindow): class AmpScanGUI(QMainWindow):
...@@ -107,16 +107,19 @@ class AmpScanGUI(QMainWindow): ...@@ -107,16 +107,19 @@ class AmpScanGUI(QMainWindow):
Numpy style docstring. Numpy style docstring.
""" """
self.alCont = AlignControls(self.filesDrop, self) if self.objectsReady(1):
self.alCont.show() self.alCont = AlignControls(self.filesDrop, self)
self.alCont.centre.clicked.connect(self.centreMesh) self.alCont.show()
self.alCont.icp.clicked.connect(self.runICP) self.alCont.centre.clicked.connect(self.centreMesh)
self.alCont.xrotButton.buttonClicked[QAbstractButton].connect(self.rotatex) self.alCont.icp.clicked.connect(self.runICP)
self.alCont.yrotButton.buttonClicked[QAbstractButton].connect(self.rotatey) self.alCont.xrotButton.buttonClicked[QAbstractButton].connect(self.rotatex)
self.alCont.zrotButton.buttonClicked[QAbstractButton].connect(self.rotatez) self.alCont.yrotButton.buttonClicked[QAbstractButton].connect(self.rotatey)
self.alCont.xtraButton.buttonClicked[QAbstractButton].connect(self.transx)