add documentation to all registration classes

This commit is contained in:
2021-07-22 11:34:05 +02:00
parent 334f908e82
commit b32451c73c
4 changed files with 108 additions and 20 deletions

View File

@@ -8,9 +8,18 @@ from situr.registration import Registration, RegistrationFunction, FilterregRegi
class SituImageChannelRegistration(Registration): class SituImageChannelRegistration(Registration):
"""This class is meant for channel registrations that are to be performed directly on a
SituImage and not on a Tile.
It inherits from Registration.
"""
def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0): def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0):
# For each channel (except nucleus) compute transform compared to reference_channel """This function performs a registration for each channel (except the nucleaus channel).
# Add Channel transformation to Channel
Args:
situ_img (SituImage): The image that should be registered
reference_channel (int, optional): the reference channel that all channels are
registered against. Defaults to 0.
"""
reference_peaks = self.peak_finder.get_channel_peaks( reference_peaks = self.peak_finder.get_channel_peaks(
situ_img, reference_channel) situ_img, reference_channel)
for channel in range(situ_img.get_channel_count()): for channel in range(situ_img.get_channel_count()):
@@ -24,7 +33,18 @@ class SituImageChannelRegistration(Registration):
class ChannelRegistration(Registration): class ChannelRegistration(Registration):
"""This class performs a simple channel registration on a Tile. Each round is looked at
seperately and registered with the reference channel.
It inherits from Registration.
"""
def do_channel_registration(self, tile: Tile, reference_channel: int = 0): def do_channel_registration(self, tile: Tile, reference_channel: int = 0):
"""Perform a SituImageChannelRegistration for each Image.
Args:
tile (Tile): the tile that the registration is supposed to be on.
reference_channel (int, optional): the reference channel that all channels are
registered against. Defaults to 0.
"""
registration = SituImageChannelRegistration() registration = SituImageChannelRegistration()
# For each channel (except nucleus) compute transform compared to reference_channel # For each channel (except nucleus) compute transform compared to reference_channel
# Add Channel transformation to Channel # Add Channel transformation to Channel
@@ -34,7 +54,18 @@ class ChannelRegistration(Registration):
class AcrossRoundChannelRegistration(ChannelRegistration): class AcrossRoundChannelRegistration(ChannelRegistration):
"""This class is a registration that uses rounds across images to do the registration.
Inherits from ChannelRegistration.
"""
def do_channel_registration(self, tile: Tile, reference_channel: int = 0): def do_channel_registration(self, tile: Tile, reference_channel: int = 0):
"""Performs a registration, where a channel is merged across rounds to give more datapoints.
This, however, makes it slower.
Args:
tile (Tile): the tile that the registration is supposed to be on.
reference_channel (int, optional): the reference channel that all channels are
registered against. Defaults to 0.
"""
reference_peaks = [] reference_peaks = []
for round in range(tile.get_round_count()): for round in range(tile.get_round_count()):
reference_peaks.append(self.peak_finder.get_channel_peaks( reference_peaks.append(self.peak_finder.get_channel_peaks(

View File

@@ -12,12 +12,36 @@ class RegistrationFunction:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
@abc.abstractmethod @abc.abstractmethod
def do_registration(self, data_peaks, reference_peaks) -> Transform: def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> Transform:
"""Method that does the registration on two arrays of peaks.
Args:
data_peaks (np.ndarray): [description]
reference_peaks (np.ndarray): [description]
Raises:
NotImplementedError: This method is abstract and therefore calling
it results in an error.
Returns:
Transform: The transformation that can be used to register data_peaks to the reference.
"""
raise NotImplementedError(self.__class__.__name__ + '.do_registration') raise NotImplementedError(self.__class__.__name__ + '.do_registration')
class FilterregRegistrationFunction(RegistrationFunction): class FilterregRegistrationFunction(RegistrationFunction):
def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform: def do_registration(self,
data_peaks: np.ndarray,
reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform:
"""Method that uses filterregregistration to register the data_peaks.
Args:
data_peaks (np.ndarray): The peaks to be registered to the reference
reference_peaks (np.ndarray): The reference peaks
Returns:
ScaleRotateTranslateTransform: the resulting transformaton from the registration
"""
source = o3.geometry.PointCloud() source = o3.geometry.PointCloud()
source.points = o3.utility.Vector3dVector(extend_dim(data_peaks)) source.points = o3.utility.Vector3dVector(extend_dim(data_peaks))
target = o3.geometry.PointCloud() target = o3.geometry.PointCloud()
@@ -26,18 +50,23 @@ class FilterregRegistrationFunction(RegistrationFunction):
registration_method = filterreg.registration_filterreg registration_method = filterreg.registration_filterreg
tf_param, _, _ = filterreg.registration_filterreg(source, target) tf_param, _, _ = filterreg.registration_filterreg(source, target)
return ScaleRotateTranslateTransform(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2]) return ScaleRotateTranslateTransform(transform_matrix=tf_param.rot[0:2, 0:2],
scale=tf_param.scale, offset=tf_param.t[0:2])
class Registration: class Registration:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, registration_function: RegistrationFunction() = FilterregRegistrationFunction(), peak_finder=PeakFinderDifferenceOfGaussian()): def __init__(self,
registration_function: RegistrationFunction() = FilterregRegistrationFunction(),
peak_finder=PeakFinderDifferenceOfGaussian()):
"""Initialize channel registration and tell which registration function to use. """Initialize channel registration and tell which registration function to use.
Args: Args:
registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform). registration_function (RegistrationFunction, optional): Registration function.
peak_finder (PeakFinder, optional): The peak finder to be used for the registration. Defaults to PeakFinderDifferenceOfGaussian(). Defaults to FilterregRegistrationFunction().
peak_finder (PeakFinder, optional): The peak finder to be used for the registration.
Defaults to PeakFinderDifferenceOfGaussian().
""" """
self.registration_function = registration_function self.registration_function = registration_function
self.peak_finder = peak_finder self.peak_finder = peak_finder

View File

@@ -1,19 +1,26 @@
from situr.image.situ_tile import Tile
import numpy as np
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
import numpy as np
class RoundRegistration(Registration): class RoundRegistration(Registration):
def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0): def do_round_registration(self,
"""This method generates a round registration transformation for a tile and saves it in the tile. situ_tile,
reference_round: int = 0,
reference_channel: int = 0):
"""This method generates a round registration transformation for a tile
and saves it in the tile.
Args: Args:
situ_tile (Tile): The tile that the transformation is to be performed on. situ_tile (Tile): The tile that the transformation is to be performed on.
reference_round (int, optional): The round that is referenced and will not be changed. Defaults to 0. reference_round (int, optional): The round that is referenced and will not be changed.
reference_channel (int, optional): The channel tha is used to compare rounds. Defaults to 0. Defaults to 0.
reference_channel (int, optional): The channel that is used to compare rounds.
Defaults to 0.
""" """
# TODO: instead of one reference channel use all channels (maybe without nucleus channel)
reference_peaks = self.peak_finder.get_channel_peaks(situ_tile.get_round( reference_peaks = self.peak_finder.get_channel_peaks(situ_tile.get_round(
reference_round), reference_channel) reference_round), reference_channel)
for round in range(situ_tile.get_round_count()): for round in range(situ_tile.get_round_count()):
@@ -26,13 +33,21 @@ class RoundRegistration(Registration):
class AllChannelRoundRegistration(RoundRegistration): class AllChannelRoundRegistration(RoundRegistration):
"""This class perofrms a round registration using all channels instead of just the reference
channel. It inherits from RoundRegistration.
"""
def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0): def do_round_registration(self,
"""This method generates a round registration transformation for a tile and saves it in the tile. situ_tile: Tile,
reference_round: int = 0,
reference_channel: int = 0):
"""This method generates a round registration transformation for a tile and saves it in
the tile.
Args: Args:
situ_tile (Tile): The tile that the transformation is to be performed on. situ_tile (Tile): The tile that the transformation is to be performed on.
reference_round (int, optional): The round that is referenced and will not be changed. Defaults to 0. reference_round (int, optional): The round that is referenced and will not be changed.
Defaults to 0.
reference_channel (int, optional): This parameter is ignored. reference_channel (int, optional): This parameter is ignored.
""" """
reference_peaks = [] reference_peaks = []

View File

@@ -3,9 +3,19 @@ from situr.registration import RoundRegistration, ChannelRegistration, round_reg
class CombinedRegistration: class CombinedRegistration:
def __init__(self, round_registration: RoundRegistration = RoundRegistration(), channel_registration: ChannelRegistration = ChannelRegistration(), reference_channel=0) -> None: """CombinedRegistration is a registration that performs a channel and a round transformaton
after each other. Also the transformations are directly applied after each registration.
"""
def __init__(self,
round_registration: RoundRegistration = RoundRegistration(),
channel_registration: ChannelRegistration = ChannelRegistration(),
reference_channel: int = 0,
reference_round: int = 0) -> None:
self.round_registration = round_registration self.round_registration = round_registration
self.channel_registration = channel_registration self.channel_registration = channel_registration
self.reference_channel = reference_channel
self.reference_round = reference_round
def do_registration_and_transform(self, tile: Tile): def do_registration_and_transform(self, tile: Tile):
""" This function applies the registration in the following order: """ This function applies the registration in the following order:
@@ -17,10 +27,13 @@ class CombinedRegistration:
Args: Args:
tile (Tile): The tile that the registration and transformations are to be performed on. tile (Tile): The tile that the registration and transformations are to be performed on.
""" """
self.channel_registration.do_channel_registration(tile) self.channel_registration.do_channel_registration(
tile, self.reference_channel)
tile.apply_channel_transformations() tile.apply_channel_transformations()
self.round_registration.do_round_registration(tile) self.round_registration.do_round_registration(tile,
self.reference_round,
self.reference_channel)
tile.apply_round_transformations() tile.apply_round_transformations()