rework transformations to resolve circular dependencies

This commit is contained in:
2021-07-15 09:44:40 +02:00
parent 8ff57e9c8d
commit b70b5a6a67
11 changed files with 79 additions and 147 deletions

View File

@@ -1,4 +1,5 @@
import abc import abc
from situr.transformation.transformation import Transform
import numpy as np import numpy as np
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from skimage import img_as_float from skimage import img_as_float
@@ -6,7 +7,7 @@ from skimage.feature import blob_dog
from typing import List from typing import List
from situr.transformation.channel_transformation import ChannelTransform, IdentityChannelTransform from situr.transformation import Transform, IdentityTransform
def extend_dim(array: np.ndarray): def extend_dim(array: np.ndarray):
@@ -63,7 +64,7 @@ class SituImage:
self.data = None self.data = None
self.nucleaus_channel = nucleaus_channel self.nucleaus_channel = nucleaus_channel
self.channel_transformations = [ self.channel_transformations = [
IdentityChannelTransform() for file in file_list IdentityTransform() for file in file_list
] ]
self.peak_finder = peak_finder self.peak_finder = peak_finder
@@ -74,9 +75,17 @@ class SituImage:
def apply_transformations(self): def apply_transformations(self):
for i, transformation in enumerate(self.channel_transformations): for i, transformation in enumerate(self.channel_transformations):
transformation.apply_transformation(self, i) for focus_level in range(self.get_focus_level_count()):
img = self.get_focus_level(i, focus_level)
transformation.apply_tranformation(img)
def set_channel_transformation(self, channel: int, transformation: ChannelTransform): def apply_transform_to_whole_image(self, transform: Transform):
for channel in range(self.get_channel_count()):
for focus_level in range(self.get_focus_level_count()):
img = self.get_focus_level(channel, focus_level)
transform.apply_tranformation(img)
def set_channel_transformation(self, channel: int, transformation: Transform):
self.channel_transformations[channel] = transformation self.channel_transformations[channel] = transformation
def get_channel_count(self) -> int: def get_channel_count(self) -> int:

View File

@@ -1,8 +1,7 @@
from situr.transformation.round_transformation import RoundTransform from situr.transformation import Transform, IdentityTransform
import numpy as np import numpy as np
from situr.image.situ_image import SituImage from situr.image.situ_image import SituImage
from situr.transformation import IdentityRoundTransform
from typing import List from typing import List
@@ -22,7 +21,7 @@ class Tile:
for situ_image_list in file_list: for situ_image_list in file_list:
self.images.append( self.images.append(
SituImage(situ_image_list, nucleaus_channel=nucleaus_channel)) SituImage(situ_image_list, nucleaus_channel=nucleaus_channel))
self.round_transformations.append(IdentityRoundTransform()) self.round_transformations.append(IdentityTransform())
def apply_transformations(self): def apply_transformations(self):
# first apply channel transformations then round transformations # first apply channel transformations then round transformations
@@ -35,9 +34,9 @@ class Tile:
def apply_round_transformations(self): def apply_round_transformations(self):
for round, transformation in enumerate(self.round_transformations): for round, transformation in enumerate(self.round_transformations):
transformation.apply_tranformation(self, round) self.images[round].apply_transform_to_whole_image(transformation)
def set_round_transformation(self, round, transformation: RoundTransform): def set_round_transformation(self, round, transformation: Transform):
self.round_transformations[round] = transformation self.round_transformations[round] = transformation
def get_round_count(self) -> int: def get_round_count(self) -> int:

View File

@@ -1,4 +1,4 @@
from .registration import Registration, RegistrationFunction, FilterregRegistrationFunction from .registration import Registration, RegistrationFunction, FilterregRegistrationFunction
from .channel_registration import ChannelRegistration from .channel_registration import ChannelRegistration
from .round_registration import RoundRegistration from .round_registration import RoundRegistration
from .tile_registration import TileRegistration from .tile_registration import CombinedRegistration

View File

@@ -1,11 +1,14 @@
from situr.image.situ_image import SituImage from situr.image.situ_image import SituImage
from situr.transformation.channel_transformation import ChannelTransform
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
from situr.transformation import ChannelTransform, ScaleRotateTranslateChannelTransform
class ChannelRegistration(Registration): class ChannelRegistration(Registration):
def __init__(self, registration_function: RegistrationFunction[ChannelTransform] = FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform)): def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()):
"""Initialize channel registration and tell which registration function to use.
Args:
registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
"""
super().__init__(registration_function) super().__init__(registration_function)
def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0): def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0):

View File

@@ -1,26 +1,22 @@
import abc import abc
from situr.transformation.channel_transformation import ChannelTransform
from situr.transformation.round_transformation import RoundTransform
import open3d as o3 import open3d as o3
from probreg import filterreg from probreg import filterreg
import numpy as np import numpy as np
from situr.image import extend_dim from situr.image import extend_dim
from situr.transformation import Transform from situr.transformation import Transform, ScaleRotateTranslateTransform
class RegistrationFunction: class RegistrationFunction:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, transormation_type: Transform):
self.transormation_type = transormation_type
@abc.abstractmethod @abc.abstractmethod
def do_registration(self, data_peaks, reference_peaks): def do_registration(self, data_peaks, reference_peaks) -> Transform:
raise NotImplementedError(self.__class__.__name__ + '.do_registration') raise NotImplementedError(self.__class__.__name__ + '.do_registration')
class FilterregRegistrationFunction(RegistrationFunction): class FilterregRegistrationFunction(RegistrationFunction):
def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> Transform: def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform:
source = o3.geometry.PointCloud() source = o3.geometry.PointCloud()
source.points = o3.utility.Vector3dVector(extend_dim(data_peaks)) source.points = o3.utility.Vector3dVector(extend_dim(data_peaks))
target = o3.geometry.PointCloud() target = o3.geometry.PointCloud()
@@ -29,10 +25,11 @@ class FilterregRegistrationFunction(RegistrationFunction):
registration_method = filterreg.registration_filterreg registration_method = filterreg.registration_filterreg
tf_param, _, _ = filterreg.registration_filterreg(source, target) tf_param, _, _ = filterreg.registration_filterreg(source, target)
return self.transormation_type(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2]) return ScaleRotateTranslateTransform(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2])
class Registration: class Registration:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, registration_function: RegistrationFunction): def __init__(self, registration_function: RegistrationFunction):
self.registration_function = registration_function self.registration_function = registration_function

View File

@@ -1,13 +1,16 @@
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
from situr.transformation import RoundTransform, ScaleRotateTranslateRoundTransform
from situr.image import Tile
class RoundRegistration(Registration): class RoundRegistration(Registration):
def __init__(self, registration_function: RegistrationFunction[RoundTransform] = FilterregRegistrationFunction(ScaleRotateTranslateRoundTransform)): def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()):
"""Initialize round registration and tell which registration function to use.
Args:
registration_function (RegistrationFunction[RoundTransform], optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
"""
super().__init__(registration_function) super().__init__(registration_function)
def do_round_registration(self, situ_tile: Tile, reference_round: int = 0, reference_channel: int = 0): def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0):
"""This method generates a round registration transformation for a tile and saves it in the tile. """This method generates a round registration transformation for a tile and saves it in the tile.
Args: Args:
@@ -17,11 +20,11 @@ class RoundRegistration(Registration):
""" """
# TODO: instead of one reference channel use all channels (maybe without nucleus channel) # TODO: instead of one reference channel use all channels (maybe without nucleus channel)
reference_peaks = situ_tile.get_image_round( reference_peaks = situ_tile.get_round(
reference_round).get_channel_peaks(reference_channel) reference_round).get_channel_peaks(reference_channel)
for round in range(situ_tile.get_roundcount()): for round in range(situ_tile.get_round_count()):
if round != reference_channel: if round != reference_channel:
current_round_peaks = situ_tile.get_image_round( current_round_peaks = situ_tile.get_round(
round round
).get_channel_peaks(reference_channel) ).get_channel_peaks(reference_channel)
transformation = self.registration_function.do_registration( transformation = self.registration_function.do_registration(

View File

@@ -18,12 +18,12 @@ class CombinedRegistration:
tile (Tile): The tile that the registration and transformations are to be performed on. tile (Tile): The tile that the registration and transformations are to be performed on.
""" """
# Do channel registration # Do channel registration
for round in range(tile.get_roundcount()): for round in range(tile.get_round_count()):
img = tile.get_image_round(round) img = tile.get_round(round)
self.channel_registration self.channel_registration
tile.apply_channel_transformations() tile.apply_channel_transformations()
round_registration.do_round_registration(tile) self.round_registration.do_round_registration(tile)
tile.apply_round_transformations() tile.apply_round_transformations()

View File

@@ -1,3 +1 @@
from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform from .transformation import Transform, IdentityTransform, ScaleRotateTranslateTransform
from .round_transformation import RoundTransform, IdentityRoundTransform, ScaleRotateTranslateRoundTransform
from .transformation import Transform

View File

@@ -1,50 +0,0 @@
import abc
from situr.image.situ_image import SituImage
import numpy as np
import scipy
from situr.transformation import Transform
class ChannelTransform(Transform):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def apply_transformation(self, situ_img: SituImage, channel: int):
"""Performs a transformation on one channel, all focus_levels are transformed the same way"""
raise NotImplementedError(
self.__class__.__name__ + '.apply_transformation')
class IdentityChannelTransform(ChannelTransform):
def apply_transformation(self, situ_img: SituImage, channel: int):
pass
class ScaleRotateTranslateChannelTransform(ChannelTransform):
def __init__(self, transform_matrix: np.ndarray, scale: float = 1, offset: np.ndarray = np.array([0, 0])):
"""Constructor for a Transformation that supports rotation, translation and scaling on a channel
Args:
transform_matrix (np.ndarray): A matrix of shape (2,2)
scale (float, optional): The scale factor. Defaults to 1.
offset (np.ndarray, optional): The offset of shape (2,). Defaults to np.array([0, 0]).
"""
# TODO: check
# * transform matrix is 2x2
# * offset is array (2,)
self.transform_matrix = transform_matrix
self.offset = offset
self.scale = scale
def apply_tranformation(self, situ_img: SituImage, channel: int):
channel_img = situ_img.get_channel(channel)
focus_levels = channel_img.shape[0]
for focus_level in range(focus_levels):
img = channel_img[focus_level, :, :]
img[:, :] = scipy.ndimage.affine_transform(
img, self.transform_matrix)
img[:, :] = scipy.ndimage.zoom(img, self.scale)
img[:, :] = scipy.ndimage.shift(img, self.offset)

View File

@@ -1,62 +0,0 @@
import abc
from situr.image.situ_tile import Tile
import scipy
import numpy as np
from situr.image import situ_image
from situr.transformation import Transform
class RoundTransform(Transform):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def apply_transformation(self, situ_tile: Tile, round: int):
"""Performs a transformation on one round, all channels and focus_levels are transformed the same way
Args:
situ_tile (Tile): The tile the transformation is applied to.
round (int): The round that the transformation is to be applied to.
Raises:
NotImplementedError: This method is abstract and therefore raises an error
"""
raise NotImplementedError(
self.__class__.__name__ + '.apply_transformation')
class IdentityRoundTransform(RoundTransform):
def apply_transformation(self, situ_tile: Tile, round: Tile):
"""Performs the identity transformation (meaning no transformation)
Args:
situ_tile (Tile): The tile the transformation is applied to.
round (Tile): The round that the transformation is to be applied to.
"""
pass
class ScaleRotateTranslateRoundTransform(RoundTransform):
def __init__(self, transform_matrix: np.ndarray, scale: int = 1, offset: np.array = np.array([0, 0])):
"""Constructor for a Transformation that supports rotation, translation and scaling on a channel
Args:
transform_matrix (np.ndarray): A matrix of shape (2,2)
scale (int, optional): The scale factor. Defaults to 1.
offset (np.array, optional): The offset of shape (2,). Defaults to np.array([0, 0]).
"""
# TODO: check
# * transform matrix is 2x2
# * offset is array (2,)
self.transform_matrix = transform_matrix
self.offset = offset
self.scale = scale
def apply_tranformation(self, situ_tile: Tile, round: int):
situ_image = situ_tile.get_image_round(round)
for channel in range(situ_image.get_channel_count()):
for focus_level in range(situ_image.get_focus_level_count()):
img = situ_image.get_focus_level(channel, focus_level)
img[:, :] = scipy.ndimage.affine_transform(
img, self.transform_matrix)
img[:, :] = scipy.ndimage.zoom(img, self.scale)
img[:, :] = scipy.ndimage.shift(img, self.offset)

View File

@@ -1,4 +1,39 @@
import abc import abc
import numpy as np
import scipy
class Transform: class Transform:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
@abc.abstractmethod
def apply_tranformation(self, img: np.ndarray) -> np.ndarray:
raise NotImplementedError(
self.__class__.__name__ + '.apply_transformation')
class IdentityTransform(Transform):
def apply_tranformation(self, img: np.ndarray) -> np.ndarray:
return img
class ScaleRotateTranslateTransform(Transform):
def __init__(self, transform_matrix: np.ndarray, scale: int = 1, offset: np.array = np.array([0, 0])):
"""Constructor for a Transformation that supports rotation, translation and scaling on an image
Args:
transform_matrix (np.ndarray): A matrix of shape (2,2)
scale (int, optional): The scale factor. Defaults to 1.
offset (np.array, optional): The offset of shape (2,). Defaults to np.array([0, 0]).
"""
# TODO: check
# * transform matrix is 2x2
# * offset is array (2,)
self.transform_matrix = transform_matrix
self.offset = offset
self.scale = scale
def apply_tranformation(self, img: np.ndarray) -> np.ndarray:
img[:, :] = scipy.ndimage.affine_transform(
img, self.transform_matrix)
img[:, :] = scipy.ndimage.zoom(img, self.scale)
img[:, :] = scipy.ndimage.shift(img, self.offset)