From 57b90d01bcbe0b4024f6be7369ebbf77b4ef1ecb Mon Sep 17 00:00:00 2001 From: Joshua Steer <joshua.w.steer@gmail.com> Date: Fri, 4 May 2018 17:58:44 +0100 Subject: [PATCH] Created registration as a seperate group of methods that use AmpObjs as inputs --- AmpScan/AmpScanGUI.py | 37 +++++++------- AmpScan/__init__.py | 2 +- AmpScan/ampVis.py | 4 +- AmpScan/core.py | 90 +++++++++++++++++++++------------ AmpScan/registration.py | 108 ++++++++++++++++++---------------------- 5 files changed, 125 insertions(+), 116 deletions(-) diff --git a/AmpScan/AmpScanGUI.py b/AmpScan/AmpScanGUI.py index 5f31c95..ee98335 100644 --- a/AmpScan/AmpScanGUI.py +++ b/AmpScan/AmpScanGUI.py @@ -1,7 +1,7 @@ import sys import numpy as np from .core import AmpObject -from .registration import regObject +from .registration import registration from .ampVis import qtVtkWindow from .pressSens import pressSense from PyQt5.QtCore import QPoint, QSize, Qt, QTimer, QRect, pyqtSignal @@ -19,6 +19,7 @@ class AmpScanGUI(QMainWindow): super(AmpScanGUI, self).__init__() self.vtkWidget = qtVtkWindow() self.renWin = self.vtkWidget._RenderWindow + self.renWin.setBackground() self.mainWidget = QWidget() self.AmpObj = None # self.CMap = np.array([[212.0, 221.0, 225.0], @@ -48,9 +49,9 @@ class AmpScanGUI(QMainWindow): def chooseSocket(self): self.sockfname = QFileDialog.getOpenFileName(self, 'Open file', filter="Meshes (*.stl)") - self.AmpObj.addData(self.sockfname[0], stype='socket') - self.AmpObj.addActor(stype='socket') - self.AmpObj.lp_smooth(stype='socket') + self.socket = AmpObject(self.sockfname[0], stype='socket') + self.socket.addActor() + self.socket.lp_smooth() def align(self): self.renWin.setnumViewports(2) @@ -60,26 +61,22 @@ class AmpScanGUI(QMainWindow): # self.renWin.render(self.AmpObj.actors, dispActors=['limb',]) # self.renWin.render(self.AmpObj.actors, dispActors=['socket',], # viewport=1) - self.renWin.renderActors(self.AmpObj.actors, - dispActors=['limb', 'socket'], - viewport=0) - self.renWin.renderActors(self.AmpObj.actors, - dispActors=['limb', 'socket'], - viewport=1) - self.AmpObj.actors['limb'].setColor([1.0, 0.0, 0.0]) - self.AmpObj.actors['limb'].setOpacity(0.5) - self.AmpObj.actors['socket'].setColor([0.0, 0.0, 1.0]) - self.AmpObj.actors['socket'].setOpacity(0.5) + self.renWin.renderActors([self.AmpObj.actor, self.socket.actor], + viewport=0) + self.renWin.renderActors([self.AmpObj.actor, self.socket.actor], + viewport=1) + self.AmpObj.actor.setColor([1.0, 0.0, 0.0]) + self.AmpObj.actor.setOpacity(0.5) + self.socket.actor.setColor([0.0, 0.0, 1.0]) + self.socket.actor.setOpacity(0.5) def register(self): self.renWin.setnumViewports(1) self.renWin.setProjection() - self.RegObj = regObject(self.AmpObj) - self.RegObj.registration(steps=5, baseline='socket', target='limb', - reg = 'reglimb', direct=True) - self.RegObj.addActor(stype='reglimb', CMap=self.AmpObj.CMapN2P) - self.renWin.renderActors(self.AmpObj.actors, ['reglimb',], shading=False) - self.renWin.setScalarBar(self.AmpObj.actors['reglimb']) + self.RegObj = registration(self.socket, self.AmpObj) + self.RegObj.addActor(CMap=self.AmpObj.CMapN2P) + self.renWin.renderActors([self.RegObj.actor,]) + self.renWin.setScalarBar(self.RegObj.actor) def analyse(self): self.RegObj.plot_slices() diff --git a/AmpScan/__init__.py b/AmpScan/__init__.py index b437ce0..6160639 100644 --- a/AmpScan/__init__.py +++ b/AmpScan/__init__.py @@ -6,6 +6,6 @@ Created on Thu Dec 15 13:50:41 2016 """ from .core import AmpObject -from .registration import regObject +from .registration import registration from .AmpScanGUI import AmpScanGUI from .socketDesignGUI import socketDesignGUI diff --git a/AmpScan/ampVis.py b/AmpScan/ampVis.py index 4f1e581..0cada59 100644 --- a/AmpScan/ampVis.py +++ b/AmpScan/ampVis.py @@ -119,7 +119,7 @@ class vtkRenWin(vtk.vtkRenderWindow): self.rens = self.rens[:n] elif dif > 0: for i in range(dif): - self.rens.append(vtkRender()) + self.rens.append(vtk.vtkRenderer()) self.axes.append(vtk.vtkCubeAxesActor()) self.AddRenderer(self.rens[-1]) if len(self.cams) < len(self.rens): @@ -239,7 +239,7 @@ class visMixin(object): self.actor.setValues(self.values) self.actor.setCMap(CMap, bands) self.actor.setScalarRange(sRange) - self.actor.Mapper.SetLookupTable(self.lut) + self.actor.Mapper.SetLookupTable(self.actor.lut) self.actor.setNorm() class ampActor(vtk.vtkActor): diff --git a/AmpScan/core.py b/AmpScan/core.py index 3b3620c..72b9c31 100644 --- a/AmpScan/core.py +++ b/AmpScan/core.py @@ -44,31 +44,36 @@ from .tsbSocketDesign import socketDesignMixin class AmpObject(trimMixin, smoothMixin, analyseMixin, visMixin, feMixin, socketDesignMixin): - def __init__(self, data, stype='limb'): - c1 = [31.0, 73.0, 125.0] - c3 = [170.0, 75.0, 65.0] - c2 = [212.0, 221.0, 225.0] - CMap1 = np.c_[[np.linspace(st, en) for (st, en) in zip(c1, c2)]] - CMap2 = np.c_[[np.linspace(st, en) for (st, en) in zip(c2, c3)]] - CMap = np.c_[CMap1[:, :-1], CMap2] - self.CMapN2P = np.transpose(CMap)/255.0 - self.CMap02P = np.flip(np.transpose(CMap1)/255.0, axis=0) + def __init__(self, data=None, stype='limb'): self.stype = stype self.values = None - if stype is 'FE': - self.addFE([Data,]) - else: - self.read_stl(data) + self.createCMap() + if isinstance(data, str): + if stype is 'FE': + self.addFE([Data,]) + else: + self.read_stl(data) + elif isinstance(data, dict): + for k, v in data.items(): + setattr(self, k, v) + self.calcStruct() def createCMap(self, cmap=None, n = 50): """ Function to generate a colormap for the AmpObj """ if cmap is None: - cmap = n + c1 = [31.0, 73.0, 125.0] + c3 = [170.0, 75.0, 65.0] + c2 = [212.0, 221.0, 225.0] + CMap1 = np.c_[[np.linspace(st, en) for (st, en) in zip(c1, c2)]] + CMap2 = np.c_[[np.linspace(st, en) for (st, en) in zip(c2, c3)]] + CMap = np.c_[CMap1[:, :-1], CMap2] + self.CMapN2P = np.transpose(CMap)/255.0 + self.CMap02P = np.flip(np.transpose(CMap1)/255.0, axis=0) - def read_stl(self, filename, unify=True, edges=True, vNorm=True): + def read_stl(self, filename, unify=True): """ Function to read .stl file from filename and import data into the AmpObj @@ -106,14 +111,24 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, self.norm = norm # Call function to unify vertices of the array if unify is True: - self.unify_vertices() + self.unifyVert() # Call function to calculate the edges array + self.calcStruct() + + def calcStruct(self, norm=True, edges=True, + edgeFaces=True, faceEdges=True, vNorm=True): + if norm is True: + self.calcNorm() if edges is True: - self.computeEdges() + self.calcEdges() + if edgeFaces is True: + self.calcEdgeFaces() + if faceEdges is True: + self.calcFaceEdges() if vNorm is True: self.calcVNorm() - def unify_vertices(self): + def unifyVert(self): """ Function to unify coincident vertices of the mesh to reduce size of the vertices array enabling speed increases @@ -124,7 +139,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, self.faces = np.resize(indC[self.faces], (len(self.norm), 3)).astype(np.int32) - def computeEdges(self): + def calcEdges(self): """ Function to compute the edges array, the edges on each face, and the faces on each edge @@ -139,12 +154,20 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, # Get edges array self.edges = np.reshape(self.faces[:, [0, 1, 0, 2, 1, 2]], [-1, 2]) self.edges = np.sort(self.edges, 1) - # Get edges on each face - self.edgesFace = np.reshape(range(len(self.faces)*3), [-1,3]) # Unify the edges self.edges, indC = np.unique(self.edges, return_inverse=True, axis=0) + + def calcEdgeFaces(self): + edges = np.reshape(self.faces[:, [0, 1, 0, 2, 1, 2]], [-1, 2]) + edges = np.sort(edges, 1) + # Unify the edges + edges, indC = np.unique(edges, return_inverse=True, axis=0) + # Get edges on each face + self.edgesFace = np.reshape(range(len(self.faces)*3), [-1,3]) #Remap the edgesFace array self.edgesFace = indC[self.edgesFace].astype(np.int32) + + def calcFaceEdges(self): #Initiate the faceEdges array self.faceEdges = np.empty([len(self.edges), 2], dtype=np.int32) self.faceEdges.fill(-99999) @@ -156,7 +179,19 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, logic = np.zeros([len(eF)], dtype=bool) logic[eFInd] = True self.faceEdges[eF[logic], 0] = fInd[logic] - self.faceEdges[eF[~logic], 1] = fInd[~logic] + self.faceEdges[eF[~logic], 1] = fInd[~logic] + + + def calcNorm(self): + """ + Calculate the normal of each face of the AmpObj + """ + norms = np.cross(self.vert[self.faces[:,1]] - + self.vert[self.faces[:,0]], + self.vert[self.faces[:,2]] - + self.vert[self.faces[:,0]]) + mag = np.linalg.norm(norms, axis=1) + self.norm = np.divide(norms, mag[:,None]) def calcVNorm(self): """ @@ -198,17 +233,6 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, data_write.tofile(fh) fh.close() - def calc_norm(self): - """ - Calculate the normal of each face of the AmpObj - """ - norms = np.cross(self.vert[self.faces[:,1]] - - self.vert[self.faces[:,0]], - self.vert[self.faces[:,2]] - - self.vert[self.faces[:,0]]) - mag = np.linalg.norm(norms, axis=1) - self.norm = np.divide(norms, mag[:,None]) - def translate(self, trans): """ Translate the AmpObj in 3D space diff --git a/AmpScan/registration.py b/AmpScan/registration.py index da56c68..c5db75e 100644 --- a/AmpScan/registration.py +++ b/AmpScan/registration.py @@ -6,6 +6,7 @@ Created on Wed Sep 13 16:07:10 2017 """ import numpy as np import pandas as pd +import copy from scipy import spatial from .core import AmpObject """ @@ -51,74 +52,61 @@ Child classes: FE mesh """ -class regObject(AmpObject): +def registration(baseline, target, method='default', steps=5, direct=True): - def __init__(self, Data=None, stype='AmpObj'): - super(regObject, self).__init__(Data, stype) - - def registration(self, steps=1, baseline='limb', - target='socket', reg = 'reglimb', direct=True): - """ - Function to register the regObject to the baseline mesh - - Parameters - ---------- - Steps: int, default 1 - Number of iterations - """ - bData = getattr(self, baseline) - tData = getattr(self, target) - bV = bData['vert'] - # Calculate the face centroids of the regObject - tData['fC'] = tData['vert'][tData['faces']].mean(axis=1) - # Construct knn tree - tTree = spatial.cKDTree(tData['fC']) - for step in np.arange(steps, 0, -1): - # Index of 10 centroids nearest to each baseline vertex - ind = tTree.query(bV, 10)[1] - D = np.zeros(bV.shape) - # Define normals for faces of nearest faces - norms = tData['norm'][ind] - # Get a point on each face - fPoints = tData['vert'][tData['faces'][ind, 0]] - # Calculate dot product between point on face and normals - d = np.einsum('ijk, ijk->ij', norms, fPoints) - t = d - np.einsum('ijk, ik->ij', norms, bV) - # Calculate new points - G = np.einsum('ijk, ij->ijk', norms, t) - GMag = np.sqrt(np.einsum('ijk, ijk->ij', G, G)).argmin(axis=1) - # Define vector from baseline point to intersect point - D = G[np.arange(len(G)), GMag, :] - bV = bV + D/step - regData = dict(bData) - regData['vert'] = bV - setattr(self, reg, regData) - self.calcError(baseline, reg, direct) - - - def calcError(self, baseline='limb', target='reglimb', direct=True): - # This is kinda slow - bData = getattr(self, baseline) - tData = getattr(self, target) + """ + Function to register the regObject to the baseline mesh + + Parameters + ---------- + Steps: int, default 1 + Number of iterations + """ + bV = baseline.vert + # Calc FaceCentroids + fC = target.vert[target.faces].mean(axis=1) + # Construct knn tree + tTree = spatial.cKDTree(fC) + for step in np.arange(steps, 0, -1): + # Index of 10 centroids nearest to each baseline vertex + ind = tTree.query(bV, 10)[1] + D = np.zeros(bV.shape) + # Define normals for faces of nearest faces + norms = target.norm[ind] + # Get a point on each face + fPoints = target.vert[target.faces[ind, 0]] + # Calculate dot product between point on face and normals + d = np.einsum('ijk, ijk->ij', norms, fPoints) + t = d - np.einsum('ijk, ik->ij', norms, bV) + # Calculate new points + G = np.einsum('ijk, ij->ijk', norms, t) + GMag = np.sqrt(np.einsum('ijk, ijk->ij', G, G)).argmin(axis=1) + # Define vector from baseline point to intersect point + D = G[np.arange(len(G)), GMag, :] + bV = bV + D/step + bData = dict(zip(['vert', 'faces'], [bV, baseline.faces])) + regObj = AmpObject(bData, stype='reg') + + def calcError(baseline, regObj, direct=True): if direct is True: - values = np.linalg.norm(tData['vert'] - bData['vert'], axis=1) - # Calculate vertex normals on target from normal of surrounding faces - vNorm = np.zeros(tData['vert'].shape) - for face, norm in zip(tData['faces'], tData['norm']): - vNorm[face, :] += norm - vNorm = vNorm / np.linalg.norm(vNorm, axis=1)[:, None] + values = np.linalg.norm(regObj.vert - baseline.vert, axis=1) # Calculate the unit vector normal between corresponding vertices # baseline and target - vector = (tData['vert'] - bData['vert'])/values[:, None] + vector = (regObj.vert - baseline.vert)/values[:, None] # Calculate angle between the two unit vectors using normal of cross # product between vNorm and vector and dot - normcrossP = np.linalg.norm(np.cross(vector, vNorm), axis=1) - dotP = np.einsum('ij,ij->i', vector, vNorm) + normcrossP = np.linalg.norm(np.cross(vector, target.vNorm), axis=1) + dotP = np.einsum('ij,ij->i', vector, target.vNorm) angle = np.arctan2(normcrossP, dotP) polarity = np.ones(angle.shape) polarity[angle < np.pi/2] =-1.0 - tData['values'] = values * polarity + values = values * polarity + return values else: - tData['values'] = np.linalg.norm(tData['vert'] - bData['vert'], - axis=1) + values = np.linalg.norm(regObj.vert - baseline.vert, axis=1) + return values + + regObj.values = calcError(baseline, regObj, False) + return regObj + -- GitLab