From 57b90d01bcbe0b4024f6be7369ebbf77b4ef1ecb Mon Sep 17 00:00:00 2001
From: Joshua Steer <joshua.w.steer@gmail.com>
Date: Fri, 4 May 2018 17:58:44 +0100
Subject: [PATCH] Created registration as a seperate group of methods that use
 AmpObjs as inputs

---
 AmpScan/AmpScanGUI.py   |  37 +++++++-------
 AmpScan/__init__.py     |   2 +-
 AmpScan/ampVis.py       |   4 +-
 AmpScan/core.py         |  90 +++++++++++++++++++++------------
 AmpScan/registration.py | 108 ++++++++++++++++++----------------------
 5 files changed, 125 insertions(+), 116 deletions(-)

diff --git a/AmpScan/AmpScanGUI.py b/AmpScan/AmpScanGUI.py
index 5f31c95..ee98335 100644
--- a/AmpScan/AmpScanGUI.py
+++ b/AmpScan/AmpScanGUI.py
@@ -1,7 +1,7 @@
 import sys
 import numpy as np
 from .core import AmpObject
-from .registration import regObject
+from .registration import registration
 from .ampVis import qtVtkWindow
 from .pressSens import pressSense
 from PyQt5.QtCore import QPoint, QSize, Qt, QTimer, QRect, pyqtSignal
@@ -19,6 +19,7 @@ class AmpScanGUI(QMainWindow):
         super(AmpScanGUI, self).__init__()
         self.vtkWidget = qtVtkWindow()
         self.renWin = self.vtkWidget._RenderWindow
+        self.renWin.setBackground()
         self.mainWidget = QWidget()
         self.AmpObj = None
 #        self.CMap = np.array([[212.0, 221.0, 225.0],
@@ -48,9 +49,9 @@ class AmpScanGUI(QMainWindow):
     def chooseSocket(self):
         self.sockfname = QFileDialog.getOpenFileName(self, 'Open file',
                                             filter="Meshes (*.stl)")
-        self.AmpObj.addData(self.sockfname[0], stype='socket')
-        self.AmpObj.addActor(stype='socket')
-        self.AmpObj.lp_smooth(stype='socket')
+        self.socket = AmpObject(self.sockfname[0], stype='socket')
+        self.socket.addActor()
+        self.socket.lp_smooth()
         
     def align(self):
         self.renWin.setnumViewports(2)
@@ -60,26 +61,22 @@ class AmpScanGUI(QMainWindow):
 #        self.renWin.render(self.AmpObj.actors, dispActors=['limb',])
 #        self.renWin.render(self.AmpObj.actors, dispActors=['socket',],
 #                              viewport=1)
-        self.renWin.renderActors(self.AmpObj.actors,
-                              dispActors=['limb', 'socket'],
-                              viewport=0)
-        self.renWin.renderActors(self.AmpObj.actors,
-                              dispActors=['limb', 'socket'],
-                              viewport=1)
-        self.AmpObj.actors['limb'].setColor([1.0, 0.0, 0.0])
-        self.AmpObj.actors['limb'].setOpacity(0.5)
-        self.AmpObj.actors['socket'].setColor([0.0, 0.0, 1.0])
-        self.AmpObj.actors['socket'].setOpacity(0.5)
+        self.renWin.renderActors([self.AmpObj.actor, self.socket.actor],
+                                 viewport=0)
+        self.renWin.renderActors([self.AmpObj.actor, self.socket.actor],
+                                 viewport=1)
+        self.AmpObj.actor.setColor([1.0, 0.0, 0.0])
+        self.AmpObj.actor.setOpacity(0.5)
+        self.socket.actor.setColor([0.0, 0.0, 1.0])
+        self.socket.actor.setOpacity(0.5)
         
     def register(self):
         self.renWin.setnumViewports(1)
         self.renWin.setProjection()
-        self.RegObj = regObject(self.AmpObj)
-        self.RegObj.registration(steps=5, baseline='socket', target='limb', 
-                                 reg = 'reglimb', direct=True)
-        self.RegObj.addActor(stype='reglimb', CMap=self.AmpObj.CMapN2P)
-        self.renWin.renderActors(self.AmpObj.actors, ['reglimb',], shading=False)
-        self.renWin.setScalarBar(self.AmpObj.actors['reglimb'])
+        self.RegObj = registration(self.socket, self.AmpObj)
+        self.RegObj.addActor(CMap=self.AmpObj.CMapN2P)
+        self.renWin.renderActors([self.RegObj.actor,])
+        self.renWin.setScalarBar(self.RegObj.actor)
     
     def analyse(self):
         self.RegObj.plot_slices()
diff --git a/AmpScan/__init__.py b/AmpScan/__init__.py
index b437ce0..6160639 100644
--- a/AmpScan/__init__.py
+++ b/AmpScan/__init__.py
@@ -6,6 +6,6 @@ Created on Thu Dec 15 13:50:41 2016
 """
 
 from .core import AmpObject
-from .registration import regObject
+from .registration import registration
 from .AmpScanGUI import AmpScanGUI
 from .socketDesignGUI import socketDesignGUI
diff --git a/AmpScan/ampVis.py b/AmpScan/ampVis.py
index 4f1e581..0cada59 100644
--- a/AmpScan/ampVis.py
+++ b/AmpScan/ampVis.py
@@ -119,7 +119,7 @@ class vtkRenWin(vtk.vtkRenderWindow):
             self.rens = self.rens[:n]
         elif dif > 0:
             for i in range(dif):
-                self.rens.append(vtkRender())
+                self.rens.append(vtk.vtkRenderer())
                 self.axes.append(vtk.vtkCubeAxesActor())
                 self.AddRenderer(self.rens[-1])
                 if len(self.cams) < len(self.rens):
@@ -239,7 +239,7 @@ class visMixin(object):
             self.actor.setValues(self.values)
             self.actor.setCMap(CMap, bands)
             self.actor.setScalarRange(sRange)
-            self.actor.Mapper.SetLookupTable(self.lut)
+            self.actor.Mapper.SetLookupTable(self.actor.lut)
         self.actor.setNorm()
 
     class ampActor(vtk.vtkActor):
diff --git a/AmpScan/core.py b/AmpScan/core.py
index 3b3620c..72b9c31 100644
--- a/AmpScan/core.py
+++ b/AmpScan/core.py
@@ -44,31 +44,36 @@ from .tsbSocketDesign import socketDesignMixin
 class AmpObject(trimMixin, smoothMixin, analyseMixin, 
                 visMixin, feMixin, socketDesignMixin):
 
-    def __init__(self, data, stype='limb'):
-        c1 = [31.0, 73.0, 125.0]
-        c3 = [170.0, 75.0, 65.0]
-        c2 = [212.0, 221.0, 225.0]
-        CMap1 = np.c_[[np.linspace(st, en) for (st, en) in zip(c1, c2)]]
-        CMap2 = np.c_[[np.linspace(st, en) for (st, en) in zip(c2, c3)]]
-        CMap = np.c_[CMap1[:, :-1], CMap2]
-        self.CMapN2P = np.transpose(CMap)/255.0
-        self.CMap02P = np.flip(np.transpose(CMap1)/255.0, axis=0)
+    def __init__(self, data=None, stype='limb'):
         self.stype = stype
         self.values = None
-        if stype is 'FE':
-            self.addFE([Data,])
-        else:
-            self.read_stl(data)
+        self.createCMap()
+        if isinstance(data, str):    
+            if stype is 'FE':
+                self.addFE([Data,])
+            else:
+                self.read_stl(data)
+        elif isinstance(data, dict):
+            for k, v in data.items():
+                setattr(self, k, v)
+            self.calcStruct()
     
     def createCMap(self, cmap=None, n = 50):
         """
         Function to generate a colormap for the AmpObj
         """
         if cmap is None:
-            cmap = n
+            c1 = [31.0, 73.0, 125.0]
+            c3 = [170.0, 75.0, 65.0]
+            c2 = [212.0, 221.0, 225.0]
+            CMap1 = np.c_[[np.linspace(st, en) for (st, en) in zip(c1, c2)]]
+            CMap2 = np.c_[[np.linspace(st, en) for (st, en) in zip(c2, c3)]]
+            CMap = np.c_[CMap1[:, :-1], CMap2]
+            self.CMapN2P = np.transpose(CMap)/255.0
+            self.CMap02P = np.flip(np.transpose(CMap1)/255.0, axis=0)
 
 
-    def read_stl(self, filename, unify=True, edges=True, vNorm=True):
+    def read_stl(self, filename, unify=True):
         """
         Function to read .stl file from filename and import data into 
         the AmpObj 
@@ -106,14 +111,24 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin,
         self.norm = norm
         # Call function to unify vertices of the array
         if unify is True:
-            self.unify_vertices()
+            self.unifyVert()
         # Call function to calculate the edges array
+        self.calcStruct()
+        
+    def calcStruct(self, norm=True, edges=True, 
+                   edgeFaces=True, faceEdges=True, vNorm=True):
+        if norm is True:
+            self.calcNorm()
         if edges is True:
-            self.computeEdges()
+            self.calcEdges()
+        if edgeFaces is True:
+            self.calcEdgeFaces()
+        if faceEdges is True:
+            self.calcFaceEdges()
         if vNorm is True:
             self.calcVNorm()
 
-    def unify_vertices(self):
+    def unifyVert(self):
         """
         Function to unify coincident vertices of the mesh to reduce
         size of the vertices array enabling speed increases
@@ -124,7 +139,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin,
         self.faces = np.resize(indC[self.faces], 
                                (len(self.norm), 3)).astype(np.int32)
 
-    def computeEdges(self):
+    def calcEdges(self):
         """
         Function to compute the edges array, the edges on each face, 
         and the faces on each edge
@@ -139,12 +154,20 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin,
         # Get edges array
         self.edges = np.reshape(self.faces[:, [0, 1, 0, 2, 1, 2]], [-1, 2])
         self.edges = np.sort(self.edges, 1)
-        # Get edges on each face 
-        self.edgesFace = np.reshape(range(len(self.faces)*3), [-1,3])
         # Unify the edges
         self.edges, indC = np.unique(self.edges, return_inverse=True, axis=0)
+
+    def calcEdgeFaces(self):
+        edges = np.reshape(self.faces[:, [0, 1, 0, 2, 1, 2]], [-1, 2])
+        edges = np.sort(edges, 1)
+        # Unify the edges
+        edges, indC = np.unique(edges, return_inverse=True, axis=0)
+        # Get edges on each face 
+        self.edgesFace = np.reshape(range(len(self.faces)*3), [-1,3])
         #Remap the edgesFace array 
         self.edgesFace = indC[self.edgesFace].astype(np.int32)
+
+    def calcFaceEdges(self):
         #Initiate the faceEdges array
         self.faceEdges = np.empty([len(self.edges), 2], dtype=np.int32)
         self.faceEdges.fill(-99999)
@@ -156,7 +179,19 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin,
         logic = np.zeros([len(eF)], dtype=bool)
         logic[eFInd] = True
         self.faceEdges[eF[logic], 0] = fInd[logic]
-        self.faceEdges[eF[~logic], 1] = fInd[~logic]
+        self.faceEdges[eF[~logic], 1] = fInd[~logic]        
+        
+
+    def calcNorm(self):
+        """
+        Calculate the normal of each face of the AmpObj
+        """
+        norms = np.cross(self.vert[self.faces[:,1]] -
+                         self.vert[self.faces[:,0]],
+                         self.vert[self.faces[:,2]] -
+                         self.vert[self.faces[:,0]])
+        mag = np.linalg.norm(norms, axis=1)
+        self.norm = np.divide(norms, mag[:,None])
         
     def calcVNorm(self):
         """
@@ -198,17 +233,6 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin,
         data_write.tofile(fh)
         fh.close()
 
-    def calc_norm(self):
-        """
-        Calculate the normal of each face of the AmpObj
-        """
-        norms = np.cross(self.vert[self.faces[:,1]] -
-                         self.vert[self.faces[:,0]],
-                         self.vert[self.faces[:,2]] -
-                         self.vert[self.faces[:,0]])
-        mag = np.linalg.norm(norms, axis=1)
-        self.norm = np.divide(norms, mag[:,None])
-
     def translate(self, trans):
         """
         Translate the AmpObj in 3D space
diff --git a/AmpScan/registration.py b/AmpScan/registration.py
index da56c68..c5db75e 100644
--- a/AmpScan/registration.py
+++ b/AmpScan/registration.py
@@ -6,6 +6,7 @@ Created on Wed Sep 13 16:07:10 2017
 """
 import numpy as np
 import pandas as pd 
+import copy
 from scipy import spatial
 from .core import AmpObject
 """
@@ -51,74 +52,61 @@ Child classes:
     FE mesh 
 """
 
-class regObject(AmpObject):
+def registration(baseline, target, method='default', steps=5, direct=True):
     
-    def __init__(self, Data=None, stype='AmpObj'):
-        super(regObject, self).__init__(Data, stype)
-
-    def registration(self, steps=1, baseline='limb',
-                     target='socket', reg = 'reglimb', direct=True):
-        """
-        Function to register the regObject to the baseline mesh
-        
-        Parameters
-        ----------
-        Steps: int, default 1
-            Number of iterations
-        """
-        bData = getattr(self, baseline)
-        tData = getattr(self, target)
-        bV = bData['vert']
-        # Calculate the face centroids of the regObject
-        tData['fC'] = tData['vert'][tData['faces']].mean(axis=1)
-        # Construct knn tree
-        tTree = spatial.cKDTree(tData['fC'])
-        for step in np.arange(steps, 0, -1):
-            # Index of 10 centroids nearest to each baseline vertex
-            ind = tTree.query(bV, 10)[1]
-            D = np.zeros(bV.shape)
-            # Define normals for faces of nearest faces
-            norms = tData['norm'][ind]
-            # Get a point on each face
-            fPoints = tData['vert'][tData['faces'][ind, 0]]
-            # Calculate dot product between point on face and normals
-            d = np.einsum('ijk, ijk->ij', norms, fPoints)
-            t = d - np.einsum('ijk, ik->ij', norms, bV)
-            # Calculate new points
-            G = np.einsum('ijk, ij->ijk', norms, t)
-            GMag = np.sqrt(np.einsum('ijk, ijk->ij', G, G)).argmin(axis=1)
-            # Define vector from baseline point to intersect point
-            D = G[np.arange(len(G)), GMag, :]
-            bV = bV + D/step
-        regData = dict(bData)
-        regData['vert'] = bV
-        setattr(self, reg, regData)
-        self.calcError(baseline, reg, direct)
-
-        
-    def calcError(self, baseline='limb', target='reglimb', direct=True):
-        # This is kinda slow
-        bData = getattr(self, baseline)
-        tData = getattr(self, target)
+    """
+    Function to register the regObject to the baseline mesh
+    
+    Parameters
+    ----------
+    Steps: int, default 1
+        Number of iterations
+    """
+    bV = baseline.vert
+    # Calc FaceCentroids
+    fC = target.vert[target.faces].mean(axis=1)
+    # Construct knn tree
+    tTree = spatial.cKDTree(fC)
+    for step in np.arange(steps, 0, -1):
+        # Index of 10 centroids nearest to each baseline vertex
+        ind = tTree.query(bV, 10)[1]
+        D = np.zeros(bV.shape)
+        # Define normals for faces of nearest faces
+        norms = target.norm[ind]
+        # Get a point on each face
+        fPoints = target.vert[target.faces[ind, 0]]
+        # Calculate dot product between point on face and normals
+        d = np.einsum('ijk, ijk->ij', norms, fPoints)
+        t = d - np.einsum('ijk, ik->ij', norms, bV)
+        # Calculate new points
+        G = np.einsum('ijk, ij->ijk', norms, t)
+        GMag = np.sqrt(np.einsum('ijk, ijk->ij', G, G)).argmin(axis=1)
+        # Define vector from baseline point to intersect point
+        D = G[np.arange(len(G)), GMag, :]
+        bV = bV + D/step
+    bData = dict(zip(['vert', 'faces'], [bV, baseline.faces]))
+    regObj = AmpObject(bData, stype='reg')
+    
+    def calcError(baseline, regObj, direct=True):
         if direct is True:
-            values = np.linalg.norm(tData['vert'] - bData['vert'], axis=1)
-            # Calculate vertex normals on target from normal of surrounding faces
-            vNorm = np.zeros(tData['vert'].shape)
-            for face, norm in zip(tData['faces'], tData['norm']):
-                vNorm[face, :] += norm
-            vNorm = vNorm / np.linalg.norm(vNorm, axis=1)[:, None]
+            values = np.linalg.norm(regObj.vert - baseline.vert, axis=1)
             # Calculate the unit vector normal between corresponding vertices
             # baseline and target
-            vector = (tData['vert'] - bData['vert'])/values[:, None]
+            vector = (regObj.vert - baseline.vert)/values[:, None]
             # Calculate angle between the two unit vectors using normal of cross
             # product between vNorm and vector and dot
-            normcrossP = np.linalg.norm(np.cross(vector, vNorm), axis=1)
-            dotP = np.einsum('ij,ij->i', vector, vNorm)
+            normcrossP = np.linalg.norm(np.cross(vector, target.vNorm), axis=1)
+            dotP = np.einsum('ij,ij->i', vector, target.vNorm)
             angle = np.arctan2(normcrossP, dotP)
             polarity = np.ones(angle.shape)
             polarity[angle < np.pi/2] =-1.0
-            tData['values'] = values * polarity
+            values = values * polarity
+            return values
         else:
-            tData['values'] = np.linalg.norm(tData['vert'] - bData['vert'],
-                                           axis=1)
+            values = np.linalg.norm(regObj.vert - baseline.vert, axis=1)
+            return values
+
+    regObj.values = calcError(baseline, regObj, False)
+    return regObj
+
         
-- 
GitLab