refactor: move peak_finder to registration

This commit is contained in:
2021-07-19 12:13:43 +02:00
parent f04ff2a296
commit 53261f9ae1
7 changed files with 103 additions and 99 deletions

View File

@@ -1,3 +1,3 @@
from .situ_image import extend_dim, remove_dim from .situ_image import extend_dim, remove_dim
from .situ_image import SituImage, PeakFinderDifferenceOfGaussian from .situ_image import SituImage
from .situ_tile import Tile from .situ_tile import Tile

View File

@@ -1,10 +1,6 @@
import abc
from situr.transformation.transformation import Transform from situr.transformation.transformation import Transform
import numpy as np import numpy as np
from PIL import Image, ImageDraw from PIL import Image
from skimage import img_as_float
from skimage.feature import blob_dog
from typing import List from typing import List
from situr.transformation import Transform, IdentityTransform from situr.transformation import Transform, IdentityTransform
@@ -18,32 +14,6 @@ def extend_dim(array: np.ndarray):
def remove_dim(array: np.ndarray): def remove_dim(array: np.ndarray):
return array[:, :-1] return array[:, :-1]
# TODO: move peak finder out of image and reverse relationship (peakfinder know about image not the other way around)
class PeakFinder:
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
"""Finds the peaks in the input image"""
raise NotImplementedError(
self.__class__.__name__ + '.find_peaks')
class PeakFinderDifferenceOfGaussian(PeakFinder):
def __init__(self, min_sigma=0.75, max_sigma=3, threshold=0.1):
self.min_sigma = min_sigma
self.max_sigma = max_sigma
self.threshold = threshold
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
img = img_as_float(img_array)
peaks = blob_dog(img, min_sigma=self.min_sigma,
max_sigma=self.max_sigma, threshold=self.threshold)
# Swap x and y
peaks = peaks[:, [0, 1]] = peaks[:, [1, 0]]
return peaks
class SituImage: class SituImage:
""" """
@@ -62,14 +32,13 @@ class SituImage:
peak_finder : peak_finder :
""" """
def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4, peak_finder: PeakFinder = PeakFinderDifferenceOfGaussian()): def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4):
self.files = file_list self.files = file_list
self.data = None self.data = None
self.nucleaus_channel = nucleaus_channel self.nucleaus_channel = nucleaus_channel
self.channel_transformations = [ self.channel_transformations = [
IdentityTransform() for file in file_list IdentityTransform() for file in file_list
] ]
self.peak_finder = peak_finder
def get_data(self) -> np.ndarray: def get_data(self) -> np.ndarray:
if self.data is None: if self.data is None:
@@ -136,58 +105,19 @@ class SituImage:
""" """
self.data = None self.data = None
def show_channel(self, channel: int, focus_level: int = 0) -> Image: def show_channel(self, channel: int, focus_level: int = 0, img_show=True) -> Image:
"""Prints and returns the specified channel and focus_level of the image. """Prints and returns the specified channel and focus_level of the image.
Args: Args:
channel (int): The channel that should be used when printing channel (int): The channel that should be used when printing
focus_level (int, optional): The focus level that should be used. Defaults to 0. focus_level (int, optional): The focus level that should be used. Defaults to 0.
img_show (bool, optional): Specifies if img.show is to be called or if just the image should be returned. Defaults to True.
Returns: Returns:
Image: The image of the specified focus level and channel Image: The image of the specified focus level and channel
""" """
img = Image.fromarray( img = Image.fromarray(
self.get_data()[channel, focus_level, :, :].astype(np.uint8)) self.get_data()[channel, focus_level, :, :].astype(np.uint8))
img.show() if img_show:
return img
def get_channel_peaks(self, channel: int, focus_level: int = 0) -> np.ndarray:
"""Returns the coordinates of peaks (local maxima) in the specified channel and focus_level. It uses the self.
Args:
channel (int): The channel that should be used when printing
focus_level (int, optional): The focus level that should be used. Defaults to 0.
Returns:
np.ndarray: The peaks found by this method as np.array of shape (n, 2)
"""
return self.peak_finder.find_peaks(self.get_data()[channel, focus_level, :, :])
def show_channel_peaks(self, channel: int, focus_level: int = 0) -> Image:
"""Returns and shows the found peaks drawn onto the image. Uses get_channel_peaks internally.
Args:
channel (int): The channel that should be used when printing
focus_level (int, optional): The focus level that should be used. Defaults to 0.
Returns:
Image: The image of the specified focus level and channel with encircled peaks.
"""
peaks = self.get_channel_peaks(
channel, focus_level)
img = Image.fromarray(
self.get_data()[channel, focus_level, :, :].astype(np.uint8)).convert('RGB')
draw = ImageDraw.Draw(img)
width = 3
inner_radius = 5
outer_radius = inner_radius + width
for x, y in zip(peaks[:, 0], peaks[:, 1]):
draw.ellipse((x - inner_radius, y - inner_radius, x + inner_radius, y + inner_radius),
outline='navy', width=width)
draw.ellipse((x - outer_radius, y - outer_radius, x + outer_radius, y + outer_radius),
outline='yellow', width=width)
img.show() img.show()
return img return img

View File

@@ -2,3 +2,4 @@ from .registration import Registration, RegistrationFunction, FilterregRegistrat
from .channel_registration import ChannelRegistration from .channel_registration import ChannelRegistration
from .round_registration import RoundRegistration from .round_registration import RoundRegistration
from .tile_registration import CombinedRegistration from .tile_registration import CombinedRegistration
from .peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian

View File

@@ -1,23 +1,19 @@
from situr.registration.peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian
from situr.image.situ_image import SituImage from situr.image.situ_image import SituImage
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
class ChannelRegistration(Registration): class ChannelRegistration(Registration):
def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()):
"""Initialize channel registration and tell which registration function to use.
Args:
registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
"""
super().__init__(registration_function)
def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0): def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0):
# For each channel (except nucleus) compute transform compared to reference_channel # For each channel (except nucleus) compute transform compared to reference_channel
# Add Channel transformation to Channel # Add Channel transformation to Channel
reference_peaks = situ_img.get_channel_peaks(reference_channel) reference_peaks = self.peak_finder.get_channel_peaks(
situ_img, reference_channel)
for channel in range(situ_img.get_channel_count()): for channel in range(situ_img.get_channel_count()):
if channel != situ_img.nucleaus_channel and channel != reference_channel: if channel != situ_img.nucleaus_channel and channel != reference_channel:
current_channel_peaks = situ_img.get_channel_peaks(channel) current_channel_peaks = self.peak_finder.get_channel_peaks(
situ_img, channel)
transformation = self.registration_function.do_registration( transformation = self.registration_function.do_registration(
current_channel_peaks, reference_peaks) current_channel_peaks, reference_peaks)
situ_img.set_channel_transformation(channel, transformation) situ_img.set_channel_transformation(channel, transformation)

View File

@@ -0,0 +1,77 @@
import abc
from PIL import Image, ImageDraw
from skimage import img_as_float
from skimage.feature import blob_dog
import numpy as np
from situr.image.situ_image import SituImage
class PeakFinder:
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
"""Finds the peaks in the input image"""
raise NotImplementedError(
self.__class__.__name__ + '.find_peaks')
def get_channel_peaks(self, img: SituImage, channel: int, focus_level: int = 0) -> np.ndarray:
"""Returns the coordinates of peaks (local maxima) in the specified channel and focus_level. It uses the self.
Args:
img (SituImage): The image to find the peaks on.
channel (int): The channel that should be used when printing
focus_level (int, optional): The focus level that should be used. Defaults to 0.
Returns:
np.ndarray: np.ndarray: The peaks found by this method as np.array of shape (n, 2)
"""
return self.find_peaks(img.get_data()[channel, focus_level, :, :])
def show_channel_peaks(self, img: SituImage, channel: int, focus_level: int = 0, img_show=True) -> Image:
"""Returns and shows the found peaks drawn onto the image. Uses get_channel_peaks internally.
Args:
img (SituImage): The image to find the peaks on.
channel (int): The channel that should be used when printing
focus_level (int, optional): The focus level that should be used. Defaults to 0.
img_show (bool, optional): Specifies if img.show is to be called or if just the image should be returned. Defaults to True.
Returns:
Image: The image of the specified focus level and channel with encircled peaks.
"""
peaks = self.get_channel_peaks(img, channel, focus_level)
img = img.show_channel(
channel, focus_level=focus_level, img_show=False).convert('RGB')
draw = ImageDraw.Draw(img)
width = 3
inner_radius = 5
outer_radius = inner_radius + width
for x, y in zip(peaks[:, 0], peaks[:, 1]):
draw.ellipse((x - inner_radius, y - inner_radius, x + inner_radius, y + inner_radius),
outline='navy', width=width)
draw.ellipse((x - outer_radius, y - outer_radius, x + outer_radius, y + outer_radius),
outline='yellow', width=width)
if img_show:
img.show()
return img
class PeakFinderDifferenceOfGaussian(PeakFinder):
def __init__(self, min_sigma=0.75, max_sigma=3, threshold=0.1):
self.min_sigma = min_sigma
self.max_sigma = max_sigma
self.threshold = threshold
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
img = img_as_float(img_array)
peaks = blob_dog(img, min_sigma=self.min_sigma,
max_sigma=self.max_sigma, threshold=self.threshold)
# Swap x and y
peaks = peaks[:, [0, 1]] = peaks[:, [1, 0]]
return peaks

View File

@@ -1,4 +1,5 @@
import abc import abc
from situr.registration.peak_finder import PeakFinderDifferenceOfGaussian
import open3d as o3 import open3d as o3
from probreg import filterreg from probreg import filterreg
import numpy as np import numpy as np
@@ -31,5 +32,12 @@ class FilterregRegistrationFunction(RegistrationFunction):
class Registration: class Registration:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, registration_function: RegistrationFunction): def __init__(self, registration_function: RegistrationFunction() = FilterregRegistrationFunction(), peak_finder=PeakFinderDifferenceOfGaussian()):
"""Initialize channel registration and tell which registration function to use.
Args:
registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
peak_finder (PeakFinder, optional): The peak finder to be used for the registration. Defaults to PeakFinderDifferenceOfGaussian().
"""
self.registration_function = registration_function self.registration_function = registration_function
self.peak_finder = peak_finder

View File

@@ -2,13 +2,6 @@ from situr.registration import Registration, RegistrationFunction, FilterregRegi
class RoundRegistration(Registration): class RoundRegistration(Registration):
def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()):
"""Initialize round registration and tell which registration function to use.
Args:
registration_function (RegistrationFunction[RoundTransform], optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
"""
super().__init__(registration_function)
def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0): def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0):
"""This method generates a round registration transformation for a tile and saves it in the tile. """This method generates a round registration transformation for a tile and saves it in the tile.
@@ -20,13 +13,12 @@ class RoundRegistration(Registration):
""" """
# TODO: instead of one reference channel use all channels (maybe without nucleus channel) # TODO: instead of one reference channel use all channels (maybe without nucleus channel)
reference_peaks = situ_tile.get_round( reference_peaks = self.peak_finder.get_channel_peaks(situ_tile.get_round(
reference_round).get_channel_peaks(reference_channel) reference_round), reference_channel)
for round in range(situ_tile.get_round_count()): for round in range(situ_tile.get_round_count()):
if round != reference_channel: if round != reference_channel:
current_round_peaks = situ_tile.get_round( current_round_peaks = self.peak_finder.get_channel_peaks(
round situ_tile.get_round(round), reference_channel)
).get_channel_peaks(reference_channel)
transformation = self.registration_function.do_registration( transformation = self.registration_function.do_registration(
current_round_peaks, reference_peaks) current_round_peaks, reference_peaks)
situ_tile.set_round_transformation(round, transformation) situ_tile.set_round_transformation(round, transformation)