diff --git a/situr/registration/__init__.py b/situr/registration/__init__.py index 88400f1..73fc395 100644 --- a/situr/registration/__init__.py +++ b/situr/registration/__init__.py @@ -1,5 +1,5 @@ from .registration import Registration, RegistrationFunction, FilterregRegistrationFunction from .channel_registration import ChannelRegistration -from .round_registration import RoundRegistration +from .round_registration import RoundRegistration, AllChannelRoundRegistration from .tile_registration import CombinedRegistration from .peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian diff --git a/situr/registration/round_registration.py b/situr/registration/round_registration.py index 492cda7..1975b71 100644 --- a/situr/registration/round_registration.py +++ b/situr/registration/round_registration.py @@ -1,5 +1,6 @@ from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction +import numpy as np class RoundRegistration(Registration): @@ -22,3 +23,35 @@ class RoundRegistration(Registration): transformation = self.registration_function.do_registration( current_round_peaks, reference_peaks) situ_tile.set_round_transformation(round, transformation) + + +class AllChannelRoundRegistration(RoundRegistration): + + def do_round_registration(self, situ_tile, reference_round: int = 0, reference_channel: int = 0): + """This method generates a round registration transformation for a tile and saves it in the tile. + + Args: + situ_tile (Tile): The tile that the transformation is to be performed on. + reference_round (int, optional): The round that is referenced and will not be changed. Defaults to 0. + reference_channel (int, optional): This parameter is ignored. + """ + reference_peaks = [] + for channel in range(situ_tile.get_channel_count()): + # TODO: possibly exclude nucleaus channel + reference_peaks.append(self.peak_finder.get_channel_peaks(situ_tile.get_round( + reference_round), channel)) + reference_peaks = np.concatenate(reference_peaks, axis=0) + + for round in range(situ_tile.get_round_count()): + if round != reference_channel: + current_round_peaks = [] + for channel in range(situ_tile.get_channel_count()): + # TODO: possibly exclude nucleaus channel + current_round_peaks.append(self.peak_finder.get_channel_peaks( + situ_tile.get_round(round), channel)) + current_round_peaks = np.concatenate( + current_round_peaks, axis=0) + + transformation = self.registration_function.do_registration( + current_round_peaks, reference_peaks) + situ_tile.set_round_transformation(round, transformation)