mirror of
https://github.com/13hannes11/situr.git
synced 2024-09-03 20:50:58 +02:00
implement registration fo rounds
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
from .situ_image import extend_dim, remove_dim
|
||||
from .situ_image import SituImage, PeakFinderDifferenceOfGaussian
|
||||
from .situ_tile import Tile
|
||||
|
||||
@@ -69,6 +69,20 @@ class SituImage:
|
||||
def get_channel_count(self):
|
||||
return self.get_data().shape[0]
|
||||
|
||||
def get_focus_level_count(self):
|
||||
return self.get_data().shape[1]
|
||||
|
||||
def get_focus_level(self, channel, focus_level):
|
||||
"""Loads channel and focus level of an image.
|
||||
|
||||
Args:
|
||||
channel (int): The channel to be used
|
||||
focus_level (int): The focus level to be used
|
||||
Returns:
|
||||
numpy.array: The loaded image of shape (width, height)
|
||||
"""
|
||||
return self.get_data()[channel, focus_level, :, :]
|
||||
|
||||
def get_channel(self, channel):
|
||||
'''
|
||||
Loads and returns the specified channel for all focus_levels.
|
||||
|
||||
@@ -25,6 +25,9 @@ class Tile:
|
||||
# TODO: implement (first apply channel transformations then round transformations)
|
||||
pass
|
||||
|
||||
def get_image_round(self, round):
|
||||
return self.images[round]
|
||||
|
||||
def set_round_transformation(self, round, transformation):
|
||||
self.round_transformations[round] = transformation
|
||||
|
||||
@@ -32,7 +35,7 @@ class Tile:
|
||||
return len(self.images)
|
||||
|
||||
def get_channel_count(self):
|
||||
return self.images.get_channel_count(self)
|
||||
return self.images[0].get_channel_count()
|
||||
|
||||
def get_round(self, round_number):
|
||||
"""This
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from situr.registration import Registration
|
||||
from situr.transformation import IdentityChannelTransform
|
||||
from situr.registration import Registration, FilterregRegistrationFunction
|
||||
from situr.transformation import ScaleRotateTranslateChannelTransform
|
||||
|
||||
|
||||
class ChannelRegistration(Registration):
|
||||
def __init__(self, registration_function=FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform)):
|
||||
super().__init__(registration_function)
|
||||
def do_channel_registration(self, situ_img, reference_channel=0):
|
||||
# For each channel (except nucleus) compute transform compared to reference_channel
|
||||
# Add Channel transformation to Channel
|
||||
|
||||
@@ -3,13 +3,12 @@ import open3d as o3
|
||||
from probreg import filterreg
|
||||
|
||||
from situr.image import extend_dim
|
||||
from situr.transformation import ScaleRotateTranslateChannelTransform
|
||||
|
||||
|
||||
class RegistrationFunction:
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, transormation_type=ScaleRotateTranslateChannelTransform):
|
||||
def __init__(self, transormation_type):
|
||||
self.transormation_type = transormation_type
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -1,6 +1,26 @@
|
||||
class RoundRegistration:
|
||||
def do_round_registration(self, situ_tile, reference_channel=0):
|
||||
# For each channel (except nucleus) compute transform compared to reference_channel
|
||||
# Add Channel transformation to Channel
|
||||
# TODO: implement
|
||||
pass
|
||||
from situr.registration import Registration, FilterregRegistrationFunction
|
||||
from situr.transformation import ScaleRotateTranslateRoundTransform
|
||||
|
||||
|
||||
class RoundRegistration(Registration):
|
||||
def __init__(self, registration_function=FilterregRegistrationFunction(ScaleRotateTranslateRoundTransform)):
|
||||
super().__init__(registration_function)
|
||||
|
||||
def do_round_registration(self, situ_tile, reference_round=0, reference_channel=0):
|
||||
"""This method generates a round registration transformation for a tile and saves it in the tile.
|
||||
|
||||
Args:
|
||||
situ_tile (Tile): The tile that the transformation is to be performed on.
|
||||
reference_round (int, optional): The round that is referenced and will not be changed. Defaults to 0.
|
||||
reference_channel (int, optional): The channel tha is used to compare rounds. Defaults to 0.
|
||||
"""
|
||||
reference_peaks = situ_tile.get_image_round(
|
||||
reference_round).get_channel_peaks(reference_channel)
|
||||
for round in range(situ_tile.get_roundcount()):
|
||||
if round != reference_channel:
|
||||
current_round_peaks = situ_tile.get_image_round(
|
||||
round
|
||||
).get_channel_peaks(reference_channel)
|
||||
transformation = self.registration_function.do_registration(
|
||||
current_round_peaks, reference_peaks)
|
||||
situ_tile.set_round_transformation(round, transformation)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform
|
||||
from .round_transformation import RoundTransform, IdentityRoundTransform
|
||||
from .round_transformation import RoundTransform, IdentityRoundTransform, ScaleRotateTranslateRoundTransform
|
||||
|
||||
@@ -1,16 +1,40 @@
|
||||
import abc
|
||||
import scipy
|
||||
import numpy as np
|
||||
from situr.image import situ_image
|
||||
|
||||
|
||||
class RoundTransform:
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply_transformation(self, situ_tile, channel):
|
||||
"""Performs a transformation on one channel, all focus_levels are transformed the same way"""
|
||||
def apply_transformation(self, situ_tile, round):
|
||||
"""Performs a transformation on one round, all channels and focus_levels are transformed the same way"""
|
||||
raise NotImplementedError(
|
||||
self.__class__.__name__ + '.apply_transformation')
|
||||
|
||||
|
||||
class IdentityRoundTransform(RoundTransform):
|
||||
def apply_transformation(self, situ_tile, channel):
|
||||
def apply_transformation(self, situ_tile, round):
|
||||
pass
|
||||
|
||||
|
||||
class ScaleRotateTranslateRoundTransform(RoundTransform):
|
||||
def __init__(self, transform_matrix, scale=1, offset=np.array([0, 0])):
|
||||
# TODO: check
|
||||
# * transform matrix is 2x2
|
||||
# * offset is array (2,)
|
||||
self.transform_matrix = transform_matrix
|
||||
self.offset = offset
|
||||
self.scale = scale
|
||||
|
||||
def apply_tranformation(self, situ_tile, round):
|
||||
situ_image = situ_tile.get_image_round(round)
|
||||
|
||||
for channel in range(situ_image.get_channel_count()):
|
||||
for focus_level in range(situ_image.get_focus_level_count()):
|
||||
img = situ_image.get_focus_level(channel, focus_level)
|
||||
img[:, :] = scipy.ndimage.affine_transform(
|
||||
img, self.transform_matrix)
|
||||
img[:, :] = scipy.ndimage.zoom(img, self.scale)
|
||||
img[:, :] = scipy.ndimage.shift(img, self.offset)
|
||||
|
||||
Reference in New Issue
Block a user