Source code for pycroglia.core.labeled_cells

import numpy as np
from typing import Set, Optional

from abc import ABC, abstractmethod

import skimage.segmentation
from numpy.typing import NDArray
from skimage import measure

from pycroglia.core.enums import SkimageCellConnectivity
from pycroglia.core.errors.errors import PycrogliaException


[docs] class LabelingStrategy(ABC): """Abstract base class for labeling strategies. Subclasses must implement the label method to generate labeled arrays from input images. Attributes: ARRAY_ELEMENTS_TYPE (type): Data type for output arrays. """ ARRAY_ELEMENTS_TYPE = np.uint8
[docs] @abstractmethod def label(self, img: NDArray) -> NDArray: """Labels the input image according to the strategy. Args: img (NDArray): Input image to label. Returns: NDArray: Labeled array. """ pass
[docs] class SkimageImgLabeling(LabelingStrategy): """Labeling strategy using skimage.measure.label. Attributes: connectivity (pycroglia.core.enums.SkimageCellConnectivity): Connectivity rule for labeling. """ def __init__(self, connectivity: SkimageCellConnectivity): """Initializes SkimageImgLabeling with the given connectivity. Args: connectivity (SkimageCellConnectivity): Connectivity rule for labeling. """ self.connectivity = connectivity
[docs] def label(self, img: NDArray) -> NDArray: """Labels the input image using skimage.measure.label. Args: img (NDArray): Input image to label. Returns: NDArray: Labeled array. """ labels = measure.label(img, connectivity=self.connectivity.value) return labels
[docs] class MaskListLabeling(LabelingStrategy): """Labeling strategy using a list of binary masks. Each mask should have the same shape as the target image. """ def __init__(self, masks: list[NDArray]): """ Args: masks (list[NDArray]): List of binary masks (same shape as the target image). """ self.masks = masks
[docs] def label(self, img: NDArray) -> NDArray: """ Args: img (NDArray): Reference image to determine output shape. Returns: NDArray: Labeled array. """ labels = np.zeros_like(img, dtype=self.ARRAY_ELEMENTS_TYPE) for idx, mask in enumerate(self.masks, start=1): labels[mask > 0] = idx relabeled, _, _ = skimage.segmentation.relabel_sequential(labels) return relabeled
[docs] class LabeledCells: """Represents labeled connected cell components in a 3D image. Provides methods to access individual cells, their sizes, and 2D projections. Attributes: ARRAY_ELEMENTS_TYPE (type): Data type for output arrays. z (int): Depth of the image. y (int): Height of the image. x (int): Width of the image. labels (NDArray): Labeled 3D array. """ ARRAY_ELEMENTS_TYPE = np.uint8 def __init__(self, img: NDArray, labeling_strategy: LabelingStrategy): """Initializes LabeledCells with a 3D image and a labeling strategy. Args: img (NDArray): 3D binary image. labeling_strategy (LabelingStrategy): Strategy for labeling connected components. """ self.z, self.y, self.x = img.shape self.labels = labeling_strategy.label(img) self._cell_sizes = np.bincount(self.labels.ravel()) self._n_cells = len(self._cell_sizes) - 1 # Buffer self._buffer: Optional[NDArray] = None def _get_buffer(self) -> NDArray: if self._buffer is None: self._buffer = np.empty(self.labels.shape, dtype=self.ARRAY_ELEMENTS_TYPE)
[docs] def len(self) -> int: """Returns the number of labeled cells. Returns: int: Number of labeled cells (excluding background). """ return self._n_cells
def _is_valid_index(self, index: int) -> bool: """Checks if the given index is a valid cell label. Args: index (int): Cell label index. Returns: bool: True if valid, False otherwise. """ return 0 < index <= self.len()
[docs] def get_cell(self, index: int) -> NDArray: """Returns a binary mask for the specified cell. Args: index (int): Cell label index. Returns: NDArray: 3D binary mask for the cell. Raises: PycrogliaException: If the index is invalid (error_code=2000). """ if not self._is_valid_index(index): raise PycrogliaException(error_code=2000) return np.equal(self.labels, index, out=self._get_buffer()).copy()
[docs] def get_cells_list(self) -> list[NDArray]: number_of_cells = self.len() if number_of_cells == 0: return [] masks = [self.get_cell(i) for i in range(1, number_of_cells + 1)] return masks
[docs] def get_cell_size(self, index: int) -> int: """Returns the size (number of voxels) of the specified cell. Args: index (int): Cell label index. Returns: int: Number of voxels in the cell. Raises: PycrogliaException: If the index is invalid (error_code=2000). """ if not self._is_valid_index(index): raise PycrogliaException(error_code=2000) return int(self._cell_sizes[index])
[docs] def labels_to_2d(self) -> NDArray: """Projects the labeled 3D array to 2D by taking the maximum label along the z-axis. Returns: NDArray: 2D array with the maximum label for each (y, x) position. """ return self.labels.max(axis=0)
[docs] def cell_to_2d(self, index: int) -> NDArray: """Projects a 3D cell to 2D by summing along the z-axis. Args: index (int): Cell label index. Returns: NDArray: 2D projection of the cell. Raises: PycrogliaException: If the index is invalid (error_code=2000). """ if not self._is_valid_index(index): raise PycrogliaException(error_code=2000) cell_matrix = np.zeros((self.z, self.y, self.x), dtype=self.ARRAY_ELEMENTS_TYPE) cell_matrix[self.labels == index] = 1 flatten = cell_matrix.sum(axis=0) return flatten
[docs] def all_cells_to_2d(self) -> NDArray: """Projects all labeled cells to 2D and stacks them along a new axis. Returns: NDArray: 3D array where each slice is the 2D projection of a cell. """ all_cells_matrix = np.zeros( (self.len(), self.y, self.x), dtype=self.ARRAY_ELEMENTS_TYPE ) for i in range(1, self.len() + 1): cell_array = self.cell_to_2d(i) all_cells_matrix[i - 1, :, :] = cell_array return all_cells_matrix
[docs] def get_border_cells(self) -> Set[int]: """Detects cells that touch the image borders in any Z slice. Identifies cells whose labels appear on any edge of the 3D image volume. A cell is considered a border cell if any of its voxels touch the top, bottom, left, or right edges of any Z slice. Returns: Set[int]: Set of cell IDs that touch the image borders. """ border_cells = set() border_labels = set() z, y, x = self.labels.shape for z_slice in range(z): # Top and bottom borders for this Z slice border_labels.update(np.unique(self.labels[z_slice, 0, :])) border_labels.update(np.unique(self.labels[z_slice, y - 1, :])) # Left and right borders for this Z slice border_labels.update(np.unique(self.labels[z_slice, :, 0])) border_labels.update(np.unique(self.labels[z_slice, :, x - 1])) # Remove background (label 0) border_labels.discard(0) # Filter to only include valid cell IDs for label in border_labels: if 1 <= label <= self.len(): border_cells.add(label) return border_cells
[docs] def selected_cells_mask(self, selected_ids: Set[int]) -> NDArray: """Return a combined 3D binary mask for the provided cell IDs. Performs a vectorized single-pass operation using np.isin over the internal labels array to build the combined mask. Args: selected_ids (Set[int]): Set of cell label IDs to include. Returns: NDArray: 3D binary mask (z, y, x) with 1 where any of the selected cells are present. """ if not selected_ids: return np.zeros_like(self.labels, dtype=self.ARRAY_ELEMENTS_TYPE) mask = np.isin(self.labels, list(selected_ids)) return mask.astype(self.ARRAY_ELEMENTS_TYPE)