diff --git a/situr/image/situ_image.py b/situr/image/situ_image.py index a4b7b8f..1fc1b49 100644 --- a/situr/image/situ_image.py +++ b/situr/image/situ_image.py @@ -1,4 +1,5 @@ import abc +from situr.transformation.transformation import Transform import numpy as np from PIL import Image, ImageDraw from skimage import img_as_float @@ -6,7 +7,7 @@ from skimage.feature import blob_dog from typing import List -from situr.transformation.channel_transformation import ChannelTransform, IdentityChannelTransform +from situr.transformation import Transform, IdentityTransform def extend_dim(array: np.ndarray): @@ -63,7 +64,7 @@ class SituImage: self.data = None self.nucleaus_channel = nucleaus_channel self.channel_transformations = [ - IdentityChannelTransform() for file in file_list + IdentityTransform() for file in file_list ] self.peak_finder = peak_finder @@ -74,9 +75,17 @@ class SituImage: def apply_transformations(self): for i, transformation in enumerate(self.channel_transformations): - transformation.apply_transformation(self, i) + for focus_level in range(self.get_focus_level_count()): + img = self.get_focus_level(i, focus_level) + transformation.apply_tranformation(img) - def set_channel_transformation(self, channel: int, transformation: ChannelTransform): + def apply_transform_to_whole_image(self, transform: Transform): + for channel in range(self.get_channel_count()): + for focus_level in range(self.get_focus_level_count()): + img = self.get_focus_level(channel, focus_level) + transform.apply_tranformation(img) + + def set_channel_transformation(self, channel: int, transformation: Transform): self.channel_transformations[channel] = transformation def get_channel_count(self) -> int: diff --git a/situr/image/situ_tile.py b/situr/image/situ_tile.py index ba97095..e1e2796 100644 --- a/situr/image/situ_tile.py +++ b/situr/image/situ_tile.py @@ -1,8 +1,7 @@ -from situr.transformation.round_transformation import RoundTransform +from situr.transformation import Transform, IdentityTransform import numpy as np from situr.image.situ_image import SituImage -from situr.transformation import IdentityRoundTransform from typing import List @@ -22,7 +21,7 @@ class Tile: for situ_image_list in file_list: self.images.append( SituImage(situ_image_list, nucleaus_channel=nucleaus_channel)) - self.round_transformations.append(IdentityRoundTransform()) + self.round_transformations.append(IdentityTransform()) def apply_transformations(self): # first apply channel transformations then round transformations @@ -35,9 +34,9 @@ class Tile: def apply_round_transformations(self): for round, transformation in enumerate(self.round_transformations): - transformation.apply_tranformation(self, round) + self.images[round].apply_transform_to_whole_image(transformation) - def set_round_transformation(self, round, transformation: RoundTransform): + def set_round_transformation(self, round, transformation: Transform): self.round_transformations[round] = transformation def get_round_count(self) -> int: diff --git a/situr/registration/__init__.py b/situr/registration/__init__.py index 1d11602..421fc2a 100644 --- a/situr/registration/__init__.py +++ b/situr/registration/__init__.py @@ -1,4 +1,4 @@ from .registration import Registration, RegistrationFunction, FilterregRegistrationFunction from .channel_registration import ChannelRegistration from .round_registration import RoundRegistration -from .tile_registration import TileRegistration +from .tile_registration import CombinedRegistration diff --git a/situr/registration/channel_registration.py b/situr/registration/channel_registration.py index 8e356ba..fc2047d 100644 --- a/situr/registration/channel_registration.py +++ b/situr/registration/channel_registration.py @@ -1,11 +1,14 @@ from situr.image.situ_image import SituImage -from situr.transformation.channel_transformation import ChannelTransform from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction -from situr.transformation import ChannelTransform, ScaleRotateTranslateChannelTransform class ChannelRegistration(Registration): - def __init__(self, registration_function: RegistrationFunction[ChannelTransform] = FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform)): + def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()): + """Initialize channel registration and tell which registration function to use. + + Args: + registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform). + """ super().__init__(registration_function) def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0): diff --git a/situr/registration/registration.py b/situr/registration/registration.py index c12f7a9..f4db766 100644 --- a/situr/registration/registration.py +++ b/situr/registration/registration.py @@ -1,26 +1,22 @@ import abc -from situr.transformation.channel_transformation import ChannelTransform -from situr.transformation.round_transformation import RoundTransform import open3d as o3 from probreg import filterreg import numpy as np from situr.image import extend_dim -from situr.transformation import Transform +from situr.transformation import Transform, ScaleRotateTranslateTransform + class RegistrationFunction: __metaclass__ = abc.ABCMeta - def __init__(self, transormation_type: Transform): - self.transormation_type = transormation_type - @abc.abstractmethod - def do_registration(self, data_peaks, reference_peaks): + def do_registration(self, data_peaks, reference_peaks) -> Transform: raise NotImplementedError(self.__class__.__name__ + '.do_registration') class FilterregRegistrationFunction(RegistrationFunction): - def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> Transform: + def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform: source = o3.geometry.PointCloud() source.points = o3.utility.Vector3dVector(extend_dim(data_peaks)) target = o3.geometry.PointCloud() @@ -29,10 +25,11 @@ class FilterregRegistrationFunction(RegistrationFunction): registration_method = filterreg.registration_filterreg tf_param, _, _ = filterreg.registration_filterreg(source, target) - return self.transormation_type(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2]) + return ScaleRotateTranslateTransform(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2]) class Registration: __metaclass__ = abc.ABCMeta + def __init__(self, registration_function: RegistrationFunction): self.registration_function = registration_function diff --git a/situr/registration/round_registration.py b/situr/registration/round_registration.py index 38cccb5..f5519eb 100644 --- a/situr/registration/round_registration.py +++ b/situr/registration/round_registration.py @@ -1,13 +1,16 @@ from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction -from situr.transformation import RoundTransform, ScaleRotateTranslateRoundTransform -from situr.image import Tile class RoundRegistration(Registration): - def __init__(self, registration_function: RegistrationFunction[RoundTransform] = FilterregRegistrationFunction(ScaleRotateTranslateRoundTransform)): + def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()): + """Initialize round registration and tell which registration function to use. + + Args: + registration_function (RegistrationFunction[RoundTransform], optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform). + """ super().__init__(registration_function) - def do_round_registration(self, situ_tile: Tile, reference_round: int = 0, reference_channel: int = 0): + def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0): """This method generates a round registration transformation for a tile and saves it in the tile. Args: @@ -17,11 +20,11 @@ class RoundRegistration(Registration): """ # TODO: instead of one reference channel use all channels (maybe without nucleus channel) - reference_peaks = situ_tile.get_image_round( + reference_peaks = situ_tile.get_round( reference_round).get_channel_peaks(reference_channel) - for round in range(situ_tile.get_roundcount()): + for round in range(situ_tile.get_round_count()): if round != reference_channel: - current_round_peaks = situ_tile.get_image_round( + current_round_peaks = situ_tile.get_round( round ).get_channel_peaks(reference_channel) transformation = self.registration_function.do_registration( diff --git a/situr/registration/tile_registration.py b/situr/registration/tile_registration.py index 5d54ed0..507b107 100644 --- a/situr/registration/tile_registration.py +++ b/situr/registration/tile_registration.py @@ -18,12 +18,12 @@ class CombinedRegistration: tile (Tile): The tile that the registration and transformations are to be performed on. """ # Do channel registration - for round in range(tile.get_roundcount()): - img = tile.get_image_round(round) + for round in range(tile.get_round_count()): + img = tile.get_round(round) self.channel_registration tile.apply_channel_transformations() - round_registration.do_round_registration(tile) + self.round_registration.do_round_registration(tile) tile.apply_round_transformations() diff --git a/situr/transformation/__init__.py b/situr/transformation/__init__.py index 899e19a..badd4c0 100644 --- a/situr/transformation/__init__.py +++ b/situr/transformation/__init__.py @@ -1,3 +1 @@ -from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform -from .round_transformation import RoundTransform, IdentityRoundTransform, ScaleRotateTranslateRoundTransform -from .transformation import Transform +from .transformation import Transform, IdentityTransform, ScaleRotateTranslateTransform diff --git a/situr/transformation/channel_transformation.py b/situr/transformation/channel_transformation.py deleted file mode 100644 index 48e0e5d..0000000 --- a/situr/transformation/channel_transformation.py +++ /dev/null @@ -1,50 +0,0 @@ -import abc -from situr.image.situ_image import SituImage -import numpy as np -import scipy - -from situr.transformation import Transform - - -class ChannelTransform(Transform): - __metaclass__ = abc.ABCMeta - - @abc.abstractmethod - def apply_transformation(self, situ_img: SituImage, channel: int): - """Performs a transformation on one channel, all focus_levels are transformed the same way""" - raise NotImplementedError( - self.__class__.__name__ + '.apply_transformation') - - -class IdentityChannelTransform(ChannelTransform): - def apply_transformation(self, situ_img: SituImage, channel: int): - pass - - -class ScaleRotateTranslateChannelTransform(ChannelTransform): - def __init__(self, transform_matrix: np.ndarray, scale: float = 1, offset: np.ndarray = np.array([0, 0])): - """Constructor for a Transformation that supports rotation, translation and scaling on a channel - - Args: - transform_matrix (np.ndarray): A matrix of shape (2,2) - scale (float, optional): The scale factor. Defaults to 1. - offset (np.ndarray, optional): The offset of shape (2,). Defaults to np.array([0, 0]). - """ - # TODO: check - # * transform matrix is 2x2 - # * offset is array (2,) - self.transform_matrix = transform_matrix - self.offset = offset - self.scale = scale - - def apply_tranformation(self, situ_img: SituImage, channel: int): - channel_img = situ_img.get_channel(channel) - focus_levels = channel_img.shape[0] - - for focus_level in range(focus_levels): - img = channel_img[focus_level, :, :] - - img[:, :] = scipy.ndimage.affine_transform( - img, self.transform_matrix) - img[:, :] = scipy.ndimage.zoom(img, self.scale) - img[:, :] = scipy.ndimage.shift(img, self.offset) diff --git a/situr/transformation/round_transformation.py b/situr/transformation/round_transformation.py deleted file mode 100644 index ea13ace..0000000 --- a/situr/transformation/round_transformation.py +++ /dev/null @@ -1,62 +0,0 @@ -import abc -from situr.image.situ_tile import Tile -import scipy -import numpy as np -from situr.image import situ_image -from situr.transformation import Transform - - -class RoundTransform(Transform): - __metaclass__ = abc.ABCMeta - - @abc.abstractmethod - def apply_transformation(self, situ_tile: Tile, round: int): - """Performs a transformation on one round, all channels and focus_levels are transformed the same way - - Args: - situ_tile (Tile): The tile the transformation is applied to. - round (int): The round that the transformation is to be applied to. - - Raises: - NotImplementedError: This method is abstract and therefore raises an error - """ - raise NotImplementedError( - self.__class__.__name__ + '.apply_transformation') - - -class IdentityRoundTransform(RoundTransform): - def apply_transformation(self, situ_tile: Tile, round: Tile): - """Performs the identity transformation (meaning no transformation) - Args: - situ_tile (Tile): The tile the transformation is applied to. - round (Tile): The round that the transformation is to be applied to. - """ - pass - - -class ScaleRotateTranslateRoundTransform(RoundTransform): - def __init__(self, transform_matrix: np.ndarray, scale: int = 1, offset: np.array = np.array([0, 0])): - """Constructor for a Transformation that supports rotation, translation and scaling on a channel - - Args: - transform_matrix (np.ndarray): A matrix of shape (2,2) - scale (int, optional): The scale factor. Defaults to 1. - offset (np.array, optional): The offset of shape (2,). Defaults to np.array([0, 0]). - """ - # TODO: check - # * transform matrix is 2x2 - # * offset is array (2,) - self.transform_matrix = transform_matrix - self.offset = offset - self.scale = scale - - def apply_tranformation(self, situ_tile: Tile, round: int): - situ_image = situ_tile.get_image_round(round) - - for channel in range(situ_image.get_channel_count()): - for focus_level in range(situ_image.get_focus_level_count()): - img = situ_image.get_focus_level(channel, focus_level) - img[:, :] = scipy.ndimage.affine_transform( - img, self.transform_matrix) - img[:, :] = scipy.ndimage.zoom(img, self.scale) - img[:, :] = scipy.ndimage.shift(img, self.offset) diff --git a/situr/transformation/transformation.py b/situr/transformation/transformation.py index f6e953b..f2c11f3 100644 --- a/situr/transformation/transformation.py +++ b/situr/transformation/transformation.py @@ -1,4 +1,39 @@ import abc +import numpy as np +import scipy class Transform: __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def apply_tranformation(self, img: np.ndarray) -> np.ndarray: + raise NotImplementedError( + self.__class__.__name__ + '.apply_transformation') + + +class IdentityTransform(Transform): + def apply_tranformation(self, img: np.ndarray) -> np.ndarray: + return img + + +class ScaleRotateTranslateTransform(Transform): + def __init__(self, transform_matrix: np.ndarray, scale: int = 1, offset: np.array = np.array([0, 0])): + """Constructor for a Transformation that supports rotation, translation and scaling on an image + + Args: + transform_matrix (np.ndarray): A matrix of shape (2,2) + scale (int, optional): The scale factor. Defaults to 1. + offset (np.array, optional): The offset of shape (2,). Defaults to np.array([0, 0]). + """ + # TODO: check + # * transform matrix is 2x2 + # * offset is array (2,) + self.transform_matrix = transform_matrix + self.offset = offset + self.scale = scale + + def apply_tranformation(self, img: np.ndarray) -> np.ndarray: + img[:, :] = scipy.ndimage.affine_transform( + img, self.transform_matrix) + img[:, :] = scipy.ndimage.zoom(img, self.scale) + img[:, :] = scipy.ndimage.shift(img, self.offset)