Skip to content

smFISH decoding #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/bioprotean_smFISH/smFISH_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from merfish3danalysis.qi2labDataStore import qi2labDataStore
from merfish3danalysis.PixelDecoder import PixelDecoder
from pathlib import Path
# import numpy as np

root_path = Path(r"/data/smFISH/02202025_Bartelle_control_smFISH_TqIB")

# initialize datastore
datastore_path = root_path / Path(r"qi2labdatastore")
datastore = qi2labDataStore(datastore_path)
merfish_bits = datastore.num_bits

# initialize decodor class
decoder = PixelDecoder(
datastore=datastore,
use_mask=False,
merfish_bits=merfish_bits,
verbose=1,
smFISH = True
)

# decode one tile
decoder.decode_one_tile(
tile_idx=0, # Specify the tile index
display_results=True, # Set to True to visualize results in Napari
lowpass_sigma=(3, 1, 1), # Lowpass filter sigma
magnitude_threshold=0.75, # L2-norm threshold
upper_magnitude_threshold=1.75, # Upper L2-norm threshold
minimum_pixels=3.0, # Minimum number of pixels for a barcode
use_normalization=True, # Use normalization
ufish_threshold=0.5 # Ufish threshold
)

print("Decoding complete.")
82 changes: 64 additions & 18 deletions src/merfish3danalysis/PixelDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
use_mask: Optional[bool] = False,
z_range: Optional[Sequence[int]] = None,
include_blanks: Optional[bool] = True,
smFISH: bool = False
):
self._datastore = datastore
self._verbose = verbose
Expand All @@ -79,6 +80,10 @@ def __init__(

self._n_merfish_bits = merfish_bits

# Is this data smFISH or MERFISH?
# Default is False, meaning data is MERFISH
self._smFISH = smFISH

if self._datastore.microscope_type == "2D":
self._is_3D = False
else:
Expand All @@ -104,8 +109,15 @@ def __init__(
self._optimize_normalization_weights = False
self._global_normalization_loaded = False
self._iterative_normalization_loaded = False
self._distance_threshold = 0.5172 # default for HW4D4 code. TO DO: calculate based on self._num_on-bits
self._magnitude_threshold = 0.9 # default for HW4D4 code

if self._smFISH:
# establish lower magnitude threshold for smFISH data
self._magnitude_threshold = 0.75
self._upper_magnitude_threshold = 1.75
self._distance_threshold = 1.0
else:
self._magnitude_threshold = 0.9 # default for HW4D4 code
self._distance_threshold = 0.5172 # default for HW4D4 code. TO DO: calculate based on self._num_on-bits

def _load_codebook(self):
"""Load and parse codebook into gene_id and codeword matrix."""
Expand Down Expand Up @@ -771,8 +783,9 @@ def _calculate_distances(
return min_distances, min_indices

def _decode_pixels(
self, distance_threshold: float = 0.5172,
magnitude_threshold: float = 1.0
self, distance_threshold: float = None,
magnitude_threshold: float = None,
upper_magnitude_threshold: float = None, # Only used for smFISH data
):
"""Decode pixels using the decoding matrix.

Expand All @@ -784,6 +797,8 @@ def _decode_pixels(
magnitude_threshold : float, default 1.0.
Magnitude threshold for decoding.
"""
if distance_threshold is None:
distance_threshold = self._distance_threshold

if self._filter_type == "lp":
original_shape = self._image_data_lp.shape
Expand Down Expand Up @@ -852,7 +867,13 @@ def _decode_pixels(
decoded_trace = cp.full(distance_trace.shape[0], -1, dtype=cp.int16)
mask_trace = distance_trace < distance_threshold
decoded_trace[mask_trace] = codebook_index_trace[mask_trace]
decoded_trace[pixel_magnitude_trace <= magnitude_threshold] = -1

# For smFISH data, we are adding an upper magnitude threshold and setting pixels above this threshold to -1.
if self._smFISH:
decoded_trace[pixel_magnitude_trace >= upper_magnitude_threshold] = -1
decoded_trace[pixel_magnitude_trace <= magnitude_threshold] = -1
else:
decoded_trace[pixel_magnitude_trace <= magnitude_threshold] = -1

self._decoded_image[z_idx, :] = cp.asnumpy(
cp.reshape(cp.round(decoded_trace, 3), z_plane_shape[1:])
Expand Down Expand Up @@ -1841,18 +1862,33 @@ def on_close_callback():
app = QApplication.instance()

app.lastWindowClosed.connect(on_close_callback)

if self._smFISH:
for bit in range(self._datastore.num_bits):
viewer.add_image(
self._scaled_pixel_images[bit],
scale=[self._axial_step, self._pixel_size, self._pixel_size],
name="pixels_" + str(int(bit)+1),
)
else:
viewer.add_image(
self._scaled_pixel_images,
scale=[self._axial_step, self._pixel_size, self._pixel_size],
name="pixels",
)

viewer.add_image(
self._scaled_pixel_images,
scale=[self._axial_step, self._pixel_size, self._pixel_size],
name="pixels",
)

viewer.add_image(
self._decoded_image,
scale=[self._axial_step, self._pixel_size, self._pixel_size],
name="decoded",
)
for bit in range(self._datastore.num_bits):
# Create a mask for pixels decoded as this bit
mask = self._decoded_image == bit
# Create an image with intensities where mask is True, zeros elsewhere
decoded_intensity_image = np.zeros_like(self._magnitude_image)
decoded_intensity_image[mask] = self._magnitude_image[mask]

viewer.add_image(
decoded_intensity_image,
scale=[self._axial_step, self._pixel_size, self._pixel_size],
name="decoded_" + str(int(bit)+1),
)

viewer.add_image(
self._magnitude_image,
Expand Down Expand Up @@ -1903,7 +1939,8 @@ def decode_one_tile(
tile_idx: int = 0,
display_results: bool = False,
lowpass_sigma: Optional[Sequence[float]] = (3, 1, 1),
magnitude_threshold: Optional[float] = 0.9,
magnitude_threshold: Optional[float] = None,
upper_magnitude_threshold: Optional[float] = None,
minimum_pixels: Optional[float] = 3.0,
use_normalization: Optional[bool] = True,
ufish_threshold: Optional[float] = 0.5,
Expand Down Expand Up @@ -1932,14 +1969,23 @@ def decode_one_tile(

if use_normalization:
self._load_iterative_normalization_vectors()
if magnitude_threshold is None:
magnitude_threshold = self._magnitude_threshold
if upper_magnitude_threshold is None:
upper_magnitude_threshold = getattr(self, "_upper_magnitude_threshold", None)

print(f"The distance threshold is {self._distance_threshold}")
print(f"The lower magnitude threshold is {magnitude_threshold}")
print(f"The upper magnitude threshold is {upper_magnitude_threshold}")

self._tile_idx = tile_idx
self._load_bit_data(ufish_threshold=ufish_threshold)
if not (np.any(lowpass_sigma == 0)):
self._lp_filter(sigma=lowpass_sigma)
self._decode_pixels(
distance_threshold=self._distance_threshold,
magnitude_threshold=magnitude_threshold,
magnitude_threshold=magnitude_threshold,
upper_magnitude_threshold=upper_magnitude_threshold,
)
if display_results:
self._display_results()
Expand Down