diff --git a/situr/image/situ_image.py b/situr/image/situ_image.py index df6bd20..ddab26a 100644 --- a/situr/image/situ_image.py +++ b/situr/image/situ_image.py @@ -3,6 +3,8 @@ from PIL import Image, ImageDraw from skimage import img_as_float from skimage.feature import blob_dog +from situr.transformation.channel_transformation import IdentityChannelTransform + def extend_dim(array): ones = np.ones((array.shape[0], 1)) @@ -33,12 +35,23 @@ class SituImage: self.files = file_list self.data = None self.nucleaus_channel = nucleaus_channel + self.channel_transformations = [ + IdentityChannelTransform() for file in file_list + ] def get_data(self): if self.data is None: self._load_image() + # TODO: apply transformations return self.data + def apply_transformations(): + # TODO: implement + pass + + def set_channel_transformation(self, channel, transformation): + self.channel_transformations[channel] = transformation + def get_channel_count(self): return self.get_data().shape[0] @@ -110,6 +123,8 @@ class SituImage: Returns: np.array: The peaks found by this method as np.array of shape (n, 2) ''' + # TODO: think of a better way to declare peak finding parameters (so they don't need to be passedaround as much) + img = img_as_float(self.get_data()[channel, focus_level, :, :]) peaks = blob_dog(img, min_sigma=min_sigma, max_sigma=max_sigma, threshold=threshold) diff --git a/situr/registration/channel_registration.py b/situr/registration/channel_registration.py index be6050d..61af21e 100644 --- a/situr/registration/channel_registration.py +++ b/situr/registration/channel_registration.py @@ -1,9 +1,15 @@ from situr.registration import Registration +from situr.transformation import IdentityChannelTransform class ChannelRegistration(Registration): def do_channel_registration(self, situ_img, reference_channel=0): # For each channel (except nucleus) compute transform compared to reference_channel # Add Channel transformation to Channel - # TODO: implement - pass + reference_peaks = situ_img.get_channel_peaks(reference_channel) + for channel in range(situ_img.get_channel_count()): + if channel != situ_img.nucleaus_channel and channel != reference_channel: + current_channel_peaks = situ_img.get_channel_peaks(channel) + transformation = self.registration_function.do_registration( + current_channel_peaks, reference_peaks) + situ_img.set_channel_transformation(channel, transformation) diff --git a/situr/transformation/__init__.py b/situr/transformation/__init__.py index 46ddf85..71b3be3 100644 --- a/situr/transformation/__init__.py +++ b/situr/transformation/__init__.py @@ -1 +1 @@ -from .channel_transformation import ChannelTransform, ScaleRotateTranslateChannelTransform +from .channel_transformation import ChannelTransform, IdentityChannelTransform, ScaleRotateTranslateChannelTransform diff --git a/situr/transformation/channel_transformation.py b/situr/transformation/channel_transformation.py index 189e899..a12fb19 100644 --- a/situr/transformation/channel_transformation.py +++ b/situr/transformation/channel_transformation.py @@ -13,6 +13,10 @@ class ChannelTransform: self.__class__.__name__ + '.apply_transformation') +class IdentityChannelTransform(ChannelTransform): + def apply_transformation(self, situ_img, channel): + pass + class ScaleRotateTranslateChannelTransform(ChannelTransform): def __init__(self, transform_matrix, scale=1, offset=np.array([0, 0])): # TODO: check