diff --git a/situr/registration/__init__.py b/situr/registration/__init__.py index 73fc395..3637663 100644 --- a/situr/registration/__init__.py +++ b/situr/registration/__init__.py @@ -1,5 +1,5 @@ from .registration import Registration, RegistrationFunction, FilterregRegistrationFunction -from .channel_registration import ChannelRegistration +from .channel_registration import SituImageChannelRegistration, ChannelRegistration, AcrossRoundChannelRegistration from .round_registration import RoundRegistration, AllChannelRoundRegistration from .tile_registration import CombinedRegistration from .peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian diff --git a/situr/registration/channel_registration.py b/situr/registration/channel_registration.py index 6b187d1..72ea10d 100644 --- a/situr/registration/channel_registration.py +++ b/situr/registration/channel_registration.py @@ -1,10 +1,13 @@ +import numpy as np +from situr.image import situ_image + +from situr.image.situ_tile import Tile from situr.registration.peak_finder import PeakFinder, PeakFinderDifferenceOfGaussian from situr.image.situ_image import SituImage from situr.registration import Registration, RegistrationFunction, FilterregRegistrationFunction -class ChannelRegistration(Registration): - +class SituImageChannelRegistration(Registration): def do_channel_registration(self, situ_img: SituImage, reference_channel: int = 0): # For each channel (except nucleus) compute transform compared to reference_channel # Add Channel transformation to Channel @@ -16,4 +19,38 @@ class ChannelRegistration(Registration): situ_img, channel) transformation = self.registration_function.do_registration( current_channel_peaks, reference_peaks) - situ_img.set_channel_transformation(channel, transformation) + situ_img.set_channel_transformation( + channel, transformation) + + +class ChannelRegistration(Registration): + def do_channel_registration(self, tile: Tile, reference_channel: int = 0): + registration = SituImageChannelRegistration() + # For each channel (except nucleus) compute transform compared to reference_channel + # Add Channel transformation to Channel + for round in range(tile.get_round_count()): + situ_img = tile.get_round(round) + registration.do_channel_registration(situ_img, reference_channel) + + +class AcrossRoundChannelRegistration(ChannelRegistration): + def do_channel_registration(self, tile: Tile, reference_channel: int = 0): + reference_peaks = [] + for round in range(tile.get_round_count()): + reference_peaks.append(self.peak_finder.get_channel_peaks( + tile.get_round(round), reference_channel)) + reference_peaks = np.concatenate(reference_peaks, axis=0) + for channel in range(tile.get_channel_count()): + if channel != tile.get_round(0).nucleaus_channel and channel != reference_channel: + current_channel_peaks = [] + for round in range(tile.get_round_count()): + current_channel_peaks.append( + self.peak_finder.get_channel_peaks(tile.get_round(round), channel)) + current_channel_peaks = np.concatenate( + current_channel_peaks, axis=0) + + transformation = self.registration_function.do_registration( + current_channel_peaks, reference_peaks) + for round in range(tile.get_round_count()): + tile.get_round(round).set_channel_transformation( + channel, transformation) diff --git a/situr/registration/tile_registration.py b/situr/registration/tile_registration.py index e41f3c0..99e8820 100644 --- a/situr/registration/tile_registration.py +++ b/situr/registration/tile_registration.py @@ -17,10 +17,7 @@ class CombinedRegistration: Args: tile (Tile): The tile that the registration and transformations are to be performed on. """ - # Do channel registration - for round in range(tile.get_round_count()): - img = tile.get_round(round) - self.channel_registration.do_channel_registration(img) + self.channel_registration.do_channel_registration(tile) tile.apply_channel_transformations()