From 1d0385e2f3d858ced3c1cdb9964b96193572c72b Mon Sep 17 00:00:00 2001 From: "Hannes F. Kuchelmeister" Date: Tue, 13 Jul 2021 13:18:41 +0200 Subject: [PATCH] implement registration fo rounds --- situr/image/__init__.py | 1 + situr/image/situ_image.py | 14 +++++++++ situr/image/situ_tile.py | 5 ++- situr/registration/channel_registration.py | 6 ++-- situr/registration/registration.py | 3 +- situr/registration/round_registration.py | 32 ++++++++++++++++---- situr/transformation/__init__.py | 2 +- situr/transformation/round_transformation.py | 30 ++++++++++++++++-- 8 files changed, 78 insertions(+), 15 deletions(-) diff --git a/situr/image/__init__.py b/situr/image/__init__.py index 06cc6a9..628c2a9 100644 --- a/situr/image/__init__.py +++ b/situr/image/__init__.py @@ -1,2 +1,3 @@ from .situ_image import extend_dim, remove_dim from .situ_image import SituImage, PeakFinderDifferenceOfGaussian +from .situ_tile import Tile diff --git a/situr/image/situ_image.py b/situr/image/situ_image.py index 9350e26..5df6465 100644 --- a/situr/image/situ_image.py +++ b/situr/image/situ_image.py @@ -69,6 +69,20 @@ class SituImage: def get_channel_count(self): return self.get_data().shape[0] + def get_focus_level_count(self): + return self.get_data().shape[1] + + def get_focus_level(self, channel, focus_level): + """Loads channel and focus level of an image. + + Args: + channel (int): The channel to be used + focus_level (int): The focus level to be used + Returns: + numpy.array: The loaded image of shape (width, height) + """ + return self.get_data()[channel, focus_level, :, :] + def get_channel(self, channel): ''' Loads and returns the specified channel for all focus_levels. diff --git a/situr/image/situ_tile.py b/situr/image/situ_tile.py index 353836e..131ef6d 100644 --- a/situr/image/situ_tile.py +++ b/situr/image/situ_tile.py @@ -25,6 +25,9 @@ class Tile: # TODO: implement (first apply channel transformations then round transformations) pass + def get_image_round(self, round): + return self.images[round] + def set_round_transformation(self, round, transformation): self.round_transformations[round] = transformation @@ -32,7 +35,7 @@ class Tile: return len(self.images) def get_channel_count(self): - return self.images.get_channel_count(self) + return self.images[0].get_channel_count() def get_round(self, round_number): """This diff --git a/situr/registration/channel_registration.py b/situr/registration/channel_registration.py index 61af21e..35fdd8a 100644 --- a/situr/registration/channel_registration.py +++ b/situr/registration/channel_registration.py @@ -1,8 +1,10 @@ -from situr.registration import Registration -from situr.transformation import IdentityChannelTransform +from situr.registration import Registration, FilterregRegistrationFunction +from situr.transformation import ScaleRotateTranslateChannelTransform class ChannelRegistration(Registration): + def __init__(self, registration_function=FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform)): + super().__init__(registration_function) def do_channel_registration(self, situ_img, reference_channel=0): # For each channel (except nucleus) compute transform compared to reference_channel # Add Channel transformation to Channel diff --git a/situr/registration/registration.py b/situr/registration/registration.py index 19cc00c..a5c4626 100644 --- a/situr/registration/registration.py +++ b/situr/registration/registration.py @@ -3,13 +3,12 @@ import open3d as o3 from probreg import filterreg from situr.image import extend_dim -from situr.transformation import ScaleRotateTranslateChannelTransform class RegistrationFunction: __metaclass__ = abc.ABCMeta - def __init__(self, transormation_type=ScaleRotateTranslateChannelTransform): + def __init__(self, transormation_type): self.transormation_type = transormation_type @abc.abstractmethod diff --git a/situr/registration/round_registration.py b/situr/registration/round_registration.py index 3b0cb8f..01c8681 100644 --- a/situr/registration/round_registration.py +++ b/situr/registration/round_registration.py @@ -1,6 +1,26 @@ -class RoundRegistration: - def do_round_registration(self, situ_tile, reference_channel=0): - # For each channel (except nucleus) compute transform compared to reference_channel - # Add Channel transformation to Channel - # TODO: implement - pass +from situr.registration import Registration, FilterregRegistrationFunction +from situr.transformation import ScaleRotateTranslateRoundTransform + + +class RoundRegistration(Registration): + def __init__(self, registration_function=FilterregRegistrationFunction(ScaleRotateTranslateRoundTransform)): + super().__init__(registration_function) + + def do_round_registration(self, situ_tile, reference_round=0, reference_channel=0): + """This method generates a round registration transformation for a tile and saves it in the tile. + + Args: + situ_tile (Tile): The tile that the transformation is to be performed on. + reference_round (int, optional): The round that is referenced and will not be changed. Defaults to 0. + reference_channel (int, optional): The channel tha is used to compare rounds. Defaults to 0. + """ + reference_peaks = situ_tile.get_image_round( + reference_round).get_channel_peaks(reference_channel) + for round in range(situ_tile.get_roundcount()): + if round != reference_channel: + current_round_peaks = situ_tile.get_image_round( + round + ).get_channel_peaks(reference_channel) + transformation = self.registration_function.do_registration( + current_round_peaks, reference_peaks) + situ_tile.set_round_transformation(round, transformation) diff --git a/situr/transformation/__init__.py b/situr/transformation/__init__.py index fd0d4c7..505f82e 100644 --- a/situr/transformation/__init__.py +++ b/situr/transformation/__init__.py @@ -1,2 +1,2 @@ from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform -from .round_transformation import RoundTransform, IdentityRoundTransform +from .round_transformation import RoundTransform, IdentityRoundTransform, ScaleRotateTranslateRoundTransform diff --git a/situr/transformation/round_transformation.py b/situr/transformation/round_transformation.py index 245f691..9e52b1b 100644 --- a/situr/transformation/round_transformation.py +++ b/situr/transformation/round_transformation.py @@ -1,16 +1,40 @@ import abc +import scipy +import numpy as np +from situr.image import situ_image class RoundTransform: __metaclass__ = abc.ABCMeta @abc.abstractmethod - def apply_transformation(self, situ_tile, channel): - """Performs a transformation on one channel, all focus_levels are transformed the same way""" + def apply_transformation(self, situ_tile, round): + """Performs a transformation on one round, all channels and focus_levels are transformed the same way""" raise NotImplementedError( self.__class__.__name__ + '.apply_transformation') class IdentityRoundTransform(RoundTransform): - def apply_transformation(self, situ_tile, channel): + def apply_transformation(self, situ_tile, round): pass + + +class ScaleRotateTranslateRoundTransform(RoundTransform): + def __init__(self, transform_matrix, scale=1, offset=np.array([0, 0])): + # TODO: check + # * transform matrix is 2x2 + # * offset is array (2,) + self.transform_matrix = transform_matrix + self.offset = offset + self.scale = scale + + def apply_tranformation(self, situ_tile, round): + situ_image = situ_tile.get_image_round(round) + + for channel in range(situ_image.get_channel_count()): + for focus_level in range(situ_image.get_focus_level_count()): + img = situ_image.get_focus_level(channel, focus_level) + img[:, :] = scipy.ndimage.affine_transform( + img, self.transform_matrix) + img[:, :] = scipy.ndimage.zoom(img, self.scale) + img[:, :] = scipy.ndimage.shift(img, self.offset)