autoformat code and add missing imports

This commit is contained in:
2021-07-12 15:54:50 +02:00
parent 85e2edb8a5
commit 93c84e731d
4 changed files with 47 additions and 24 deletions

View File

@@ -8,8 +8,9 @@ def extend_dim(array):
ones = np.ones((array.shape[0], 1)) ones = np.ones((array.shape[0], 1))
return np.append(array, ones, axis=1) return np.append(array, ones, axis=1)
def remove_dim(array): def remove_dim(array):
return array[:,:-1] return array[:, :-1]
class SituImage: class SituImage:
@@ -45,7 +46,7 @@ class SituImage:
Returns: Returns:
numpy.array: The loaded image of shape (focus_level, width, height) numpy.array: The loaded image of shape (focus_level, width, height)
''' '''
return self.get_data()[channel,:,:,:] return self.get_data()[channel, :, :, :]
def _load_image(self): def _load_image(self):
''' '''
@@ -83,7 +84,7 @@ class SituImage:
Returns: Returns:
image: The image of the specified focus level and channel image: The image of the specified focus level and channel
''' '''
img = Image.fromarray(self.get_data()[0,0,:,:]) img = Image.fromarray(self.get_data()[0, 0, :, :])
img.show() img.show()
return img return img
@@ -107,7 +108,8 @@ class SituImage:
np.array: The peaks found by this method as np.array of shape (n, 2) np.array: The peaks found by this method as np.array of shape (n, 2)
''' '''
img = img_as_float(self.get_data()[channel, focus_level, :, :]) img = img_as_float(self.get_data()[channel, focus_level, :, :])
peaks = blob_dog(img, min_sigma=min_sigma, max_sigma=max_sigma, threshold=threshold) peaks = blob_dog(img, min_sigma=min_sigma,
max_sigma=max_sigma, threshold=threshold)
return peaks[:, 0:2] return peaks[:, 0:2]
def show_channel_peaks(self, channel, focus_level=0, min_sigma=0.75, max_sigma=3, threshold=0.1): def show_channel_peaks(self, channel, focus_level=0, min_sigma=0.75, max_sigma=3, threshold=0.1):
@@ -128,11 +130,13 @@ class SituImage:
Returns: Returns:
image: The image of the specified focus level and channel with encircled peaks. image: The image of the specified focus level and channel with encircled peaks.
''' '''
peaks = self.get_channel_peaks(channel, focus_level, min_sigma, max_sigma, threshold) peaks = self.get_channel_peaks(
channel, focus_level, min_sigma, max_sigma, threshold)
img = Image.fromarray(self.get_data()[channel, focus_level, :, :]) img = Image.fromarray(self.get_data()[channel, focus_level, :, :])
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
for x, y in zip(peaks[:,0], peaks[:,1]): for x, y in zip(peaks[:, 0], peaks[:, 1]):
draw.ellipse((x - 5, y - 5, x + 5, y + 5), outline ='white', width = 3) draw.ellipse((x - 5, y - 5, x + 5, y + 5),
outline='white', width=3)
img.show() img.show()
return img return img

View File

@@ -1,14 +1,25 @@
import abc
import open3d as o3
from probreg import filterreg
from situr.image import extend_dim
from situr.transformation import ScaleRotateTranslateChannelTransform
class ChannelRegistration: class ChannelRegistration:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def do_registration(self, situ_img , reference_channel=0): def do_registration(self, situ_img, reference_channel=0):
# For each channel (except nucleus) compute transform compared to reference_channel # For each channel (except nucleus) compute transform compared to reference_channel
# Add Channel transformation to Channel # Add Channel transformation to Channel
pass pass
@abc.abstractmethod @abc.abstractmethod
def register_single_channel(self, peaks_data, reference_peaks): def register_single_channel(self, peaks_data, reference_peaks):
"""Performs the channel registration on an image. Expects the peaks in each image as input.""" """Performs the channel registration on an image. Expects the peaks in each image as input."""
raise NotImplementedError(self.__class__.__name__ + '.register_single_channel') raise NotImplementedError(
self.__class__.__name__ + '.register_single_channel')
class FilterregChannelRegistration(ChannelRegistration): class FilterregChannelRegistration(ChannelRegistration):
def register_single_channel(self, data_peaks, reference_peaks): def register_single_channel(self, data_peaks, reference_peaks):
@@ -17,7 +28,7 @@ class FilterregChannelRegistration(ChannelRegistration):
target = o3.geometry.PointCloud() target = o3.geometry.PointCloud()
target.points = o3.utility.Vector3dVector(extend_dim(reference_peaks)) target.points = o3.utility.Vector3dVector(extend_dim(reference_peaks))
registration_method=filterreg.registration_filterreg registration_method = filterreg.registration_filterreg
tf_param, _, _ = filterreg.registration_filterreg(source, target) tf_param, _, _ = filterreg.registration_filterreg(source, target)
return ScaleRotateTranslateChannelTransform(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2]) return ScaleRotateTranslateChannelTransform(transform_matrix=tf_param.rot[0:2, 0:2], scale=tf_param.scale, offset=tf_param.t[0:2])

View File

@@ -0,0 +1 @@
from .channel_transformation import ChannelTransform, ScaleRotateTranslateChannelTransform

View File

@@ -1,10 +1,16 @@
import abc
import numpy as np
import scipy
class ChannelTransform: class ChannelTransform:
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
@abc.abstractmethod @abc.abstractmethod
def apply_transformation(self, situ_img , channel): def apply_transformation(self, situ_img, channel):
"""Performs a transformation on one channel, all focus_levels are transformed the same way""" """Performs a transformation on one channel, all focus_levels are transformed the same way"""
raise NotImplementedError(self.__class__.__name__ + '.apply_transformation') raise NotImplementedError(
self.__class__.__name__ + '.apply_transformation')
class ScaleRotateTranslateChannelTransform(ChannelTransform): class ScaleRotateTranslateChannelTransform(ChannelTransform):
@@ -16,13 +22,14 @@ class ScaleRotateTranslateChannelTransform(ChannelTransform):
self.offset = offset self.offset = offset
self.scale = scale self.scale = scale
def apply_tranformation(self, situ_img , channel): def apply_tranformation(self, situ_img, channel):
channel_img = situ_img.get_channel(channel) channel_img = situ_img.get_channel(channel)
focus_levels = channel_img.shape[0] focus_levels = channel_img.shape[0]
for focus_level in range(focus_levels): for focus_level in range(focus_levels):
img = channel_img [focus_level, :, :] img = channel_img[focus_level, :, :]
img [:, :] = scipy.ndimage.affine_transform(img, self.transform_matrix) img[:, :] = scipy.ndimage.affine_transform(
img [:, :] = scipy.ndimage.zoom(img, self.scale) img, self.transform_matrix)
img [:, :] = scipy.ndimage.shift(img, self.offset) img[:, :] = scipy.ndimage.zoom(img, self.scale)
img[:, :] = scipy.ndimage.shift(img, self.offset)