mirror of
https://github.com/13hannes11/situr.git
synced 2024-09-03 20:50:58 +02:00
refactor: move peak_finder to registration
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
from .situ_image import extend_dim, remove_dim
|
from .situ_image import extend_dim, remove_dim
|
||||||
from .situ_image import SituImage, PeakFinderDifferenceOfGaussian
|
from .situ_image import SituImage
|
||||||
from .situ_tile import Tile
|
from .situ_tile import Tile
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
import abc
|
|
||||||
from situr.transformation.transformation import Transform
|
from situr.transformation.transformation import Transform
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image
|
||||||
from skimage import img_as_float
|
|
||||||
from skimage.feature import blob_dog
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from situr.transformation import Transform, IdentityTransform
|
from situr.transformation import Transform, IdentityTransform
|
||||||
@@ -18,32 +14,6 @@ def extend_dim(array: np.ndarray):
|
|||||||
def remove_dim(array: np.ndarray):
|
def remove_dim(array: np.ndarray):
|
||||||
return array[:, :-1]
|
return array[:, :-1]
|
||||||
|
|
||||||
# TODO: move peak finder out of image and reverse relationship (peakfinder know about image not the other way around)
|
|
||||||
class PeakFinder:
|
|
||||||
__metaclass__ = abc.ABCMeta
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
|
|
||||||
"""Finds the peaks in the input image"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
self.__class__.__name__ + '.find_peaks')
|
|
||||||
|
|
||||||
|
|
||||||
class PeakFinderDifferenceOfGaussian(PeakFinder):
|
|
||||||
def __init__(self, min_sigma=0.75, max_sigma=3, threshold=0.1):
|
|
||||||
self.min_sigma = min_sigma
|
|
||||||
self.max_sigma = max_sigma
|
|
||||||
self.threshold = threshold
|
|
||||||
|
|
||||||
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
|
|
||||||
img = img_as_float(img_array)
|
|
||||||
peaks = blob_dog(img, min_sigma=self.min_sigma,
|
|
||||||
max_sigma=self.max_sigma, threshold=self.threshold)
|
|
||||||
|
|
||||||
# Swap x and y
|
|
||||||
peaks = peaks[:, [0, 1]] = peaks[:, [1, 0]]
|
|
||||||
return peaks
|
|
||||||
|
|
||||||
|
|
||||||
class SituImage:
|
class SituImage:
|
||||||
"""
|
"""
|
||||||
@@ -62,14 +32,13 @@ class SituImage:
|
|||||||
peak_finder :
|
peak_finder :
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4, peak_finder: PeakFinder = PeakFinderDifferenceOfGaussian()):
|
def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4):
|
||||||
self.files = file_list
|
self.files = file_list
|
||||||
self.data = None
|
self.data = None
|
||||||
self.nucleaus_channel = nucleaus_channel
|
self.nucleaus_channel = nucleaus_channel
|
||||||
self.channel_transformations = [
|
self.channel_transformations = [
|
||||||
IdentityTransform() for file in file_list
|
IdentityTransform() for file in file_list
|
||||||
]
|
]
|
||||||
self.peak_finder = peak_finder
|
|
||||||
|
|
||||||
def get_data(self) -> np.ndarray:
|
def get_data(self) -> np.ndarray:
|
||||||
if self.data is None:
|
if self.data is None:
|
||||||
@@ -136,58 +105,19 @@ class SituImage:
|
|||||||
"""
|
"""
|
||||||
self.data = None
|
self.data = None
|
||||||
|
|
||||||
def show_channel(self, channel: int, focus_level: int = 0) -> Image:
|
def show_channel(self, channel: int, focus_level: int = 0, img_show=True) -> Image:
|
||||||
"""Prints and returns the specified channel and focus_level of the image.
|
"""Prints and returns the specified channel and focus_level of the image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channel (int): The channel that should be used when printing
|
channel (int): The channel that should be used when printing
|
||||||
focus_level (int, optional): The focus level that should be used. Defaults to 0.
|
focus_level (int, optional): The focus level that should be used. Defaults to 0.
|
||||||
|
img_show (bool, optional): Specifies if img.show is to be called or if just the image should be returned. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Image: The image of the specified focus level and channel
|
Image: The image of the specified focus level and channel
|
||||||
"""
|
"""
|
||||||
img = Image.fromarray(
|
img = Image.fromarray(
|
||||||
self.get_data()[channel, focus_level, :, :].astype(np.uint8))
|
self.get_data()[channel, focus_level, :, :].astype(np.uint8))
|
||||||
img.show()
|
if img_show:
|
||||||
return img
|
|
||||||
|
|
||||||
def get_channel_peaks(self, channel: int, focus_level: int = 0) -> np.ndarray:
|
|
||||||
"""Returns the coordinates of peaks (local maxima) in the specified channel and focus_level. It uses the self.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel (int): The channel that should be used when printing
|
|
||||||
focus_level (int, optional): The focus level that should be used. Defaults to 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: The peaks found by this method as np.array of shape (n, 2)
|
|
||||||
"""
|
|
||||||
return self.peak_finder.find_peaks(self.get_data()[channel, focus_level, :, :])
|
|
||||||
|
|
||||||
def show_channel_peaks(self, channel: int, focus_level: int = 0) -> Image:
|
|
||||||
"""Returns and shows the found peaks drawn onto the image. Uses get_channel_peaks internally.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel (int): The channel that should be used when printing
|
|
||||||
focus_level (int, optional): The focus level that should be used. Defaults to 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image: The image of the specified focus level and channel with encircled peaks.
|
|
||||||
"""
|
|
||||||
peaks = self.get_channel_peaks(
|
|
||||||
channel, focus_level)
|
|
||||||
|
|
||||||
img = Image.fromarray(
|
|
||||||
self.get_data()[channel, focus_level, :, :].astype(np.uint8)).convert('RGB')
|
|
||||||
draw = ImageDraw.Draw(img)
|
|
||||||
|
|
||||||
width = 3
|
|
||||||
inner_radius = 5
|
|
||||||
outer_radius = inner_radius + width
|
|
||||||
|
|
||||||
for x, y in zip(peaks[:, 0], peaks[:, 1]):
|
|
||||||
draw.ellipse((x - inner_radius, y - inner_radius, x + inner_radius, y + inner_radius),
|
|
||||||
outline='navy', width=width)
|
|
||||||
draw.ellipse((x - outer_radius, y - outer_radius, x + outer_radius, y + outer_radius),
|
|
||||||
outline='yellow', width=width)
|
|
||||||
img.show()
|
img.show()
|
||||||
return img
|
return img
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ from .registration import Registration, RegistrationFunction, FilterregRegistrat
|
|||||||
from .channel_registration import ChannelRegistration
|
from .channel_registration import ChannelRegistration
|
||||||
from .round_registration import RoundRegistration
|
from .round_registration import RoundRegistration
|
||||||
from .tile_registration import CombinedRegistration
|
from .tile_registration import CombinedRegistration
|
||||||
|
from .peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian
|
||||||
|
|||||||
@@ -1,23 +1,19 @@
|
|||||||
|
from situr.registration.peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian
|
||||||
from situr.image.situ_image import SituImage
|
from situr.image.situ_image import SituImage
|
||||||
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
|
from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction
|
||||||
|
|
||||||
|
|
||||||
class ChannelRegistration(Registration):
|
class ChannelRegistration(Registration):
|
||||||
def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()):
|
|
||||||
"""Initialize channel registration and tell which registration function to use.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
|
|
||||||
"""
|
|
||||||
super().__init__(registration_function)
|
|
||||||
|
|
||||||
def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0):
|
def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0):
|
||||||
# For each channel (except nucleus) compute transform compared to reference_channel
|
# For each channel (except nucleus) compute transform compared to reference_channel
|
||||||
# Add Channel transformation to Channel
|
# Add Channel transformation to Channel
|
||||||
reference_peaks = situ_img.get_channel_peaks(reference_channel)
|
reference_peaks = self.peak_finder.get_channel_peaks(
|
||||||
|
situ_img, reference_channel)
|
||||||
for channel in range(situ_img.get_channel_count()):
|
for channel in range(situ_img.get_channel_count()):
|
||||||
if channel != situ_img.nucleaus_channel and channel != reference_channel:
|
if channel != situ_img.nucleaus_channel and channel != reference_channel:
|
||||||
current_channel_peaks = situ_img.get_channel_peaks(channel)
|
current_channel_peaks = self.peak_finder.get_channel_peaks(
|
||||||
|
situ_img, channel)
|
||||||
transformation = self.registration_function.do_registration(
|
transformation = self.registration_function.do_registration(
|
||||||
current_channel_peaks, reference_peaks)
|
current_channel_peaks, reference_peaks)
|
||||||
situ_img.set_channel_transformation(channel, transformation)
|
situ_img.set_channel_transformation(channel, transformation)
|
||||||
|
|||||||
77
situr/registration/peak_finder.py
Normal file
77
situr/registration/peak_finder.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import abc
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
from skimage import img_as_float
|
||||||
|
from skimage.feature import blob_dog
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from situr.image.situ_image import SituImage
|
||||||
|
|
||||||
|
|
||||||
|
class PeakFinder:
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
|
||||||
|
"""Finds the peaks in the input image"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
self.__class__.__name__ + '.find_peaks')
|
||||||
|
|
||||||
|
def get_channel_peaks(self, img: SituImage, channel: int, focus_level: int = 0) -> np.ndarray:
|
||||||
|
"""Returns the coordinates of peaks (local maxima) in the specified channel and focus_level. It uses the self.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (SituImage): The image to find the peaks on.
|
||||||
|
channel (int): The channel that should be used when printing
|
||||||
|
focus_level (int, optional): The focus level that should be used. Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: np.ndarray: The peaks found by this method as np.array of shape (n, 2)
|
||||||
|
"""
|
||||||
|
return self.find_peaks(img.get_data()[channel, focus_level, :, :])
|
||||||
|
|
||||||
|
def show_channel_peaks(self, img: SituImage, channel: int, focus_level: int = 0, img_show=True) -> Image:
|
||||||
|
"""Returns and shows the found peaks drawn onto the image. Uses get_channel_peaks internally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (SituImage): The image to find the peaks on.
|
||||||
|
channel (int): The channel that should be used when printing
|
||||||
|
focus_level (int, optional): The focus level that should be used. Defaults to 0.
|
||||||
|
img_show (bool, optional): Specifies if img.show is to be called or if just the image should be returned. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image: The image of the specified focus level and channel with encircled peaks.
|
||||||
|
"""
|
||||||
|
peaks = self.get_channel_peaks(img, channel, focus_level)
|
||||||
|
|
||||||
|
img = img.show_channel(
|
||||||
|
channel, focus_level=focus_level, img_show=False).convert('RGB')
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
width = 3
|
||||||
|
inner_radius = 5
|
||||||
|
outer_radius = inner_radius + width
|
||||||
|
|
||||||
|
for x, y in zip(peaks[:, 0], peaks[:, 1]):
|
||||||
|
draw.ellipse((x - inner_radius, y - inner_radius, x + inner_radius, y + inner_radius),
|
||||||
|
outline='navy', width=width)
|
||||||
|
draw.ellipse((x - outer_radius, y - outer_radius, x + outer_radius, y + outer_radius),
|
||||||
|
outline='yellow', width=width)
|
||||||
|
if img_show:
|
||||||
|
img.show()
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class PeakFinderDifferenceOfGaussian(PeakFinder):
|
||||||
|
def __init__(self, min_sigma=0.75, max_sigma=3, threshold=0.1):
|
||||||
|
self.min_sigma = min_sigma
|
||||||
|
self.max_sigma = max_sigma
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
|
||||||
|
img = img_as_float(img_array)
|
||||||
|
peaks = blob_dog(img, min_sigma=self.min_sigma,
|
||||||
|
max_sigma=self.max_sigma, threshold=self.threshold)
|
||||||
|
|
||||||
|
# Swap x and y
|
||||||
|
peaks = peaks[:, [0, 1]] = peaks[:, [1, 0]]
|
||||||
|
return peaks
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import abc
|
import abc
|
||||||
|
from situr.registration.peak_finder import PeakFinderDifferenceOfGaussian
|
||||||
import open3d as o3
|
import open3d as o3
|
||||||
from probreg import filterreg
|
from probreg import filterreg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -31,5 +32,12 @@ class FilterregRegistrationFunction(RegistrationFunction):
|
|||||||
class Registration:
|
class Registration:
|
||||||
__metaclass__ = abc.ABCMeta
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
def __init__(self, registration_function: RegistrationFunction):
|
def __init__(self, registration_function: RegistrationFunction() = FilterregRegistrationFunction(), peak_finder=PeakFinderDifferenceOfGaussian()):
|
||||||
|
"""Initialize channel registration and tell which registration function to use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registration_function (RegistrationFunction, optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
|
||||||
|
peak_finder (PeakFinder, optional): The peak finder to be used for the registration. Defaults to PeakFinderDifferenceOfGaussian().
|
||||||
|
"""
|
||||||
self.registration_function = registration_function
|
self.registration_function = registration_function
|
||||||
|
self.peak_finder = peak_finder
|
||||||
|
|||||||
@@ -2,13 +2,6 @@ from situr.registration import Registration, RegistrationFunction, FilterregRegi
|
|||||||
|
|
||||||
|
|
||||||
class RoundRegistration(Registration):
|
class RoundRegistration(Registration):
|
||||||
def __init__(self, registration_function: RegistrationFunction = FilterregRegistrationFunction()):
|
|
||||||
"""Initialize round registration and tell which registration function to use.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
registration_function (RegistrationFunction[RoundTransform], optional): Registration function. Defaults to FilterregRegistrationFunction(ScaleRotateTranslateChannelTransform).
|
|
||||||
"""
|
|
||||||
super().__init__(registration_function)
|
|
||||||
|
|
||||||
def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0):
|
def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0):
|
||||||
"""This method generates a round registration transformation for a tile and saves it in the tile.
|
"""This method generates a round registration transformation for a tile and saves it in the tile.
|
||||||
@@ -20,13 +13,12 @@ class RoundRegistration(Registration):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: instead of one reference channel use all channels (maybe without nucleus channel)
|
# TODO: instead of one reference channel use all channels (maybe without nucleus channel)
|
||||||
reference_peaks = situ_tile.get_round(
|
reference_peaks = self.peak_finder.get_channel_peaks(situ_tile.get_round(
|
||||||
reference_round).get_channel_peaks(reference_channel)
|
reference_round), reference_channel)
|
||||||
for round in range(situ_tile.get_round_count()):
|
for round in range(situ_tile.get_round_count()):
|
||||||
if round != reference_channel:
|
if round != reference_channel:
|
||||||
current_round_peaks = situ_tile.get_round(
|
current_round_peaks = self.peak_finder.get_channel_peaks(
|
||||||
round
|
situ_tile.get_round(round), reference_channel)
|
||||||
).get_channel_peaks(reference_channel)
|
|
||||||
transformation = self.registration_function.do_registration(
|
transformation = self.registration_function.do_registration(
|
||||||
current_round_peaks, reference_peaks)
|
current_round_peaks, reference_peaks)
|
||||||
situ_tile.set_round_transformation(round, transformation)
|
situ_tile.set_round_transformation(round, transformation)
|
||||||
|
|||||||
Reference in New Issue
Block a user