mirror of
https://github.com/13hannes11/situr.git
synced 2024-09-03 20:50:58 +02:00
91 lines
3.6 KiB
Python
91 lines
3.6 KiB
Python
import abc
|
|
from situr.registration.peak_finder import PeakFinderDifferenceOfGaussian
|
|
import open3d as o3
|
|
from probreg import filterreg
|
|
import numpy as np
|
|
|
|
from situr.image import extend_dim
|
|
from situr.transformation import Transform, ScaleRotateTranslateTransform
|
|
|
|
|
|
class RegistrationFunction:
|
|
__metaclass__ = abc.ABCMeta
|
|
|
|
@abc.abstractmethod
|
|
def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> Transform:
|
|
"""Method that does the registration on two arrays of peaks.
|
|
|
|
Args:
|
|
data_peaks (np.ndarray): [description]
|
|
reference_peaks (np.ndarray): [description]
|
|
|
|
Raises:
|
|
NotImplementedError: This method is abstract and therefore calling
|
|
it results in an error.
|
|
|
|
Returns:
|
|
Transform: The transformation that can be used to register data_peaks to the reference.
|
|
"""
|
|
raise NotImplementedError(self.__class__.__name__ + '.do_registration')
|
|
|
|
|
|
class FilterregRegistrationFunction(RegistrationFunction):
|
|
def do_registration(self,
|
|
data_peaks: np.ndarray,
|
|
reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform:
|
|
"""Method that uses filterregregistration to register the data_peaks.
|
|
|
|
Args:
|
|
data_peaks (np.ndarray): The peaks to be registered to the reference
|
|
reference_peaks (np.ndarray): The reference peaks
|
|
|
|
Returns:
|
|
ScaleRotateTranslateTransform: the resulting transformaton from the registration
|
|
"""
|
|
source = o3.geometry.PointCloud()
|
|
source.points = o3.utility.Vector3dVector(extend_dim(data_peaks))
|
|
target = o3.geometry.PointCloud()
|
|
target.points = o3.utility.Vector3dVector(extend_dim(reference_peaks))
|
|
|
|
registration_method = filterreg.registration_filterreg
|
|
tf_param, _, _ = filterreg.registration_filterreg(source, target)
|
|
|
|
return ScaleRotateTranslateTransform(transform_matrix=tf_param.rot[0:2, 0:2],
|
|
scale=tf_param.scale, offset=tf_param.t[0:2])
|
|
|
|
|
|
class IcpRegistrationFunction(RegistrationFunction):
|
|
def __init__(self, max_correspondence_distance=50) -> None:
|
|
self.max_distance = max_correspondence_distance
|
|
|
|
def do_registration(self,
|
|
data_peaks: np.ndarray,
|
|
reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform:
|
|
source = o3.geometry.PointCloud()
|
|
source.points = o3.utility.Vector3dVector(extend_dim(data_peaks))
|
|
target = o3.geometry.PointCloud()
|
|
target.points = o3.utility.Vector3dVector(extend_dim(reference_peaks))
|
|
reg_p2p = o3.pipelines.registration.registration_icp(
|
|
source, target, self.max_distance)
|
|
return ScaleRotateTranslateTransform(
|
|
reg_p2p.transformation[0:2, 0:2],
|
|
offset=reg_p2p.transformation[[1, 0], 3])
|
|
|
|
|
|
class Registration:
|
|
__metaclass__ = abc.ABCMeta
|
|
|
|
def __init__(self,
|
|
registration_function: RegistrationFunction() = FilterregRegistrationFunction(),
|
|
peak_finder=PeakFinderDifferenceOfGaussian()):
|
|
"""Initialize channel registration and tell which registration function to use.
|
|
|
|
Args:
|
|
registration_function (RegistrationFunction, optional): Registration function.
|
|
Defaults to FilterregRegistrationFunction().
|
|
peak_finder (PeakFinder, optional): The peak finder to be used for the registration.
|
|
Defaults to PeakFinderDifferenceOfGaussian().
|
|
"""
|
|
self.registration_function = registration_function
|
|
self.peak_finder = peak_finder
|