add type hints to tile

This commit is contained in:
2021-07-14 11:30:57 +02:00
parent c2b3b46fc4
commit 10f652ede1

View File

@@ -1,8 +1,11 @@
from situr.transformation.round_transformation import RoundTransform
import numpy as np import numpy as np
from situr.image.situ_image import SituImage from situr.image.situ_image import SituImage
from situr.transformation import IdentityRoundTransform from situr.transformation import IdentityRoundTransform
from typing import List
class Tile: class Tile:
''' '''
@@ -13,7 +16,7 @@ class Tile:
X 2048 X 2048
''' '''
def __init__(self, file_list, nucleaus_channel=4): def __init__(self, file_list: List[List[List[str]]], nucleaus_channel: int = 4):
self.images = [] self.images = []
self.round_transformations = [] self.round_transformations = []
for situ_image_list in file_list: for situ_image_list in file_list:
@@ -33,35 +36,31 @@ class Tile:
# TODO: implement # TODO: implement
pass pass
def get_image_round(self, round): def set_round_transformation(self, round, transformation: RoundTransform):
return self.images[round]
def set_round_transformation(self, round, transformation):
self.round_transformations[round] = transformation self.round_transformations[round] = transformation
def get_round_count(self): def get_round_count(self) -> int:
return len(self.images) return len(self.images)
def get_channel_count(self): def get_channel_count(self) -> int:
return self.images[0].get_channel_count() return self.images[0].get_channel_count()
def get_round(self, round_number): def get_round(self, round_number: int) -> SituImage:
"""This """This methods returns the round based on round number
Args: Args:
round_number (integer): The round number starting with 0 round_number (int): The round number (starting with index 0)
Returns: Returns:
SituImage: The image corresponding to the requested round number. SituImage: The image corresponding to the requested round number.
""" """
return self.images[id] return self.images[round_number]
def to_numpy_array(self): def to_numpy_array(self) -> np.ndarray:
tmp_list = [] tmp_list = []
for image in self.images: for image in self.images:
tmp_list.append(image.get_data()) tmp_list.append(image.get_data())
return np.array(tmp_list) return np.array(tmp_list)
def get_channel(self, round, channel, focus_level=0): def get_channel(self, round: int, channel: int, focus_level: int = 0) -> np.ndarray:
return self.images[round].get_channel(channel, focus_level=focus_level) return self.images[round].get_channel(channel, focus_level=focus_level)