diff --git a/situr/image/situ_image.py b/situr/image/situ_image.py index 5df6465..6d8599e 100644 --- a/situr/image/situ_image.py +++ b/situr/image/situ_image.py @@ -3,15 +3,17 @@ from PIL import Image, ImageDraw from skimage import img_as_float from skimage.feature import blob_dog -from situr.transformation.channel_transformation import IdentityChannelTransform +from typing import List + +from situr.transformation.channel_transformation import ChannelTransform, IdentityChannelTransform -def extend_dim(array): +def extend_dim(array: np.ndarray): ones = np.ones((array.shape[0], 1)) return np.append(array, ones, axis=1) -def remove_dim(array): +def remove_dim(array: np.ndarray): return array[:, :-1] @@ -21,7 +23,7 @@ class PeakFinderDifferenceOfGaussian: self.max_sigma = max_sigma self.threshold = threshold - def find_peaks(self, img_array): + def find_peaks(self, img_array: np.ndarray) -> np.ndarray: img = img_as_float(img_array) peaks = blob_dog(img, min_sigma=self.min_sigma, max_sigma=self.max_sigma, threshold=self.threshold) @@ -38,13 +40,13 @@ class SituImage: ---------- data : numpy.array the image data containing all the channels of shape (channels, focus_levels, image_size_y, image_size_x) - files: (list(list(str))) + files : List[List[str]] A list of lists. Each inner list corresponds to one focus level. Its contents correspons to a file for each channel. nucleaus_channel : int tells which channel is used for showing where the cell nucleuses are. """ - def __init__(self, file_list, nucleaus_channel=4): + def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4): self.files = file_list self.data = None self.nucleaus_channel = nucleaus_channel @@ -52,56 +54,53 @@ class SituImage: IdentityChannelTransform() for file in file_list ] self.peak_finder = PeakFinderDifferenceOfGaussian() + # TODO: make peak finder a constructor argument - def get_data(self): + def get_data(self) -> np.ndarray: if self.data is None: self._load_image() # TODO: apply transformations return self.data - def apply_transformations(): + def apply_transformations(self): # TODO: implement pass - def set_channel_transformation(self, channel, transformation): + def set_channel_transformation(self, channel: int, transformation: ChannelTransform): self.channel_transformations[channel] = transformation - def get_channel_count(self): + def get_channel_count(self) -> int: return self.get_data().shape[0] - def get_focus_level_count(self): + def get_focus_level_count(self) -> int: return self.get_data().shape[1] - def get_focus_level(self, channel, focus_level): + def get_focus_level(self, channel: int, focus_level: int) -> np.ndarray: """Loads channel and focus level of an image. Args: channel (int): The channel to be used focus_level (int): The focus level to be used - Returns: - numpy.array: The loaded image of shape (width, height) + + Returns: + np.ndarray: The loaded image of shape (width, height) """ return self.get_data()[channel, focus_level, :, :] - def get_channel(self, channel): - ''' - Loads and returns the specified channel for all focus_levels. + def get_channel(self, channel: int) -> np.ndarray: + """Loads and returns the specified channel for all focus_levels. - Returns: - numpy.array: The loaded image of shape (focus_level, width, height) - ''' + Args: + channel (int): The channel to be returned + + Returns: + np.ndarray: The loaded image of shape (focus_level, width, height) + """ return self.get_data()[channel, :, :, :] def _load_image(self): - ''' - Loads the channels of an image from seperate files and returns them as numpy array. - - Parameters: - channel (int): - The channel that should be used - Returns: - numpy.array: The loaded image of shape (channels, focus_level, width, height) - ''' + """Loads the whole image from files + """ image_list = [] for focus_level_list in self.files: channels = [] @@ -111,54 +110,46 @@ class SituImage: self.data = np.array(image_list) def unload_image(self): - ''' - Unloads the image data to free up memory - ''' + """Unloads the image data to free up memory + """ self.data = None - def show_channel(self, channel, focus_level=0): - ''' - Prints and returns the specified channel and focus_level of the image. + def show_channel(self, channel: int, focus_level: int = 0) -> Image: + """Prints and returns the specified channel and focus_level of the image. + + Args: + channel (int): The channel that should be used when printing + focus_level (int, optional): The focus level that should be used. Defaults to 0. - Parameters: - channel (int): - The channel that should be used when printing - focus_level (int) default: 0: - The focus level that should be used Returns: - image: The image of the specified focus level and channel - ''' - img = Image.fromarray(self.get_data()[0, 0, :, :]) + Image: The image of the specified focus level and channel + """ + img = Image.fromarray(self.get_data()[channel, focus_level, :, :]) img.show() return img - def get_channel_peaks(self, channel, focus_level=0, min_sigma=0.75, max_sigma=3, threshold=0.1): - ''' - Returns the coordinates of peaks (local maxima) in the specified channel and focus_level. - This method uses skimage blob_dog, therefore using difference of gaussian. + def get_channel_peaks(self, channel: int, focus_level: int = 0) -> np.ndarray: + """Returns the coordinates of peaks (local maxima) in the specified channel and focus_level. It uses the self. + + Args: + channel (int): The channel that should be used when printing + focus_level (int, optional): The focus level that should be used. Defaults to 0. - Parameters: - channel (int): - The channel that should be used when printing - focus_level (int) default: 0: - The focus level that should be used Returns: - np.array: The peaks found by this method as np.array of shape (n, 2) - ''' + np.ndarray: The peaks found by this method as np.array of shape (n, 2) + """ return self.peak_finder.find_peaks(self.get_data()[channel, focus_level, :, :]) - def show_channel_peaks(self, channel, focus_level=0): - ''' - Returns and shows the found. Uses get_channel_peaks internally. + def show_channel_peaks(self, channel: int, focus_level: int = 0) -> Image: + """Returns and shows the found peaks drawn onto the image. Uses get_channel_peaks internally. + + Args: + channel (int): The channel that should be used when printing + focus_level (int, optional): The focus level that should be used. Defaults to 0. - Parameters: - channel (int): - The channel that should be used when printing - focus_level (int) default: 0: - The focus level that should be used Returns: - image: The image of the specified focus level and channel with encircled peaks. - ''' + Image: The image of the specified focus level and channel with encircled peaks. + """ peaks = self.get_channel_peaks( channel, focus_level)