implement registration fo rounds

This commit is contained in:
2021-07-13 13:18:41 +02:00
parent c2c39dc462
commit 1d0385e2f3
8 changed files with 78 additions and 15 deletions

View File

@@ -1,2 +1,3 @@
from .situ_image import extend_dim, remove_dim from .situ_image import extend_dim, remove_dim
from .situ_image import SituImage, PeakFinderDifferenceOfGaussian from .situ_image import SituImage, PeakFinderDifferenceOfGaussian
from .situ_tile import Tile

View File

@@ -69,6 +69,20 @@ class SituImage:
def get_channel_count(self): def get_channel_count(self):
return self.get_data().shape[0] return self.get_data().shape[0]
def get_focus_level_count(self):
return self.get_data().shape[1]
def get_focus_level(self, channel, focus_level):
"""Loads channel and focus level of an image.
Args:
channel (int): The channel to be used
focus_level (int): The focus level to be used
Returns:
numpy.array: The loaded image of shape (width, height)
"""
return self.get_data()[channel, focus_level, :, :]
def get_channel(self, channel): def get_channel(self, channel):
''' '''
Loads and returns the specified channel for all focus_levels. Loads and returns the specified channel for all focus_levels.

View File

@@ -25,6 +25,9 @@ class Tile:
# TODO: implement (first apply channel transformations then round transformations) # TODO: implement (first apply channel transformations then round transformations)
pass pass
def get_image_round(self, round):
return self.images[round]
def set_round_transformation(self, round, transformation): def set_round_transformation(self, round, transformation):
self.round_transformations[round] = transformation self.round_transformations[round] = transformation
@@ -32,7 +35,7 @@ class Tile:
return len(self.images) return len(self.images)
def get_channel_count(self): def get_channel_count(self):
return self.images.get_channel_count(self) return self.images[0].get_channel_count()
def get_round(self, round_number): def get_round(self, round_number):
"""This """This

View File

@@ -1,8 +1,10 @@
from situr.registration import Registration from situr.registration import Registration, FilterregRegistrationFunction
from situr.transformation import IdentityChannelTransform from situr.transformation import ScaleRotateTranslateChannelTransform
class ChannelRegistration(Registration): class ChannelRegistration(Registration):
def __init__(self, registration_function=FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform)):
super().__init__(registration_function)
def do_channel_registration(self, situ_img, reference_channel=0): def do_channel_registration(self, situ_img, reference_channel=0):
# For each channel (except nucleus) compute transform compared to reference_channel # For each channel (except nucleus) compute transform compared to reference_channel
# Add Channel transformation to Channel # Add Channel transformation to Channel

View File

@@ -3,13 +3,12 @@ import open3d as o3
from probreg import filterreg from probreg import filterreg
from situr.image import extend_dim from situr.image import extend_dim
from situr.transformation import ScaleRotateTranslateChannelTransform
class RegistrationFunction: class RegistrationFunction:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, transormation_type=ScaleRotateTranslateChannelTransform): def __init__(self, transormation_type):
self.transormation_type = transormation_type self.transormation_type = transormation_type
@abc.abstractmethod @abc.abstractmethod

View File

@@ -1,6 +1,26 @@
class RoundRegistration: from situr.registration import Registration, FilterregRegistrationFunction
def do_round_registration(self, situ_tile, reference_channel=0): from situr.transformation import ScaleRotateTranslateRoundTransform
# For each channel (except nucleus) compute transform compared to reference_channel
# Add Channel transformation to Channel
# TODO: implement class RoundRegistration(Registration):
pass def __init__(self, registration_function=FilterregRegistrationFunction(ScaleRotateTranslateRoundTransform)):
super().__init__(registration_function)
def do_round_registration(self, situ_tile, reference_round=0, reference_channel=0):
"""This method generates a round registration transformation for a tile and saves it in the tile.
Args:
situ_tile (Tile): The tile that the transformation is to be performed on.
reference_round (int, optional): The round that is referenced and will not be changed. Defaults to 0.
reference_channel (int, optional): The channel tha is used to compare rounds. Defaults to 0.
"""
reference_peaks = situ_tile.get_image_round(
reference_round).get_channel_peaks(reference_channel)
for round in range(situ_tile.get_roundcount()):
if round != reference_channel:
current_round_peaks = situ_tile.get_image_round(
round
).get_channel_peaks(reference_channel)
transformation = self.registration_function.do_registration(
current_round_peaks, reference_peaks)
situ_tile.set_round_transformation(round, transformation)

View File

@@ -1,2 +1,2 @@
from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform
from .round_transformation import RoundTransform, IdentityRoundTransform from .round_transformation import RoundTransform, IdentityRoundTransform, ScaleRotateTranslateRoundTransform

View File

@@ -1,16 +1,40 @@
import abc import abc
import scipy
import numpy as np
from situr.image import situ_image
class RoundTransform: class RoundTransform:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
@abc.abstractmethod @abc.abstractmethod
def apply_transformation(self, situ_tile, channel): def apply_transformation(self, situ_tile, round):
"""Performs a transformation on one channel, all focus_levels are transformed the same way""" """Performs a transformation on one round, all channels and focus_levels are transformed the same way"""
raise NotImplementedError( raise NotImplementedError(
self.__class__.__name__ + '.apply_transformation') self.__class__.__name__ + '.apply_transformation')
class IdentityRoundTransform(RoundTransform): class IdentityRoundTransform(RoundTransform):
def apply_transformation(self, situ_tile, channel): def apply_transformation(self, situ_tile, round):
pass pass
class ScaleRotateTranslateRoundTransform(RoundTransform):
def __init__(self, transform_matrix, scale=1, offset=np.array([0, 0])):
# TODO: check
# * transform matrix is 2x2
# * offset is array (2,)
self.transform_matrix = transform_matrix
self.offset = offset
self.scale = scale
def apply_tranformation(self, situ_tile, round):
situ_image = situ_tile.get_image_round(round)
for channel in range(situ_image.get_channel_count()):
for focus_level in range(situ_image.get_focus_level_count()):
img = situ_image.get_focus_level(channel, focus_level)
img[:, :] = scipy.ndimage.affine_transform(
img, self.transform_matrix)
img[:, :] = scipy.ndimage.zoom(img, self.scale)
img[:, :] = scipy.ndimage.shift(img, self.offset)