add channel transformations into situ_image class

This commit is contained in:
2021-07-13 10:11:41 +02:00
parent 5bac4eb2bd
commit f447157d35
4 changed files with 28 additions and 3 deletions

View File

@@ -3,6 +3,8 @@ from PIL import Image, ImageDraw
from skimage import img_as_float from skimage import img_as_float
from skimage.feature import blob_dog from skimage.feature import blob_dog
from situr.transformation.channel_transformation import IdentityChannelTransform
def extend_dim(array): def extend_dim(array):
ones = np.ones((array.shape[0], 1)) ones = np.ones((array.shape[0], 1))
@@ -33,12 +35,23 @@ class SituImage:
self.files = file_list self.files = file_list
self.data = None self.data = None
self.nucleaus_channel = nucleaus_channel self.nucleaus_channel = nucleaus_channel
self.channel_transformations = [
IdentityChannelTransform() for file in file_list
]
def get_data(self): def get_data(self):
if self.data is None: if self.data is None:
self._load_image() self._load_image()
# TODO: apply transformations
return self.data return self.data
def apply_transformations():
# TODO: implement
pass
def set_channel_transformation(self, channel, transformation):
self.channel_transformations[channel] = transformation
def get_channel_count(self): def get_channel_count(self):
return self.get_data().shape[0] return self.get_data().shape[0]
@@ -110,6 +123,8 @@ class SituImage:
Returns: Returns:
np.array: The peaks found by this method as np.array of shape (n, 2) np.array: The peaks found by this method as np.array of shape (n, 2)
''' '''
# TODO: think of a better way to declare peak finding parameters (so they don't need to be passedaround as much)
img = img_as_float(self.get_data()[channel, focus_level, :, :]) img = img_as_float(self.get_data()[channel, focus_level, :, :])
peaks = blob_dog(img, min_sigma=min_sigma, peaks = blob_dog(img, min_sigma=min_sigma,
max_sigma=max_sigma, threshold=threshold) max_sigma=max_sigma, threshold=threshold)

View File

@@ -1,9 +1,15 @@
from situr.registration import Registration from situr.registration import Registration
from situr.transformation import IdentityChannelTransform
class ChannelRegistration(Registration): class ChannelRegistration(Registration):
def do_channel_registration(self, situ_img, reference_channel=0): def do_channel_registration(self, situ_img, reference_channel=0):
# For each channel (except nucleus) compute transform compared to reference_channel # For each channel (except nucleus) compute transform compared to reference_channel
# Add Channel transformation to Channel # Add Channel transformation to Channel
# TODO: implement reference_peaks = situ_img.get_channel_peaks(reference_channel)
pass for channel in range(situ_img.get_channel_count()):
if channel != situ_img.nucleaus_channel and channel != reference_channel:
current_channel_peaks = situ_img.get_channel_peaks(channel)
transformation = self.registration_function.do_registration(
current_channel_peaks, reference_peaks)
situ_img.set_channel_transformation(channel, transformation)

View File

@@ -1 +1 @@
from .channel_transformation import ChannelTransform, ScaleRotateTranslateChannelTransform from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform

View File

@@ -13,6 +13,10 @@ class ChannelTransform:
self.__class__.__name__ + '.apply_transformation') self.__class__.__name__ + '.apply_transformation')
class IdentityChannelTransform(ChannelTransform):
def apply_transformation(self, situ_img, channel):
pass
class ScaleRotateTranslateChannelTransform(ChannelTransform): class ScaleRotateTranslateChannelTransform(ChannelTransform):
def __init__(self, transform_matrix, scale=1, offset=np.array([0, 0])): def __init__(self, transform_matrix, scale=1, offset=np.array([0, 0])):
# TODO: check # TODO: check