mirror of
https://github.com/13hannes11/situr.git
synced 2024-09-03 20:50:58 +02:00
rework transformations to resolve circular dependencies
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import abc
|
import abc
|
||||||
|
from situr.transformation.transformation import Transform
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
from skimage import img_as_float
|
from skimage import img_as_float
|
||||||
@@ -6,7 +7,7 @@ from skimage.feature import blob_dog
|
|||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from situr.transformation.channel_transformation import ChannelTransform, IdentityChannelTransform
|
from situr.transformation import Transform, IdentityTransform
|
||||||
|
|
||||||
|
|
||||||
def extend_dim(array: np.ndarray):
|
def extend_dim(array: np.ndarray):
|
||||||
@@ -63,7 +64,7 @@ class SituImage:
|
|||||||
self.data = None
|
self.data = None
|
||||||
self.nucleaus_channel = nucleaus_channel
|
self.nucleaus_channel = nucleaus_channel
|
||||||
self.channel_transformations = [
|
self.channel_transformations = [
|
||||||
IdentityChannelTransform() for file in file_list
|
IdentityTransform() for file in file_list
|
||||||
]
|
]
|
||||||
self.peak_finder = peak_finder
|
self.peak_finder = peak_finder
|
||||||
|
|
||||||
@@ -74,9 +75,17 @@ class SituImage:
|
|||||||
|
|
||||||
def apply_transformations(self):
|
def apply_transformations(self):
|
||||||
for i, transformation in enumerate(self.channel_transformations):
|
for i, transformation in enumerate(self.channel_transformations):
|
||||||
transformation.apply_transformation(self, i)
|
for focus_level in range(self.get_focus_level_count()):
|
||||||
|
img = self.get_focus_level(i, focus_level)
|
||||||
|
transformation.apply_tranformation(img)
|
||||||
|
|
||||||
def set_channel_transformation(self, channel: int, transformation: ChannelTransform):
|
def apply_transform_to_whole_image(self, transform: Transform):
|
||||||
|
for channel in range(self.get_channel_count()):
|
||||||
|
for focus_level in range(self.get_focus_level_count()):
|
||||||
|
img = self.get_focus_level(channel, focus_level)
|
||||||
|
transform.apply_tranformation(img)
|
||||||
|
|
||||||
|
def set_channel_transformation(self, channel: int, transformation: Transform):
|
||||||
self.channel_transformations[channel] = transformation
|
self.channel_transformations[channel] = transformation
|
||||||
|
|
||||||
def get_channel_count(self) -> int:
|
def get_channel_count(self) -> int:
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from situr.transformation.round_transformation import RoundTransform
|
from situr.transformation import Transform, IdentityTransform
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from situr.image.situ_image import SituImage
|
from situr.image.situ_image import SituImage
|
||||||
from situr.transformation import IdentityRoundTransform
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -22,7 +21,7 @@ class Tile:
|
|||||||
for situ_image_list in file_list:
|
for situ_image_list in file_list:
|
||||||
self.images.append(
|
self.images.append(
|
||||||
SituImage(situ_image_list, nucleaus_channel=nucleaus_channel))
|
SituImage(situ_image_list, nucleaus_channel=nucleaus_channel))
|
||||||
self.round_transformations.append(IdentityRoundTransform())
|
self.round_transformations.append(IdentityTransform())
|
||||||
|
|
||||||
def apply_transformations(self):
|
def apply_transformations(self):
|
||||||
# first apply channel transformations then round transformations
|
# first apply channel transformations then round transformations
|
||||||
@@ -35,9 +34,9 @@ class Tile:
|
|||||||
|
|
||||||
def apply_round_transformations(self):
|
def apply_round_transformations(self):
|
||||||
for round, transformation in enumerate(self.round_transformations):
|
for round, transformation in enumerate(self.round_transformations):
|
||||||
transformation.apply_tranformation(self, round)
|
self.images[round].apply_transform_to_whole_image(transformation)
|
||||||
|
|
||||||
def set_round_transformation(self, round, transformation: RoundTransform):
|
def set_round_transformation(self, round, transformation: Transform):
|
||||||
self.round_transformations[round] = transformation
|
self.round_transformations[round] = transformation
|
||||||
|
|
||||||
def get_round_count(self) -> int:
|
def get_round_count(self) -> int:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .registration import Registration, RegistrationFunction, FilterregRegistrationFunction
|
from .registration import Registration, RegistrationFunction, FilterregRegistrationFunction
|
||||||
from .channel_registration import ChannelRegistration
|
from .channel_registration import ChannelRegistration
|
||||||
from .round_registration import RoundRegistration
|
from .round_registration import RoundRegistration
|
||||||
from .tile_registration import TileRegistration
|
from .tile_registration import CombinedRegistration
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
from situr.image.situ_image import SituImage
|
from situr.image.situ_image import SituImage
|
||||||
from situr.transformation.channel_transformation import ChannelTransform
|
|
||||||
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
|
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
|
||||||
from situr.transformation import ChannelTransform, ScaleRotateTranslateChannelTransform
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelRegistration(Registration):
|
class ChannelRegistration(Registration):
|
||||||
def __init__(self, registration_function: RegistrationFunction[ChannelTransform] = FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform)):
|
def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()):
|
||||||
|
"""Initialize channel registration and tell which registration function to use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
|
||||||
|
"""
|
||||||
super().__init__(registration_function)
|
super().__init__(registration_function)
|
||||||
|
|
||||||
def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0):
|
def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0):
|
||||||
|
|||||||
@@ -1,26 +1,22 @@
|
|||||||
import abc
|
import abc
|
||||||
from situr.transformation.channel_transformation import ChannelTransform
|
|
||||||
from situr.transformation.round_transformation import RoundTransform
|
|
||||||
import open3d as o3
|
import open3d as o3
|
||||||
from probreg import filterreg
|
from probreg import filterreg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from situr.image import extend_dim
|
from situr.image import extend_dim
|
||||||
from situr.transformation import Transform
|
from situr.transformation import Transform, ScaleRotateTranslateTransform
|
||||||
|
|
||||||
|
|
||||||
class RegistrationFunction:
|
class RegistrationFunction:
|
||||||
__metaclass__ = abc.ABCMeta
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
def __init__(self, transormation_type: Transform):
|
|
||||||
self.transormation_type = transormation_type
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_registration(self, data_peaks, reference_peaks):
|
def do_registration(self, data_peaks, reference_peaks) -> Transform:
|
||||||
raise NotImplementedError(self.__class__.__name__ + '.do_registration')
|
raise NotImplementedError(self.__class__.__name__ + '.do_registration')
|
||||||
|
|
||||||
|
|
||||||
class FilterregRegistrationFunction(RegistrationFunction):
|
class FilterregRegistrationFunction(RegistrationFunction):
|
||||||
def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> Transform:
|
def do_registration(self, data_peaks: np.ndarray, reference_peaks: np.ndarray) -> ScaleRotateTranslateTransform:
|
||||||
source = o3.geometry.PointCloud()
|
source = o3.geometry.PointCloud()
|
||||||
source.points = o3.utility.Vector3dVector(extend_dim(data_peaks))
|
source.points = o3.utility.Vector3dVector(extend_dim(data_peaks))
|
||||||
target = o3.geometry.PointCloud()
|
target = o3.geometry.PointCloud()
|
||||||
@@ -29,10 +25,11 @@ class FilterregRegistrationFunction(RegistrationFunction):
|
|||||||
registration_method = filterreg.registration_filterreg
|
registration_method = filterreg.registration_filterreg
|
||||||
tf_param, _, _ = filterreg.registration_filterreg(source, target)
|
tf_param, _, _ = filterreg.registration_filterreg(source, target)
|
||||||
|
|
||||||
return self.transormation_type(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2])
|
return ScaleRotateTranslateTransform(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2])
|
||||||
|
|
||||||
|
|
||||||
class Registration:
|
class Registration:
|
||||||
__metaclass__ = abc.ABCMeta
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
def __init__(self, registration_function: RegistrationFunction):
|
def __init__(self, registration_function: RegistrationFunction):
|
||||||
self.registration_function = registration_function
|
self.registration_function = registration_function
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
|
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
|
||||||
from situr.transformation import RoundTransform, ScaleRotateTranslateRoundTransform
|
|
||||||
from situr.image import Tile
|
|
||||||
|
|
||||||
|
|
||||||
class RoundRegistration(Registration):
|
class RoundRegistration(Registration):
|
||||||
def __init__(self, registration_function: RegistrationFunction[RoundTransform] = FilterregRegistrationFunction(ScaleRotateTranslateRoundTransform)):
|
def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()):
|
||||||
|
"""Initialize round registration and tell which registration function to use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registration_function (RegistrationFunction[RoundTransform], optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
|
||||||
|
"""
|
||||||
super().__init__(registration_function)
|
super().__init__(registration_function)
|
||||||
|
|
||||||
def do_round_registration(self, situ_tile: Tile, reference_round: int = 0, reference_channel: int = 0):
|
def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0):
|
||||||
"""This method generates a round registration transformation for a tile and saves it in the tile.
|
"""This method generates a round registration transformation for a tile and saves it in the tile.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -17,11 +20,11 @@ class RoundRegistration(Registration):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: instead of one reference channel use all channels (maybe without nucleus channel)
|
# TODO: instead of one reference channel use all channels (maybe without nucleus channel)
|
||||||
reference_peaks = situ_tile.get_image_round(
|
reference_peaks = situ_tile.get_round(
|
||||||
reference_round).get_channel_peaks(reference_channel)
|
reference_round).get_channel_peaks(reference_channel)
|
||||||
for round in range(situ_tile.get_roundcount()):
|
for round in range(situ_tile.get_round_count()):
|
||||||
if round != reference_channel:
|
if round != reference_channel:
|
||||||
current_round_peaks = situ_tile.get_image_round(
|
current_round_peaks = situ_tile.get_round(
|
||||||
round
|
round
|
||||||
).get_channel_peaks(reference_channel)
|
).get_channel_peaks(reference_channel)
|
||||||
transformation = self.registration_function.do_registration(
|
transformation = self.registration_function.do_registration(
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ class CombinedRegistration:
|
|||||||
tile (Tile): The tile that the registration and transformations are to be performed on.
|
tile (Tile): The tile that the registration and transformations are to be performed on.
|
||||||
"""
|
"""
|
||||||
# Do channel registration
|
# Do channel registration
|
||||||
for round in range(tile.get_roundcount()):
|
for round in range(tile.get_round_count()):
|
||||||
img = tile.get_image_round(round)
|
img = tile.get_round(round)
|
||||||
self.channel_registration
|
self.channel_registration
|
||||||
|
|
||||||
tile.apply_channel_transformations()
|
tile.apply_channel_transformations()
|
||||||
|
|
||||||
round_registration.do_round_registration(tile)
|
self.round_registration.do_round_registration(tile)
|
||||||
|
|
||||||
tile.apply_round_transformations()
|
tile.apply_round_transformations()
|
||||||
|
|||||||
@@ -1,3 +1 @@
|
|||||||
from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform
|
from .transformation import Transform, IdentityTransform, ScaleRotateTranslateTransform
|
||||||
from .round_transformation import RoundTransform, IdentityRoundTransform, ScaleRotateTranslateRoundTransform
|
|
||||||
from .transformation import Transform
|
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
import abc
|
|
||||||
from situr.image.situ_image import SituImage
|
|
||||||
import numpy as np
|
|
||||||
import scipy
|
|
||||||
|
|
||||||
from situr.transformation import Transform
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelTransform(Transform):
|
|
||||||
__metaclass__ = abc.ABCMeta
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def apply_transformation(self, situ_img: SituImage, channel: int):
|
|
||||||
"""Performs a transformation on one channel, all focus_levels are transformed the same way"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
self.__class__.__name__ + '.apply_transformation')
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityChannelTransform(ChannelTransform):
|
|
||||||
def apply_transformation(self, situ_img: SituImage, channel: int):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleRotateTranslateChannelTransform(ChannelTransform):
|
|
||||||
def __init__(self, transform_matrix: np.ndarray, scale: float = 1, offset: np.ndarray = np.array([0, 0])):
|
|
||||||
"""Constructor for a Transformation that supports rotation, translation and scaling on a channel
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transform_matrix (np.ndarray): A matrix of shape (2,2)
|
|
||||||
scale (float, optional): The scale factor. Defaults to 1.
|
|
||||||
offset (np.ndarray, optional): The offset of shape (2,). Defaults to np.array([0, 0]).
|
|
||||||
"""
|
|
||||||
# TODO: check
|
|
||||||
# * transform matrix is 2x2
|
|
||||||
# * offset is array (2,)
|
|
||||||
self.transform_matrix = transform_matrix
|
|
||||||
self.offset = offset
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def apply_tranformation(self, situ_img: SituImage, channel: int):
|
|
||||||
channel_img = situ_img.get_channel(channel)
|
|
||||||
focus_levels = channel_img.shape[0]
|
|
||||||
|
|
||||||
for focus_level in range(focus_levels):
|
|
||||||
img = channel_img[focus_level, :, :]
|
|
||||||
|
|
||||||
img[:, :] = scipy.ndimage.affine_transform(
|
|
||||||
img, self.transform_matrix)
|
|
||||||
img[:, :] = scipy.ndimage.zoom(img, self.scale)
|
|
||||||
img[:, :] = scipy.ndimage.shift(img, self.offset)
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import abc
|
|
||||||
from situr.image.situ_tile import Tile
|
|
||||||
import scipy
|
|
||||||
import numpy as np
|
|
||||||
from situr.image import situ_image
|
|
||||||
from situr.transformation import Transform
|
|
||||||
|
|
||||||
|
|
||||||
class RoundTransform(Transform):
|
|
||||||
__metaclass__ = abc.ABCMeta
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def apply_transformation(self, situ_tile: Tile, round: int):
|
|
||||||
"""Performs a transformation on one round, all channels and focus_levels are transformed the same way
|
|
||||||
|
|
||||||
Args:
|
|
||||||
situ_tile (Tile): The tile the transformation is applied to.
|
|
||||||
round (int): The round that the transformation is to be applied to.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method is abstract and therefore raises an error
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
self.__class__.__name__ + '.apply_transformation')
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityRoundTransform(RoundTransform):
|
|
||||||
def apply_transformation(self, situ_tile: Tile, round: Tile):
|
|
||||||
"""Performs the identity transformation (meaning no transformation)
|
|
||||||
Args:
|
|
||||||
situ_tile (Tile): The tile the transformation is applied to.
|
|
||||||
round (Tile): The round that the transformation is to be applied to.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleRotateTranslateRoundTransform(RoundTransform):
|
|
||||||
def __init__(self, transform_matrix: np.ndarray, scale: int = 1, offset: np.array = np.array([0, 0])):
|
|
||||||
"""Constructor for a Transformation that supports rotation, translation and scaling on a channel
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transform_matrix (np.ndarray): A matrix of shape (2,2)
|
|
||||||
scale (int, optional): The scale factor. Defaults to 1.
|
|
||||||
offset (np.array, optional): The offset of shape (2,). Defaults to np.array([0, 0]).
|
|
||||||
"""
|
|
||||||
# TODO: check
|
|
||||||
# * transform matrix is 2x2
|
|
||||||
# * offset is array (2,)
|
|
||||||
self.transform_matrix = transform_matrix
|
|
||||||
self.offset = offset
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def apply_tranformation(self, situ_tile: Tile, round: int):
|
|
||||||
situ_image = situ_tile.get_image_round(round)
|
|
||||||
|
|
||||||
for channel in range(situ_image.get_channel_count()):
|
|
||||||
for focus_level in range(situ_image.get_focus_level_count()):
|
|
||||||
img = situ_image.get_focus_level(channel, focus_level)
|
|
||||||
img[:, :] = scipy.ndimage.affine_transform(
|
|
||||||
img, self.transform_matrix)
|
|
||||||
img[:, :] = scipy.ndimage.zoom(img, self.scale)
|
|
||||||
img[:, :] = scipy.ndimage.shift(img, self.offset)
|
|
||||||
@@ -1,4 +1,39 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
|
||||||
class Transform:
|
class Transform:
|
||||||
__metaclass__ = abc.ABCMeta
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def apply_tranformation(self, img: np.ndarray) -> np.ndarray:
|
||||||
|
raise NotImplementedError(
|
||||||
|
self.__class__.__name__ + '.apply_transformation')
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityTransform(Transform):
|
||||||
|
def apply_tranformation(self, img: np.ndarray) -> np.ndarray:
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleRotateTranslateTransform(Transform):
|
||||||
|
def __init__(self, transform_matrix: np.ndarray, scale: int = 1, offset: np.array = np.array([0, 0])):
|
||||||
|
"""Constructor for a Transformation that supports rotation, translation and scaling on an image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transform_matrix (np.ndarray): A matrix of shape (2,2)
|
||||||
|
scale (int, optional): The scale factor. Defaults to 1.
|
||||||
|
offset (np.array, optional): The offset of shape (2,). Defaults to np.array([0, 0]).
|
||||||
|
"""
|
||||||
|
# TODO: check
|
||||||
|
# * transform matrix is 2x2
|
||||||
|
# * offset is array (2,)
|
||||||
|
self.transform_matrix = transform_matrix
|
||||||
|
self.offset = offset
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def apply_tranformation(self, img: np.ndarray) -> np.ndarray:
|
||||||
|
img[:, :] = scipy.ndimage.affine_transform(
|
||||||
|
img, self.transform_matrix)
|
||||||
|
img[:, :] = scipy.ndimage.zoom(img, self.scale)
|
||||||
|
img[:, :] = scipy.ndimage.shift(img, self.offset)
|
||||||
|
|||||||
Reference in New Issue
Block a user