mirror of
https://github.com/13hannes11/situr.git
synced 2024-09-03 20:50:58 +02:00
make more classes abstract
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import abc
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from skimage import img_as_float
|
||||
@@ -17,7 +18,17 @@ def remove_dim(array: np.ndarray):
|
||||
return array[:, :-1]
|
||||
|
||||
|
||||
class PeakFinderDifferenceOfGaussian:
|
||||
class PeakFinder:
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def find_peaks(self, img_array: np.ndarray) -> np.ndarray:
|
||||
"""Finds the peaks in the input image"""
|
||||
raise NotImplementedError(
|
||||
self.__class__.__name__ + '.find_peaks')
|
||||
|
||||
|
||||
class PeakFinderDifferenceOfGaussian(PeakFinder):
|
||||
def __init__(self, min_sigma=0.75, max_sigma=3, threshold=0.1):
|
||||
self.min_sigma = min_sigma
|
||||
self.max_sigma = max_sigma
|
||||
@@ -44,22 +55,21 @@ class SituImage:
|
||||
A list of lists. Each inner list corresponds to one focus level. Its contents correspons to a file for each channel.
|
||||
nucleaus_channel : int
|
||||
tells which channel is used for showing where the cell nucleuses are.
|
||||
peak_finder :
|
||||
"""
|
||||
|
||||
def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4):
|
||||
def __init__(self, file_list: List[List[str]], nucleaus_channel: int = 4, peak_finder: PeakFinder = PeakFinderDifferenceOfGaussian()):
|
||||
self.files = file_list
|
||||
self.data = None
|
||||
self.nucleaus_channel = nucleaus_channel
|
||||
self.channel_transformations = [
|
||||
IdentityChannelTransform() for file in file_list
|
||||
]
|
||||
self.peak_finder = PeakFinderDifferenceOfGaussian()
|
||||
# TODO: make peak finder a constructor argument
|
||||
self.peak_finder = peak_finder
|
||||
|
||||
def get_data(self) -> np.ndarray:
|
||||
if self.data is None:
|
||||
self._load_image()
|
||||
# TODO: apply transformations
|
||||
return self.data
|
||||
|
||||
def apply_transformations(self):
|
||||
|
||||
@@ -33,5 +33,6 @@ class FilterregRegistrationFunction(RegistrationFunction):
|
||||
|
||||
|
||||
class Registration:
|
||||
__metaclass__ = abc.ABCMeta
|
||||
def __init__(self, registration_function: RegistrationFunction):
|
||||
self.registration_function = registration_function
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
import abc
|
||||
|
||||
class Transform:
|
||||
pass
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
Reference in New Issue
Block a user