diff --git a/situr/registration/channel_registration.py b/situr/registration/channel_registration.py index 72ea10d..d422c2f 100644 --- a/situr/registration/channel_registration.py +++ b/situr/registration/channel_registration.py @@ -8,9 +8,18 @@ from situr.registration import Registration, RegistrationFunction, FilterregRegi class SituImageChannelRegistration(Registration): + """This class is meant for channel registrations that are to be performed directly on a + SituImage and not on a Tile. + It inherits from Registration. + """ def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0): - # For each channel (except nucleus) compute transform compared to reference_channel - # Add Channel transformation to Channel + """This function performs a registration for each channel (except the nucleaus channel). + + Args: + situ_img (SituImage): The image that should be registered + reference_channel (int, optional): the reference channel that all channels are + registered against. Defaults to 0. + """ reference_peaks = self.peak_finder.get_channel_peaks( situ_img, reference_channel) for channel in range(situ_img.get_channel_count()): @@ -24,7 +33,18 @@ class SituImageChannelRegistration(Registration): class ChannelRegistration(Registration): + """This class performs a simple channel registration on a Tile. Each round is looked at + seperately and registered with the reference channel. + It inherits from Registration. + """ def do_channel_registration(self, tile: Tile, reference_channel: int = 0): + """Perform a SituImageChannelRegistration for each Image. + + Args: + tile (Tile): the tile that the registration is supposed to be on. + reference_channel (int, optional): the reference channel that all channels are + registered against. Defaults to 0. + """ registration = SituImageChannelRegistration() # For each channel (except nucleus) compute transform compared to reference_channel # Add Channel transformation to Channel @@ -34,7 +54,18 @@ class ChannelRegistration(Registration): class AcrossRoundChannelRegistration(ChannelRegistration): + """This class is a registration that uses rounds across images to do the registration. + Inherits from ChannelRegistration. + """ def do_channel_registration(self, tile: Tile, reference_channel: int = 0): + """Performs a registration, where a channel is merged across rounds to give more datapoints. + This, however, makes it slower. + + Args: + tile (Tile): the tile that the registration is supposed to be on. + reference_channel (int, optional): the reference channel that all channels are + registered against. Defaults to 0. + """ reference_peaks = [] for round in range(tile.get_round_count()): reference_peaks.append(self.peak_finder.get_channel_peaks( diff --git a/situr/registration/registration.py b/situr/registration/registration.py index 762e9b7..fb3b15e 100644 --- a/situr/registration/registration.py +++ b/situr/registration/registration.py @@ -12,12 +12,36 @@ class RegistrationFunction: __metaclass__ = abc.ABCMeta @abc.abstractmethod - def do_registration(self, data_peaks, reference_peaks) -> Transform: + def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> Transform: + """Method that does the registration on two arrays of peaks. + + Args: + data_peaks (np.ndarray): [description] + reference_peaks (np.ndarray): [description] + + Raises: + NotImplementedError: This method is abstract and therefore calling + it results in an error. + + Returns: + Transform: The transformation that can be used to register data_peaks to the reference. + """ raise NotImplementedError(self.__class__.__name__ + '.do_registration') class FilterregRegistrationFunction(RegistrationFunction): - def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform: + def do_registration(self, + data_peaks: np.ndarray, + reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform: + """Method that uses filterregregistration to register the data_peaks. + + Args: + data_peaks (np.ndarray): The peaks to be registered to the reference + reference_peaks (np.ndarray): The reference peaks + + Returns: + ScaleRotateTranslateTransform: the resulting transformaton from the registration + """ source = o3.geometry.PointCloud() source.points = o3.utility.Vector3dVector(extend_dim(data_peaks)) target = o3.geometry.PointCloud() @@ -26,18 +50,23 @@ class FilterregRegistrationFunction(RegistrationFunction): registration_method = filterreg.registration_filterreg tf_param, _, _ = filterreg.registration_filterreg(source, target) - return ScaleRotateTranslateTransform(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2]) + return ScaleRotateTranslateTransform(transform_matrix=tf_param.rot[0:2, 0:2], + scale=tf_param.scale, offset=tf_param.t[0:2]) class Registration: __metaclass__ = abc.ABCMeta - def __init__(self, registration_function: RegistrationFunction() = FilterregRegistrationFunction(), peak_finder=PeakFinderDifferenceOfGaussian()): + def __init__(self, + registration_function: RegistrationFunction() = FilterregRegistrationFunction(), + peak_finder=PeakFinderDifferenceOfGaussian()): """Initialize channel registration and tell which registration function to use. Args: - registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform). - peak_finder (PeakFinder, optional): The peak finder to be used for the registration. Defaults to PeakFinderDifferenceOfGaussian(). + registration_function (RegistrationFunction, optional): Registration function. + Defaults to FilterregRegistrationFunction(). + peak_finder (PeakFinder, optional): The peak finder to be used for the registration. + Defaults to PeakFinderDifferenceOfGaussian(). """ self.registration_function = registration_function self.peak_finder = peak_finder diff --git a/situr/registration/round_registration.py b/situr/registration/round_registration.py index 1975b71..bb7b061 100644 --- a/situr/registration/round_registration.py +++ b/situr/registration/round_registration.py @@ -1,19 +1,26 @@ +from situr.image.situ_tile import Tile +import numpy as np + from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction -import numpy as np class RoundRegistration(Registration): - def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0): - """This method generates a round registration transformation for a tile and saves it in the tile. + def do_round_registration(self, + situ_tile, + reference_round: int = 0, + reference_channel: int = 0): + """This method generates a round registration transformation for a tile + and saves it in the tile. Args: situ_tile (Tile): The tile that the transformation is to be performed on. - reference_round (int, optional): The round that is referenced and will not be changed. Defaults to 0. - reference_channel (int, optional): The channel tha is used to compare rounds. Defaults to 0. + reference_round (int, optional): The round that is referenced and will not be changed. + Defaults to 0. + reference_channel (int, optional): The channel that is used to compare rounds. + Defaults to 0. """ - # TODO: instead of one reference channel use all channels (maybe without nucleus channel) reference_peaks = self.peak_finder.get_channel_peaks(situ_tile.get_round( reference_round), reference_channel) for round in range(situ_tile.get_round_count()): @@ -26,13 +33,21 @@ class RoundRegistration(Registration): class AllChannelRoundRegistration(RoundRegistration): + """This class perofrms a round registration using all channels instead of just the reference + channel. It inherits from RoundRegistration. + """ - def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0): - """This method generates a round registration transformation for a tile and saves it in the tile. + def do_round_registration(self, + situ_tile: Tile, + reference_round: int = 0, + reference_channel: int = 0): + """This method generates a round registration transformation for a tile and saves it in + the tile. Args: situ_tile (Tile): The tile that the transformation is to be performed on. - reference_round (int, optional): The round that is referenced and will not be changed. Defaults to 0. + reference_round (int, optional): The round that is referenced and will not be changed. + Defaults to 0. reference_channel (int, optional): This parameter is ignored. """ reference_peaks = [] diff --git a/situr/registration/tile_registration.py b/situr/registration/tile_registration.py index 99e8820..4ce8b8d 100644 --- a/situr/registration/tile_registration.py +++ b/situr/registration/tile_registration.py @@ -3,9 +3,19 @@ from situr.registration import RoundRegistration, ChannelRegistration, round_reg class CombinedRegistration: - def __init__(self, round_registration: RoundRegistration = RoundRegistration(), channel_registration: ChannelRegistration = ChannelRegistration(), reference_channel=0) -> None: + """CombinedRegistration is a registration that performs a channel and a round transformaton + after each other. Also the transformations are directly applied after each registration. + """ + + def __init__(self, + round_registration: RoundRegistration = RoundRegistration(), + channel_registration: ChannelRegistration = ChannelRegistration(), + reference_channel: int = 0, + reference_round: int = 0) -> None: self.round_registration = round_registration self.channel_registration = channel_registration + self.reference_channel = reference_channel + self.reference_round = reference_round def do_registration_and_transform(self, tile: Tile): """ This function applies the registration in the following order: @@ -17,10 +27,13 @@ class CombinedRegistration: Args: tile (Tile): The tile that the registration and transformations are to be performed on. """ - self.channel_registration.do_channel_registration(tile) + self.channel_registration.do_channel_registration( + tile, self.reference_channel) tile.apply_channel_transformations() - self.round_registration.do_round_registration(tile) + self.round_registration.do_round_registration(tile, + self.reference_round, + self.reference_channel) tile.apply_round_transformations()