Skip to content
Snippets Groups Projects
Commit ab46ee40 authored by Joshua Steer's avatar Joshua Steer
Browse files

Updated align code so it now has point2point and point2plane functionality

parent 51baaa60
No related branches found
No related tags found
No related merge requests found
...@@ -102,9 +102,7 @@ class align(object): ...@@ -102,9 +102,7 @@ class align(object):
[moving.vert, moving.faces, moving.values])) [moving.vert, moving.faces, moving.values]))
alData = copy.deepcopy(mData) alData = copy.deepcopy(mData)
self.m = AmpObject(alData, stype='reg') self.m = AmpObject(alData, stype='reg')
self.m.calcVNorm()
self.s = static self.s = static
self.s.calcVNorm()
if method is not None: if method is not None:
getattr(self, method)(*args, **kwargs) getattr(self, method)(*args, **kwargs)
...@@ -141,7 +139,8 @@ class align(object): ...@@ -141,7 +139,8 @@ class align(object):
dist = dist.min(axis=1) dist = dist.min(axis=1)
return dist.sum() return dist.sum()
def linearICP(self, maxiter=20, inlier=0.5, initTransform=None): def linearICP(self, metric = 'point2point',
maxiter=20, inlier=1.0, initTransform=None):
""" """
Iterative Closest Point algorithm which relies on using least squares Iterative Closest Point algorithm which relies on using least squares
method on a having made the minimisation problem into a set of linear method on a having made the minimisation problem into a set of linear
...@@ -150,19 +149,20 @@ class align(object): ...@@ -150,19 +149,20 @@ class align(object):
# Define the rotation, translation, error and quaterion arrays # Define the rotation, translation, error and quaterion arrays
Rs = np.zeros([3, 3, maxiter+1]) Rs = np.zeros([3, 3, maxiter+1])
Ts = np.zeros([3, maxiter+1]) Ts = np.zeros([3, maxiter+1])
qs = np.r_[np.ones([1, maxiter+1]), # qs = np.r_[np.ones([1, maxiter+1]),
np.zeros([6, maxiter+1])] # np.zeros([6, maxiter+1])]
dq = np.zeros([7, maxiter+1]) # dq = np.zeros([7, maxiter+1])
dTheta = np.zeros([maxiter+1]) dTheta = np.zeros([maxiter+1])
err = np.zeros([maxiter+1]) err = np.zeros([maxiter+1])
if initTransform is None: if initTransform is None:
initTransform = np.eye(4) initTransform = np.eye(4)
Rs[:, :, 0] = initTransform[:3, :3] Rs[:, :, 0] = initTransform[:3, :3]
Ts[:, 0] = initTransform[3, :3] Ts[:, 0] = initTransform[3, :3]
qs[:4, 0] = self.rot2quat(Rs[:, :, 0]) # qs[:4, 0] = self.rot2quat(Rs[:, :, 0])
qs[4:, 0] = Ts[:, 0] # qs[4:, 0] = Ts[:, 0]
# Define # Define
kdTree = spatial.cKDTree(self.s.vert) fC = self.s.vert[self.s.faces].mean(axis=1)
kdTree = spatial.cKDTree(fC)
self.m.rigidTransform(Rs[:, :, 0], Ts[:, 0]) self.m.rigidTransform(Rs[:, :, 0], Ts[:, 0])
inlier = math.ceil(self.m.vert.shape[0]*inlier) inlier = math.ceil(self.m.vert.shape[0]*inlier)
[dist, idx] = kdTree.query(self.m.vert, 1) [dist, idx] = kdTree.query(self.m.vert, 1)
...@@ -173,9 +173,14 @@ class align(object): ...@@ -173,9 +173,14 @@ class align(object):
[dist, idx, sort] = dist[:inlier], idx[:inlier], sort[:inlier] [dist, idx, sort] = dist[:inlier], idx[:inlier], sort[:inlier]
err[0] = math.sqrt(dist.mean()) err[0] = math.sqrt(dist.mean())
for i in range(maxiter): for i in range(maxiter):
[R, T] = self.point2plane(self.m.vert[sort], if metric == 'point2point':
self.s.vert[idx, :], [R, T] = getattr(self, metric)(self.m.vert[sort],
self.s.vNorm[idx, :]) fC[idx, :])
else:
[R, T] = getattr(self, metric)(self.m.vert[sort],
fC[idx, :],
self.s.norm[idx, :])
Rs[:, :, i+1] = np.dot(R, Rs[:, :, i]) Rs[:, :, i+1] = np.dot(R, Rs[:, :, i])
Ts[:, i+1] = np.dot(R, Ts[:, i]) + T Ts[:, i+1] = np.dot(R, Ts[:, i]) + T
self.m.rigidTransform(R, T) self.m.rigidTransform(R, T)
...@@ -184,7 +189,7 @@ class align(object): ...@@ -184,7 +189,7 @@ class align(object):
[dist, idx] = [dist[sort], idx[sort]] [dist, idx] = [dist[sort], idx[sort]]
[dist, idx, sort] = dist[:inlier], idx[:inlier], sort[:inlier] [dist, idx, sort] = dist[:inlier], idx[:inlier], sort[:inlier]
err[i+1] = math.sqrt(dist.mean()) err[i+1] = math.sqrt(dist.mean())
qs[:, i+1] = np.r_[self.rot2quat(R), T] # qs[:, i+1] = np.r_[self.rot2quat(R), T]
R = Rs[:, :, -1] R = Rs[:, :, -1]
#Simpl #Simpl
[U, s, V] = np.linalg.svd(R) [U, s, V] = np.linalg.svd(R)
...@@ -213,6 +218,21 @@ class align(object): ...@@ -213,6 +218,21 @@ class align(object):
T = X[3:] T = X[3:]
return (R, T) return (R, T)
@staticmethod
def point2point(mv, sv):
mCent = mv - mv.mean(axis=0)
sCent = sv - sv.mean(axis=0)
C = np.dot(mCent.T, sCent)
[U,_,V] = np.linalg.svd(C)
det = np.linalg.det(np.dot(U, V))
sign = np.eye(3)
sign[2,2] = np.sign(det)
R = np.dot(V.T, sign)
R = np.dot(R, U.T)
T = sv.mean(axis=0) - np.dot(R, mv.mean(axis=0))
return (R, T)
@staticmethod @staticmethod
def rot2quat(R): def rot2quat(R):
[[Qxx, Qxy, Qxz], [[Qxx, Qxy, Qxz],
...@@ -275,7 +295,7 @@ class align(object): ...@@ -275,7 +295,7 @@ class align(object):
self.s.actor.setOpacity(0.5) self.s.actor.setOpacity(0.5)
self.m.actor.setColor([0.0, 0.0, 1.0]) self.m.actor.setColor([0.0, 0.0, 1.0])
self.m.actor.setOpacity(0.5) self.m.actor.setOpacity(0.5)
win.renderActors([self.s.actor, self.m.actor], shading=True) win.renderActors([self.s.actor, self.m.actor])
win.Render() win.Render()
win.rens[0].GetActiveCamera().Azimuth(180) win.rens[0].GetActiveCamera().Azimuth(180)
win.rens[0].GetActiveCamera().SetParallelProjection(True) win.rens[0].GetActiveCamera().SetParallelProjection(True)
......
...@@ -356,6 +356,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, ...@@ -356,6 +356,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin,
self.vert[:, :] = np.dot(self.vert, R.T) self.vert[:, :] = np.dot(self.vert, R.T)
if norms is True: if norms is True:
self.norm[:, :] = np.dot(self.norm, R.T) self.norm[:, :] = np.dot(self.norm, R.T)
if hasattr(self, 'vNorm'):
self.vNorm[:, :] = np.dot(self.vNorm, R.T) self.vNorm[:, :] = np.dot(self.vNorm, R.T)
...@@ -373,7 +374,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin, ...@@ -373,7 +374,7 @@ class AmpObject(trimMixin, smoothMixin, analyseMixin,
""" """
if R is not None: if R is not None:
self.rotateMat(R, True) self.rotate(R, True)
if T is not None: if T is not None:
self.translate(T) self.translate(T)
......
...@@ -124,7 +124,8 @@ class AmpScanGUI(QMainWindow): ...@@ -124,7 +124,8 @@ class AmpScanGUI(QMainWindow):
moving = str(self.alCont.moving.currentText()) moving = str(self.alCont.moving.currentText())
self.fileManager.setTable(static, [1,0,0], 0.5, 2) self.fileManager.setTable(static, [1,0,0], 0.5, 2)
self.fileManager.setTable(moving, [0,0,1], 0.5, 2) self.fileManager.setTable(moving, [0,0,1], 0.5, 2)
al = align(self.files[moving], self.files[static]).m al = align(self.files[moving], self.files[static],
maxiter=10, method='linearICP',metric='point2plane').m
al.addActor() al.addActor()
alName = moving + '_al' alName = moving + '_al'
self.files[alName] = al self.files[alName] = al
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment