redo registration to be more generic

This commit is contained in:
2021-07-12 18:39:42 +02:00
parent 5fc4b75b6c
commit 29c9555755
4 changed files with 46 additions and 29 deletions

View File

@@ -0,0 +1 @@
from .registration import Registration, RegistrationFunction, FilterregRegistrationFunction

View File

@@ -1,34 +1,9 @@
import abc
import open3d as o3
from probreg import filterreg
from situr.image import extend_dim
from situr.transformation import ScaleRotateTranslateChannelTransform
from situr.registration import Registration
class ChannelRegistration:
__metaclass__ = abc.ABCMeta
def do_registration(self, situ_img, reference_channel=0):
class ChannelRegistration(Registration):
def do_channel_registration(self, situ_img, reference_channel=0):
# For each channel (except nucleus) compute transform compared to reference_channel
# Add Channel transformation to Channel
# TODO: implement
pass
@abc.abstractmethod
def register_single_channel(self, peaks_data, reference_peaks):
"""Performs the channel registration on an image. Expects the peaks in each image as input."""
raise NotImplementedError(
self.__class__.__name__ + '.register_single_channel')
class FilterregChannelRegistration(ChannelRegistration):
def register_single_channel(self, data_peaks, reference_peaks):
source = o3.geometry.PointCloud()
source.points = o3.utility.Vector3dVector(extend_dim(data_peaks))
target = o3.geometry.PointCloud()
target.points = o3.utility.Vector3dVector(extend_dim(reference_peaks))
registration_method = filterreg.registration_filterreg
tf_param, _, _ = filterreg.registration_filterreg(source, target)
return ScaleRotateTranslateChannelTransform(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2])

View File

@@ -0,0 +1,35 @@
import abc
import open3d as o3
from probreg import filterreg
from situr.image import extend_dim
from situr.transformation import ScaleRotateTranslateChannelTransform
class RegistrationFunction:
__metaclass__ = abc.ABCMeta
def __init__(self, transormation_type=ScaleRotateTranslateChannelTransform):
self.transormation_type = transormation_type
@abc.abstractmethod
def do_registration(self, data_peaks, reference_peaks):
raise NotImplementedError(self.__class__.__name__ + '.do_registration')
class FilterregRegistrationFunction(RegistrationFunction):
def do_registration(self, data_peaks, reference_peaks):
source = o3.geometry.PointCloud()
source.points = o3.utility.Vector3dVector(extend_dim(data_peaks))
target = o3.geometry.PointCloud()
target.points = o3.utility.Vector3dVector(extend_dim(reference_peaks))
registration_method = filterreg.registration_filterreg
tf_param, _, _ = filterreg.registration_filterreg(source, target)
return self.transormation_type(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2])
class Registration:
def __init__(self, registration_function):
self.registration_function = registration_function

View File

@@ -0,0 +1,6 @@
class RoundRegistration:
def do_round_registration(self, situ_tile, reference_channel=0):
# For each channel (except nucleus) compute transform compared to reference_channel
# Add Channel transformation to Channel
# TODO: implement
pass