From 53261f9ae1cc73d3b23626ff58430ba5def0ce0a Mon Sep 17 00:00:00 2001 From: "Hannes F. Kuchelmeister" Date: Mon, 19 Jul 2021 12:13:43 +0200 Subject: [PATCH] refactor: move peak_finder to registration --- situr/image/__init__.py | 2 +- situr/image/situ_image.py | 82 ++-------------------- situr/registration/__init__.py | 1 + situr/registration/channel_registration.py | 14 ++-- situr/registration/peak_finder.py | 77 ++++++++++++++++++++ situr/registration/registration.py | 10 ++- situr/registration/round_registration.py | 16 ++--- 7 files changed, 103 insertions(+), 99 deletions(-) create mode 100644 situr/registration/peak_finder.py diff --git a/situr/image/__init__.py b/situr/image/__init__.py index 628c2a9..350837f 100644 --- a/situr/image/__init__.py +++ b/situr/image/__init__.py @@ -1,3 +1,3 @@ from .situ_image import extend_dim, remove_dim -from .situ_image import SituImage, PeakFinderDifferenceOfGaussian +from .situ_image import SituImage from .situ_tile import Tile diff --git a/situr/image/situ_image.py b/situr/image/situ_image.py index 3522ad8..4f7355b 100644 --- a/situr/image/situ_image.py +++ b/situr/image/situ_image.py @@ -1,10 +1,6 @@ -import abc from situr.transformation.transformation import Transform import numpy as np -from PIL import Image, ImageDraw -from skimage import img_as_float -from skimage.feature import blob_dog - +from PIL import Image from typing import List from situr.transformation import Transform, IdentityTransform @@ -18,32 +14,6 @@ def extend_dim(array: np.ndarray): def remove_dim(array: np.ndarray): return array[:, :-1] -# TODO: move peak finder out of image and reverse relationship (peakfinder know about image not the other way around) -class PeakFinder: - __metaclass__ = abc.ABCMeta - - @abc.abstractmethod - def find_peaks(self, img_array: np.ndarray) -> np.ndarray: - """Finds the peaks in the input image""" - raise NotImplementedError( - self.__class__.__name__ + '.find_peaks') - - -class PeakFinderDifferenceOfGaussian(PeakFinder): - def __init__(self, min_sigma=0.75, max_sigma=3, threshold=0.1): - self.min_sigma = min_sigma - self.max_sigma = max_sigma - self.threshold = threshold - - def find_peaks(self, img_array: np.ndarray) -> np.ndarray: - img = img_as_float(img_array) - peaks = blob_dog(img, min_sigma=self.min_sigma, - max_sigma=self.max_sigma, threshold=self.threshold) - - # Swap x and y - peaks = peaks[:, [0, 1]] = peaks[:, [1, 0]] - return peaks - class SituImage: """ @@ -62,14 +32,13 @@ class SituImage: peak_finder : """ - def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4, peak_finder: PeakFinder = PeakFinderDifferenceOfGaussian()): + def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4): self.files = file_list self.data = None self.nucleaus_channel = nucleaus_channel self.channel_transformations = [ IdentityTransform() for file in file_list ] - self.peak_finder = peak_finder def get_data(self) -> np.ndarray: if self.data is None: @@ -136,58 +105,19 @@ class SituImage: """ self.data = None - def show_channel(self, channel: int, focus_level: int = 0) -> Image: + def show_channel(self, channel: int, focus_level: int = 0, img_show=True) -> Image: """Prints and returns the specified channel and focus_level of the image. Args: channel (int): The channel that should be used when printing focus_level (int, optional): The focus level that should be used. Defaults to 0. + img_show (bool, optional): Specifies if img.show is to be called or if just the image should be returned. Defaults to True. Returns: Image: The image of the specified focus level and channel """ img = Image.fromarray( self.get_data()[channel, focus_level, :, :].astype(np.uint8)) - img.show() - return img - - def get_channel_peaks(self, channel: int, focus_level: int = 0) -> np.ndarray: - """Returns the coordinates of peaks (local maxima) in the specified channel and focus_level. It uses the self. - - Args: - channel (int): The channel that should be used when printing - focus_level (int, optional): The focus level that should be used. Defaults to 0. - - Returns: - np.ndarray: The peaks found by this method as np.array of shape (n, 2) - """ - return self.peak_finder.find_peaks(self.get_data()[channel, focus_level, :, :]) - - def show_channel_peaks(self, channel: int, focus_level: int = 0) -> Image: - """Returns and shows the found peaks drawn onto the image. Uses get_channel_peaks internally. - - Args: - channel (int): The channel that should be used when printing - focus_level (int, optional): The focus level that should be used. Defaults to 0. - - Returns: - Image: The image of the specified focus level and channel with encircled peaks. - """ - peaks = self.get_channel_peaks( - channel, focus_level) - - img = Image.fromarray( - self.get_data()[channel, focus_level, :, :].astype(np.uint8)).convert('RGB') - draw = ImageDraw.Draw(img) - - width = 3 - inner_radius = 5 - outer_radius = inner_radius + width - - for x, y in zip(peaks[:, 0], peaks[:, 1]): - draw.ellipse((x - inner_radius, y - inner_radius, x + inner_radius, y + inner_radius), - outline='navy', width=width) - draw.ellipse((x - outer_radius, y - outer_radius, x + outer_radius, y + outer_radius), - outline='yellow', width=width) - img.show() + if img_show: + img.show() return img diff --git a/situr/registration/__init__.py b/situr/registration/__init__.py index 421fc2a..88400f1 100644 --- a/situr/registration/__init__.py +++ b/situr/registration/__init__.py @@ -2,3 +2,4 @@ from .registration import Registration, RegistrationFunction, FilterregRegistrat from .channel_registration import ChannelRegistration from .round_registration import RoundRegistration from .tile_registration import CombinedRegistration +from .peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian diff --git a/situr/registration/channel_registration.py b/situr/registration/channel_registration.py index fc2047d..6b187d1 100644 --- a/situr/registration/channel_registration.py +++ b/situr/registration/channel_registration.py @@ -1,23 +1,19 @@ +from situr.registration.peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian from situr.image.situ_image import SituImage from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction class ChannelRegistration(Registration): - def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()): - """Initialize channel registration and tell which registration function to use. - - Args: - registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform). - """ - super().__init__(registration_function) def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0): # For each channel (except nucleus) compute transform compared to reference_channel # Add Channel transformation to Channel - reference_peaks = situ_img.get_channel_peaks(reference_channel) + reference_peaks = self.peak_finder.get_channel_peaks( + situ_img, reference_channel) for channel in range(situ_img.get_channel_count()): if channel != situ_img.nucleaus_channel and channel != reference_channel: - current_channel_peaks = situ_img.get_channel_peaks(channel) + current_channel_peaks = self.peak_finder.get_channel_peaks( + situ_img, channel) transformation = self.registration_function.do_registration( current_channel_peaks, reference_peaks) situ_img.set_channel_transformation(channel, transformation) diff --git a/situr/registration/peak_finder.py b/situr/registration/peak_finder.py new file mode 100644 index 0000000..1c4a915 --- /dev/null +++ b/situr/registration/peak_finder.py @@ -0,0 +1,77 @@ +import abc +from PIL import Image, ImageDraw +from skimage import img_as_float +from skimage.feature import blob_dog +import numpy as np + +from situr.image.situ_image import SituImage + + +class PeakFinder: + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def find_peaks(self, img_array: np.ndarray) -> np.ndarray: + """Finds the peaks in the input image""" + raise NotImplementedError( + self.__class__.__name__ + '.find_peaks') + + def get_channel_peaks(self, img: SituImage, channel: int, focus_level: int = 0) -> np.ndarray: + """Returns the coordinates of peaks (local maxima) in the specified channel and focus_level. It uses the self. + + Args: + img (SituImage): The image to find the peaks on. + channel (int): The channel that should be used when printing + focus_level (int, optional): The focus level that should be used. Defaults to 0. + + Returns: + np.ndarray: np.ndarray: The peaks found by this method as np.array of shape (n, 2) + """ + return self.find_peaks(img.get_data()[channel, focus_level, :, :]) + + def show_channel_peaks(self, img: SituImage, channel: int, focus_level: int = 0, img_show=True) -> Image: + """Returns and shows the found peaks drawn onto the image. Uses get_channel_peaks internally. + + Args: + img (SituImage): The image to find the peaks on. + channel (int): The channel that should be used when printing + focus_level (int, optional): The focus level that should be used. Defaults to 0. + img_show (bool, optional): Specifies if img.show is to be called or if just the image should be returned. Defaults to True. + + Returns: + Image: The image of the specified focus level and channel with encircled peaks. + """ + peaks = self.get_channel_peaks(img, channel, focus_level) + + img = img.show_channel( + channel, focus_level=focus_level, img_show=False).convert('RGB') + draw = ImageDraw.Draw(img) + + width = 3 + inner_radius = 5 + outer_radius = inner_radius + width + + for x, y in zip(peaks[:, 0], peaks[:, 1]): + draw.ellipse((x - inner_radius, y - inner_radius, x + inner_radius, y + inner_radius), + outline='navy', width=width) + draw.ellipse((x - outer_radius, y - outer_radius, x + outer_radius, y + outer_radius), + outline='yellow', width=width) + if img_show: + img.show() + return img + + +class PeakFinderDifferenceOfGaussian(PeakFinder): + def __init__(self, min_sigma=0.75, max_sigma=3, threshold=0.1): + self.min_sigma = min_sigma + self.max_sigma = max_sigma + self.threshold = threshold + + def find_peaks(self, img_array: np.ndarray) -> np.ndarray: + img = img_as_float(img_array) + peaks = blob_dog(img, min_sigma=self.min_sigma, + max_sigma=self.max_sigma, threshold=self.threshold) + + # Swap x and y + peaks = peaks[:, [0, 1]] = peaks[:, [1, 0]] + return peaks diff --git a/situr/registration/registration.py b/situr/registration/registration.py index f4db766..762e9b7 100644 --- a/situr/registration/registration.py +++ b/situr/registration/registration.py @@ -1,4 +1,5 @@ import abc +from situr.registration.peak_finder import PeakFinderDifferenceOfGaussian import open3d as o3 from probreg import filterreg import numpy as np @@ -31,5 +32,12 @@ class FilterregRegistrationFunction(RegistrationFunction): class Registration: __metaclass__ = abc.ABCMeta - def __init__(self, registration_function: RegistrationFunction): + def __init__(self, registration_function: RegistrationFunction() = FilterregRegistrationFunction(), peak_finder=PeakFinderDifferenceOfGaussian()): + """Initialize channel registration and tell which registration function to use. + + Args: + registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform). + peak_finder (PeakFinder, optional): The peak finder to be used for the registration. Defaults to PeakFinderDifferenceOfGaussian(). + """ self.registration_function = registration_function + self.peak_finder = peak_finder diff --git a/situr/registration/round_registration.py b/situr/registration/round_registration.py index f5519eb..492cda7 100644 --- a/situr/registration/round_registration.py +++ b/situr/registration/round_registration.py @@ -2,13 +2,6 @@ from situr.registration import Registration, RegistrationFunction, FilterregRegi class RoundRegistration(Registration): - def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()): - """Initialize round registration and tell which registration function to use. - - Args: - registration_function (RegistrationFunction[RoundTransform], optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform). - """ - super().__init__(registration_function) def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0): """This method generates a round registration transformation for a tile and saves it in the tile. @@ -20,13 +13,12 @@ class RoundRegistration(Registration): """ # TODO: instead of one reference channel use all channels (maybe without nucleus channel) - reference_peaks = situ_tile.get_round( - reference_round).get_channel_peaks(reference_channel) + reference_peaks = self.peak_finder.get_channel_peaks(situ_tile.get_round( + reference_round), reference_channel) for round in range(situ_tile.get_round_count()): if round != reference_channel: - current_round_peaks = situ_tile.get_round( - round - ).get_channel_peaks(reference_channel) + current_round_peaks = self.peak_finder.get_channel_peaks( + situ_tile.get_round(round), reference_channel) transformation = self.registration_function.do_registration( current_round_peaks, reference_peaks) situ_tile.set_round_transformation(round, transformation)