diff --git a/examples/bioprotean_smFISH/smFISH_decoding.py b/examples/bioprotean_smFISH/smFISH_decoding.py new file mode 100644 index 0000000..4b54833 --- /dev/null +++ b/examples/bioprotean_smFISH/smFISH_decoding.py @@ -0,0 +1,34 @@ +from merfish3danalysis.qi2labDataStore import qi2labDataStore +from merfish3danalysis.PixelDecoder import PixelDecoder +from pathlib import Path +# import numpy as np + +root_path = Path(r"/data/smFISH/02202025_Bartelle_control_smFISH_TqIB") + +# initialize datastore +datastore_path = root_path / Path(r"qi2labdatastore") +datastore = qi2labDataStore(datastore_path) +merfish_bits = datastore.num_bits + +# initialize decodor class +decoder = PixelDecoder( + datastore=datastore, + use_mask=False, + merfish_bits=merfish_bits, + verbose=1, + smFISH = True +) + +# decode one tile +decoder.decode_one_tile( + tile_idx=0, # Specify the tile index + display_results=True, # Set to True to visualize results in Napari + lowpass_sigma=(3, 1, 1), # Lowpass filter sigma + magnitude_threshold=0.75, # L2-norm threshold + upper_magnitude_threshold=1.75, # Upper L2-norm threshold + minimum_pixels=3.0, # Minimum number of pixels for a barcode + use_normalization=True, # Use normalization + ufish_threshold=0.5 # Ufish threshold +) + +print("Decoding complete.") \ No newline at end of file diff --git a/src/merfish3danalysis/PixelDecoder.py b/src/merfish3danalysis/PixelDecoder.py index 1fcea1f..addaf9d 100644 --- a/src/merfish3danalysis/PixelDecoder.py +++ b/src/merfish3danalysis/PixelDecoder.py @@ -71,6 +71,7 @@ def __init__( use_mask: Optional[bool] = False, z_range: Optional[Sequence[int]] = None, include_blanks: Optional[bool] = True, + smFISH: bool = False ): self._datastore = datastore self._verbose = verbose @@ -79,6 +80,10 @@ def __init__( self._n_merfish_bits = merfish_bits + # Is this data smFISH or MERFISH? + # Default is False, meaning data is MERFISH + self._smFISH = smFISH + if self._datastore.microscope_type == "2D": self._is_3D = False else: @@ -104,8 +109,15 @@ def __init__( self._optimize_normalization_weights = False self._global_normalization_loaded = False self._iterative_normalization_loaded = False - self._distance_threshold = 0.5172 # default for HW4D4 code. TO DO: calculate based on self._num_on-bits - self._magnitude_threshold = 0.9 # default for HW4D4 code + + if self._smFISH: + # establish lower magnitude threshold for smFISH data + self._magnitude_threshold = 0.75 + self._upper_magnitude_threshold = 1.75 + self._distance_threshold = 1.0 + else: + self._magnitude_threshold = 0.9 # default for HW4D4 code + self._distance_threshold = 0.5172 # default for HW4D4 code. TO DO: calculate based on self._num_on-bits def _load_codebook(self): """Load and parse codebook into gene_id and codeword matrix.""" @@ -771,8 +783,9 @@ def _calculate_distances( return min_distances, min_indices def _decode_pixels( - self, distance_threshold: float = 0.5172, - magnitude_threshold: float = 1.0 + self, distance_threshold: float = None, + magnitude_threshold: float = None, + upper_magnitude_threshold: float = None, # Only used for smFISH data ): """Decode pixels using the decoding matrix. @@ -784,6 +797,8 @@ def _decode_pixels( magnitude_threshold : float, default 1.0. Magnitude threshold for decoding. """ + if distance_threshold is None: + distance_threshold = self._distance_threshold if self._filter_type == "lp": original_shape = self._image_data_lp.shape @@ -852,7 +867,13 @@ def _decode_pixels( decoded_trace = cp.full(distance_trace.shape[0], -1, dtype=cp.int16) mask_trace = distance_trace < distance_threshold decoded_trace[mask_trace] = codebook_index_trace[mask_trace] - decoded_trace[pixel_magnitude_trace <= magnitude_threshold] = -1 + + # For smFISH data, we are adding an upper magnitude threshold and setting pixels above this threshold to -1. + if self._smFISH: + decoded_trace[pixel_magnitude_trace >= upper_magnitude_threshold] = -1 + decoded_trace[pixel_magnitude_trace <= magnitude_threshold] = -1 + else: + decoded_trace[pixel_magnitude_trace <= magnitude_threshold] = -1 self._decoded_image[z_idx, :] = cp.asnumpy( cp.reshape(cp.round(decoded_trace, 3), z_plane_shape[1:]) @@ -1841,18 +1862,33 @@ def on_close_callback(): app = QApplication.instance() app.lastWindowClosed.connect(on_close_callback) + + if self._smFISH: + for bit in range(self._datastore.num_bits): + viewer.add_image( + self._scaled_pixel_images[bit], + scale=[self._axial_step, self._pixel_size, self._pixel_size], + name="pixels_" + str(int(bit)+1), + ) + else: + viewer.add_image( + self._scaled_pixel_images, + scale=[self._axial_step, self._pixel_size, self._pixel_size], + name="pixels", + ) - viewer.add_image( - self._scaled_pixel_images, - scale=[self._axial_step, self._pixel_size, self._pixel_size], - name="pixels", - ) - - viewer.add_image( - self._decoded_image, - scale=[self._axial_step, self._pixel_size, self._pixel_size], - name="decoded", - ) + for bit in range(self._datastore.num_bits): + # Create a mask for pixels decoded as this bit + mask = self._decoded_image == bit + # Create an image with intensities where mask is True, zeros elsewhere + decoded_intensity_image = np.zeros_like(self._magnitude_image) + decoded_intensity_image[mask] = self._magnitude_image[mask] + + viewer.add_image( + decoded_intensity_image, + scale=[self._axial_step, self._pixel_size, self._pixel_size], + name="decoded_" + str(int(bit)+1), + ) viewer.add_image( self._magnitude_image, @@ -1903,7 +1939,8 @@ def decode_one_tile( tile_idx: int = 0, display_results: bool = False, lowpass_sigma: Optional[Sequence[float]] = (3, 1, 1), - magnitude_threshold: Optional[float] = 0.9, + magnitude_threshold: Optional[float] = None, + upper_magnitude_threshold: Optional[float] = None, minimum_pixels: Optional[float] = 3.0, use_normalization: Optional[bool] = True, ufish_threshold: Optional[float] = 0.5, @@ -1932,6 +1969,14 @@ def decode_one_tile( if use_normalization: self._load_iterative_normalization_vectors() + if magnitude_threshold is None: + magnitude_threshold = self._magnitude_threshold + if upper_magnitude_threshold is None: + upper_magnitude_threshold = getattr(self, "_upper_magnitude_threshold", None) + + print(f"The distance threshold is {self._distance_threshold}") + print(f"The lower magnitude threshold is {magnitude_threshold}") + print(f"The upper magnitude threshold is {upper_magnitude_threshold}") self._tile_idx = tile_idx self._load_bit_data(ufish_threshold=ufish_threshold) @@ -1939,7 +1984,8 @@ def decode_one_tile( self._lp_filter(sigma=lowpass_sigma) self._decode_pixels( distance_threshold=self._distance_threshold, - magnitude_threshold=magnitude_threshold, + magnitude_threshold=magnitude_threshold, + upper_magnitude_threshold=upper_magnitude_threshold, ) if display_results: self._display_results()