diff --git a/merlin/analysis/segment.py b/merlin/analysis/segment.py index 786cb0d9..8ce22e89 100755 --- a/merlin/analysis/segment.py +++ b/merlin/analysis/segment.py @@ -1,7 +1,10 @@ import cv2 import numpy as np from skimage import measure -from skimage import segmentation +from skimage import segmentation as skiseg +from skimage import morphology +from skimage import feature +from skimage import filters import rtree from shapely import geometry from typing import List, Dict @@ -10,10 +13,10 @@ from merlin.core import dataset from merlin.core import analysistask from merlin.util import spatialfeature -from merlin.util import watershed +from merlin.util import segmentation import pandas import networkx as nx - +import time class FeatureSavingAnalysisTask(analysistask.ParallelAnalysisTask): @@ -90,13 +93,13 @@ def _run_analysis(self, fragmentIndex): .get_data_channel_index(self.parameters['watershed_channel_name']) watershedImages = self._read_and_filter_image_stack(fragmentIndex, watershedIndex, 5) - seeds = watershed.separate_merged_seeds( - watershed.extract_seeds(seedImages)) - normalizedWatershed, watershedMask = watershed.prepare_watershed_images( - watershedImages) + seeds = segmentation.separate_merged_seeds( + segmentation.extract_seeds(seedImages)) + normalizedWatershed, watershedMask = segmentation\ + .prepare_watershed_images(watershedImages) seeds[np.invert(watershedMask)] = 0 - watershedOutput = segmentation.watershed( + watershedOutput = skiseg.watershed( normalizedWatershed, measure.label(seeds), mask=watershedMask, connectivity=np.ones((3, 3, 3)), watershed_line=True) @@ -120,6 +123,203 @@ def _read_and_filter_image_stack(self, fov: int, channelIndex: int, for z in range(len(self.dataSet.get_z_positions()))]) +class WatershedSegmentNucleiCV2(FeatureSavingAnalysisTask): + + """ + An analysis task that determines the boundaries of features in the + image data in each field of view using a watershed algorithm + implemented in CV2. + + A tutorial explaining the general scheme of the method can be + found in https://opencv-python-tutroals.readthedocs.io/en/latest/ + py_tutorials/py_imgproc/py_watershed/py_watershed.html. + + The watershed segmentation is performed in each z-position + independently and combined into 3D objects in a later step + + The class can be used to segment either nuclear or cytoplasmic + compartments. If both the compartment and membrane channels are the + same, the membrane channel is calculated from the edge transform of + the provided channel. + + Since each field of view is analyzed individually, the segmentation + results should be cleaned in order to merge cells that cross the + field of view boundary. + """ + + def __init__(self, dataSet, parameters=None, analysisName=None): + super().__init__(dataSet, parameters, analysisName) + + if 'membrane_channel_name' not in self.parameters: + self.parameters['membrane_channel_name'] = 'DAPI' + if 'compartment_channel_name' not in self.parameters: + self.parameters['compartment_channel_name'] = 'DAPI' + + def fragment_count(self): + return len(self.dataSet.get_fovs()) + + def get_estimated_memory(self): + # TODO - refine estimate + return 2048 + + def get_estimated_time(self): + # TODO - refine estimate + return 5 + + def get_dependencies(self): + return [self.parameters['warp_task'], + self.parameters['global_align_task']] + + def get_cell_boundaries(self) -> List[spatialfeature.SpatialFeature]: + featureDB = self.get_feature_database() + return featureDB.read_features() + + def _run_analysis(self, fragmentIndex): + startTime = time.time() + + globalTask = self.dataSet.load_analysis_task( + self.parameters['global_align_task']) + + # read membrane and compartment indexes + membraneIndex = self.dataSet \ + .get_data_organization() \ + .get_data_channel_index( + self.parameters['membrane_channel_name']) + compartmentIndex = self.dataSet \ + .get_data_organization() \ + .get_data_channel_index( + self.parameters['compartment_channel_name']) + + # read membrane and compartment images + membraneImages = self._read_image_stack(fragmentIndex, membraneIndex) + compartmentImages = self._read_image_stack(fragmentIndex, + compartmentIndex) + + # Prepare masks for cv2 watershed + watershedMarkers = segmentation.get_cv2_watershed_markers( + compartmentImages, + membraneImages, + self.parameters['compartment_channel_name'], + self.parameters['membrane_channel_name']) + + # perform watershed in individual z positions + watershedOutput = segmentation.apply_cv2_watershed(compartmentImages, + watershedMarkers) + + # combine all z positions in watershed + watershedCombinedOutput = segmentation \ + .combine_2d_segmentation_masks_into_3d(watershedOutput) + + # get features from mask. This is the slowestart (6 min for the + # previous part, 15+ for the rest, for a 7 frame Image. + zPos = np.array(self.dataSet.get_data_organization().get_z_positions()) + featureList = [spatialfeature.SpatialFeature.feature_from_label_matrix( + (watershedCombinedOutput == i), fragmentIndex, + globalTask.fov_to_global_transform(fragmentIndex), zPos) + for i in np.unique(watershedCombinedOutput) if i != 0] + + featureDB = self.get_feature_database() + featureDB.write_features(featureList, fragmentIndex) + + def _read_image_stack(self, fov: int, channelIndex: int) -> np.ndarray: + warpTask = self.dataSet.load_analysis_task( + self.parameters['warp_task']) + return np.array([warpTask.get_aligned_image(fov, channelIndex, z) + for z in range(len(self.dataSet.get_z_positions()))]) + + +class MachineLearningSegment(FeatureSavingAnalysisTask): + """ + An analysis task that determines the boundaries of features in the + image data in each field of view using a the specified machine learning + method. The available method is cellpose (https://github.com/MouseLand/ + cellpose). + + TODO: implement unets / Ilastik + """ + + def __init__(self, dataSet, parameters=None, analysisName=None): + super().__init__(dataSet, parameters, analysisName) + + if 'method' not in self.parameters: + self.parameters['method'] = 'cellpose' + if 'diameter' not in self.parameters: + self.parameters['diameter'] = 50 + if 'compartment_channel_name' not in self.parameters: + self.parameters['compartment_channel_name'] = 'DAPI' + if 'flow_threshold' not in self.parameters: + self.parameters['flow_threshold'] = 0.5 + if 'cellprob_threshold' not in self.parameters: + self.parameters['cellprob_threshold'] = 1 + + def fragment_count(self): + return len(self.dataSet.get_fovs()) + + def get_estimated_memory(self): + # TODO - refine estimate + return 2048 + + def get_estimated_time(self): + # TODO - refine estimate + return 5 + + def get_dependencies(self): + return [self.parameters['warp_task'], + self.parameters['global_align_task']] + + def get_cell_boundaries(self) -> List[spatialfeature.SpatialFeature]: + featureDB = self.get_feature_database() + return featureDB.read_features() + + def _run_analysis(self, fragmentIndex): + + globalTask = self.dataSet.load_analysis_task( + self.parameters['global_align_task']) + + # read membrane and compartment indexes + compartmentIndex = self.dataSet \ + .get_data_organization() \ + .get_data_channel_index( + self.parameters['compartment_channel_name']) + + # Read images and perform segmentation + compartmentImages = self._read_image_stack(fragmentIndex, + compartmentIndex) + + if self.parameters['method'] == 'cellpose': + segParameters = dict({ + 'method': 'cellpose', + 'diameter': self.parameters['diameter'], + 'channel': self.parameters['compartment_channel_name'], + 'flow_threshold': self.parameters['flow_threshold'], + 'cellprob_threshold': self.parameters['cellprob_threshold'] + }) + + segmentationOutput = segmentation.apply_machine_learning_segmentation( + compartmentImages, segParameters) + + # combine all z positions in watershed + watershedCombinedOutput = segmentation \ + .combine_2d_segmentation_masks_into_3d(segmentationOutput) + + # get features from mask. This is the slowestart (6 min for the + # previous part, 15+ for the rest, for a 7 frame Image. + zPos = np.array(self.dataSet.get_data_organization().get_z_positions()) + featureList = [spatialfeature.SpatialFeature.feature_from_label_matrix( + (watershedCombinedOutput == i), fragmentIndex, + globalTask.fov_to_global_transform(fragmentIndex), zPos) + for i in np.unique(watershedCombinedOutput) if i != 0] + + featureDB = self.get_feature_database() + featureDB.write_features(featureList, fragmentIndex) + + def _read_image_stack(self, fov: int, channelIndex: int) -> np.ndarray: + warpTask = self.dataSet.load_analysis_task( + self.parameters['warp_task']) + return np.array([warpTask.get_aligned_image(fov, channelIndex, z) + for z in range(len(self.dataSet.get_z_positions()))]) + + class CleanCellBoundaries(analysistask.ParallelAnalysisTask): ''' A task to construct a network graph where each cell is a node, and overlaps diff --git a/merlin/core/dataset.py b/merlin/core/dataset.py index bc120af6..1a411619 100755 --- a/merlin/core/dataset.py +++ b/merlin/core/dataset.py @@ -616,7 +616,7 @@ def load_analysis_task(self, analysisTaskName: str) \ -> analysistask.AnalysisTask: loadName = os.sep.join([self.get_task_subdirectory( analysisTaskName), 'task.json']) - + print(loadName) with open(loadName, 'r') as inFile: parameters = json.load(inFile) analysisModule = importlib.import_module(parameters['module']) diff --git a/merlin/merlin.py b/merlin/merlin.py index c892baa3..4cae293e 100755 --- a/merlin/merlin.py +++ b/merlin/merlin.py @@ -162,6 +162,12 @@ def run_with_snakemake( dataSet: dataset.MERFISHDataSet, snakefilePath: str, coreCount: int, snakemakeParameters: Dict = {}, report: bool = True): print('Running MERlin pipeline through snakemake') + ''' + if 'restart_times' not in snakemakeParameters: + snakemakeParameters['restart_times'] = 3 + if 'latency_wait' not in snakemakeParameters: + snakemakeParameters['latency_wait'] = 60 + ''' snakemake.snakemake(snakefilePath, cores=coreCount, workdir=dataSet.get_snakemake_path(), stats=snakefilePath + '.stats', lock=False, diff --git a/merlin/util/dataportal.py b/merlin/util/dataportal.py index 3c8d2bb6..bba38f97 100755 --- a/merlin/util/dataportal.py +++ b/merlin/util/dataportal.py @@ -375,6 +375,5 @@ def read_file_bytes(self, startByte, endByte): endByte=endByte-1) return file - def close(self) -> None: pass diff --git a/merlin/util/segmentation.py b/merlin/util/segmentation.py new file mode 100755 index 00000000..f4a2f1cc --- /dev/null +++ b/merlin/util/segmentation.py @@ -0,0 +1,590 @@ +import numpy as np +import cv2 +from scipy import ndimage +from scipy.ndimage.morphology import binary_fill_holes +from skimage import morphology +from skimage import filters +from skimage import measure +from skimage import feature +from pyclustering.cluster import kmedoids +from typing import Tuple + +from merlin.util import matlab + +from cellpose import models + + +""" +This module contains utility functions for preparing images for +watershed segmentation, as well as functions to perform segmentation +using machine learning approaches +""" + +# To match Matlab's strel('disk', 20) +diskStruct = morphology.diamond(28)[9:48, 9:48] + + +def extract_seeds(seedImageStackIn: np.ndarray) -> np.ndarray: + """Determine seed positions from the input images. + + The initial seeds are determined by finding the regional intensity maximums + after erosion and filtering with an adaptive threshold. These initial + seeds are then expanded by dilation. + + Args: + seedImageStackIn: a 3 dimensional numpy array arranged as (z,x,y) + Returns: a boolean numpy array with the same dimensions as seedImageStackIn + where a given (z,x,y) coordinate is True if it corresponds to a seed + position and false otherwise. + """ + seedImages = seedImageStackIn.copy() + + seedImages = ndimage.grey_erosion( + seedImages, + footprint=ndimage.morphology.generate_binary_structure(3, 1)) + seedImages = np.array([cv2.erode(x, diskStruct, + borderType=cv2.BORDER_REFLECT) + for x in seedImages]) + + thresholdFilterSize = int(2 * np.floor(seedImages.shape[1] / 16) + 1) + seedMask = np.array([x < 1.1 * filters.threshold_local( + x, thresholdFilterSize, method='mean', mode='nearest') + for x in seedImages]) + + seedImages[seedMask] = 0 + + seeds = morphology.local_maxima(seedImages, allow_borders=True) + + seeds = ndimage.morphology.binary_dilation( + seeds, structure=ndimage.morphology.generate_binary_structure(3, 1)) + seeds = np.array([ndimage.morphology.binary_dilation( + x, structure=morphology.diamond(28)[9:48, 9:48]) for x in seeds]) + + return seeds + + +def separate_merged_seeds(seedsIn: np.ndarray) -> np.ndarray: + """Separate seeds that are merged in 3 dimensions but are separated + in some 2 dimensional slices. + + Args: + seedsIn: a 3 dimensional binary numpy array arranged as (z,x,y) where + True indicates the pixel corresponds with a seed. + Returns: a 3 dimensional binary numpy array of the same size as seedsIn + indicating the positions of seeds after processing. + """ + + def create_region_image(shape, c): + region = np.zeros(shape) + for x in c.coords: + region[x[0], x[1], x[2]] = 1 + return region + + components = measure.regionprops(measure.label(seedsIn)) + seeds = np.zeros(seedsIn.shape) + for c in components: + seedImage = create_region_image(seeds.shape, c) + localProps = [measure.regionprops(measure.label(x)) for x in seedImage] + seedCounts = [len(x) for x in localProps] + + if all([x < 2 for x in seedCounts]): + goodFrames = [i for i, x in enumerate(seedCounts) if x == 1] + goodProperties = [y for x in goodFrames for y in localProps[x]] + seedPositions = np.round([np.median( + [x.centroid for x in goodProperties], axis=0)]).astype(int) + else: + goodFrames = [i for i, x in enumerate(seedCounts) if x > 1] + goodProperties = [y for x in goodFrames for y in localProps[x]] + goodCentroids = [x.centroid for x in goodProperties] + km = kmedoids.kmedoids( + goodCentroids, + np.random.choice(np.arange(len(goodCentroids)), + size=np.max(seedCounts))) + km.process() + seedPositions = np.round( + [goodCentroids[x] for x in km.get_medoids()]).astype(int) + + for s in seedPositions: + for f in goodFrames: + seeds[f, s[0], s[1]] = 1 + + seeds = ndimage.morphology.binary_dilation( + seeds, structure=ndimage.morphology.generate_binary_structure(3, 1)) + seeds = np.array([ndimage.morphology.binary_dilation( + x, structure=diskStruct) for x in seeds]) + + return seeds + + +def prepare_watershed_images(watershedImageStack: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """Prepare the given images as the input image for watershedding. + + A watershed mask is determined using an adaptive threshold and the watershed + images are inverted so the largest values in the watershed images become + minima and then the image stack is normalized to have values between 0 + and 1. + + Args: + watershedImageStack: a 3 dimensional numpy array containing the images + arranged as (z, x, y). + Returns: a tuple containing the normalized watershed images and the + calculated watershed mask + """ + filterSize = int(2 * np.floor(watershedImageStack.shape[1] / 16) + 1) + + watershedMask = np.array([ndimage.morphology.binary_fill_holes( + x > 1.1 * filters.threshold_local(x, filterSize, method='mean', + mode='nearest')) + for x in watershedImageStack]) + + normalizedWatershed = 1 - (watershedImageStack + - np.min(watershedImageStack)) / \ + (np.max(watershedImageStack) + - np.min(watershedImageStack)) + normalizedWatershed[np.invert(watershedMask)] = 1 + + return normalizedWatershed, watershedMask + + +def get_membrane_mask(membraneImages: np.ndarray, + compartmentChannelName: str, + membraneChannelName: str) -> np.ndarray: + """Calculate binary mask with 1's in membrane pixels and 0 otherwise. + The images expected are some type of membrane label (WGA, ConA, + Lamin, Cadherins) or compartment images (DAPI, CD45, polyT) + + Args: + membraneImages: a 3 dimensional numpy array containing the images + arranged as (z, x, y). + membraneChannelName: A string with the name of a membrane channel. + compartmentChannelName: A string with the name of the compartment + channel + Returns: + ndarray containing a 3 dimensional mask arranged as (z, x, y) + """ + mask = np.zeros(membraneImages.shape) + if membraneChannelName != compartmentChannelName: + fineBlockSize = 61 + for z in range(membraneImages.shape[0]): + mask[z, :, :] = (membraneImages[z, :, :] > + filters.threshold_local(membraneImages[z, :, :], + fineBlockSize, + offset=0)) + mask[z, :, :] = morphology.remove_small_objects( + mask[z, :, :].astype('bool'), + min_size=100, + connectivity=1) + mask[z, :, :] = morphology.binary_closing(mask[z, :, :], + morphology.selem.disk(5)) + mask[z, :, :] = morphology.skeletonize(mask[z, :, :]) + else: + filterSigma2 = 5 + filterSize2 = int(2*np.ceil(2*filterSigma2)+1) + edgeSigma = 2# 1 #2 + lowThresh = 0.1# 0.5 #0.2 + hiThresh = 0.5# 0.7 #0.6 + for z in range(membraneImages.shape[0]): + blurredImage = cv2.GaussianBlur(membraneImages[z, :, :], + (filterSize2, filterSize2), + filterSigma2) + edge0 = feature.canny(membraneImages[z, :, :], + sigma=edgeSigma, + use_quantiles=True, + low_threshold=lowThresh, + high_threshold=hiThresh) + edge0 = morphology.dilation(edge0, morphology.selem.disk(10)) + + edge1 = feature.canny(blurredImage, + sigma=edgeSigma, + use_quantiles=True, + low_threshold=lowThresh, + high_threshold=hiThresh) + edge1 = morphology.dilation(edge1, morphology.selem.disk(10)) + + mask[z, :, :] = edge0 + edge1 + + mask[z, :, :] = morphology.skeletonize(mask[z, :, :]) + + return mask + + +def get_compartment_mask(compartmentImages: np.ndarray) -> np.ndarray: + """Calculate binary mask with 1's in compartment (nuclei or cytoplasm) + pixels and 0 otherwise. The images expected are some type of compartment + label (e.g. Nuclei: DAPI, Cytoplasm: PolyT, CD45, etc) + + Args: + compartmentImages: a 3 dimensional numpy array containing the images + arranged as (z, x, y). + Returns: + ndarray containing a 3 dimensional mask arranged as (z, x, y) + """ + + # generate compartment mask based on thresholding + thresholdingMask = np.zeros(compartmentImages.shape) + coarseBlockSize = 241 + fineBlockSize = 61 + for z in range(compartmentImages.shape[0]): + coarseThresholdingMask = (compartmentImages[z, :, :] > + filters.threshold_local( + compartmentImages[z, :, :], + coarseBlockSize, + offset=0)) + fineThresholdingMask = (compartmentImages[z, :, :] > + filters.threshold_local( + compartmentImages[z, :, :], + fineBlockSize, + offset=0)) + thresholdingMask[z, :, :] = (coarseThresholdingMask * + fineThresholdingMask) + thresholdingMask[z, :, :] = binary_fill_holes( + thresholdingMask[z, :, :]) + + # generate border mask, necessary to avoid making a single + # connected component when using binary_fill_holes below + borderMask = np.zeros((compartmentImages.shape[1], + compartmentImages.shape[2])) + borderMask[25:(compartmentImages.shape[1]-25), + 25:(compartmentImages.shape[2]-25)] = 1 + + # generate compartment mask from hessian, fine + fineHessianMask = np.zeros(compartmentImages.shape) + for z in range(compartmentImages.shape[0]): + fineHessian = filters.hessian(compartmentImages[z, :, :]) + fineHessianMask[z, :, :] = fineHessian == fineHessian.max() + fineHessianMask[z, :, :] = morphology.binary_closing( + fineHessianMask[z, :, :], + morphology.selem.disk(5)) + fineHessianMask[z, :, :] = fineHessianMask[z, :, :] * borderMask + fineHessianMask[z, :, :] = binary_fill_holes( + fineHessianMask[z, :, :]) + + # generate compartment mask from hessian, coarse + coarseHessianMask = np.zeros(compartmentImages.shape) + for z in range(compartmentImages.shape[0]): + coarseHessian = filters.hessian(compartmentImages[z, :, :] - + morphology.white_tophat( + compartmentImages[z, :, :], + morphology.selem.disk(20))) + coarseHessianMask[z, :, :] = coarseHessian == coarseHessian.max() + coarseHessianMask[z, :, :] = morphology.binary_closing( + coarseHessianMask[z, :, :], morphology.selem.disk(5)) + coarseHessianMask[z, :, :] = (coarseHessianMask[z, :, :] * + borderMask) + coarseHessianMask[z, :, :] = binary_fill_holes( + coarseHessianMask[z, :, :]) + + # combine masks + compartmentMask = thresholdingMask + fineHessianMask + coarseHessianMask + return binary_fill_holes(compartmentMask) + + +def get_cv2_watershed_markers(compartmentImages: np.ndarray, + membraneImages: np.ndarray, + compartmentChannelName: str, + membraneChannelName: str) -> np.ndarray: + """Combine membrane and compartment markers into a single multilabel mask + for CV2 watershed + + Args: + compartmentImages: a 3 dimensional numpy array containing the images + arranged as (z, x, y). + membraneImages: a 3 dimensional numpy array containing the images + arranged as (z, x, y). + compartmentChannelName: str with the name of the compartment channel + to use + membraneChannelName: str with the name of the membrane channel + to use + + Returns: + ndarray containing a 3 dimensional mask arranged as (z, x, y) of + cv2-compatible watershed markers + """ + + compartmentMask = get_compartment_mask(compartmentImages) + membraneMask = get_membrane_mask(membraneImages, + compartmentChannelName, + membraneChannelName) + + watershedMarker = np.zeros(compartmentMask.shape) + + for z in range(compartmentImages.shape[0]): + + # generate areas of sure bg and fg, as well as the area of + # unknown classification + background = morphology.dilation(compartmentMask[z, :, :], + morphology.selem.disk(15)) + membraneDilated = morphology.dilation( + membraneMask[z, :, :].astype('bool'), + morphology.selem.disk(10)) + foreground = morphology.erosion(compartmentMask[z, :, :] * ~ + membraneDilated, + morphology.selem.disk(5)) + unknown = background * ~ foreground + + background = np.uint8(background) * 255 + foreground = np.uint8(foreground) * 255 + unknown = np.uint8(unknown) * 255 + + # Marker labelling + ret, markers = cv2.connectedComponents(foreground) + + # Add one to all labels so that sure background is not 0, but 1 + markers = markers + 1 + + # Now, mark the region of unknown with zero + markers[unknown == 255] = 0 + + watershedMarker[z, :, :] = markers + + return watershedMarker + + +def convert_grayscale_to_rgb(uint16Image: np.ndarray) -> np.ndarray: + """Convert a 16 bit 2D grayscale image into a 3D 8-bit RGB image. + cv2 only works in 8-bit. Based on https://stackoverflow.com/questions/ + 25485886/how-to-convert-a-16-bit-to-an-8-bit-image-in-opencv3D + + Args: + uint16Image: a 2 dimensional numpy array containing the 16-bit + image + Returns: + ndarray containing a 3 dimensional 8-bit image stack + """ + + # invert image + uint16Image = 2**16 - uint16Image + + # convert to uint8 + ratio = np.amax(uint16Image) / 256 + uint8Image = (uint16Image / ratio).astype('uint8') + + print('size = [' + str(uint16Image.shape[0]) + ', ' + str(uint16Image.shape[1]) + ']') + + rgbImage = np.zeros((uint16Image.shape[0], uint16Image.shape[1], 3)) + rgbImage[:, :, 0] = uint8Image + rgbImage[:, :, 1] = uint8Image + rgbImage[:, :, 2] = uint8Image + rgbImage = rgbImage.astype('uint8') + + return rgbImage + + +def apply_cv2_watershed(compartmentImages: np.ndarray, + watershedMarkers: np.ndarray) -> np.ndarray: + """Perform watershed using cv2 + + Args: + compartmentImages: a 3 dimensional numpy array containing the images + arranged as (z, x, y). + watershedMarkers: a 3 dimensional numpy array containing the cv2 + markers arranged as (z, x, y). + Returns: + ndarray containing a 3 dimensional mask arranged as (z, x, y) of + segmented cells. masks in different z positions are + independent + """ + + watershedOutput = np.zeros(watershedMarkers.shape) + for z in range(compartmentImages.shape[0]): + rgbImage = convert_grayscale_to_rgb(compartmentImages[z, :, :]) + watershedOutput[z, :, :] = cv2.watershed(rgbImage, + watershedMarkers[z, :, :]. + astype('int32')) + watershedOutput[z, :, :][watershedOutput[z, :, :] <= 1] = 0 + + return watershedOutput + + +def get_overlapping_objects(segmentationZ0: np.ndarray, + segmentationZ1: np.ndarray, + n0: int) -> Tuple[np.float64, + np.float64, np.float64]: + """compare cell labels in adjacent image masks + + Args: + segmentationZ0: a 2 dimensional numpy array containing a + segmentation mask in position Z + segmentationZ1: a 2 dimensional numpy array containing a + segmentation mask adjacent tosegmentationZ0 + n0: an integer with the index of the object (cell/nuclei) + to be compared between the provided segmentation masks + + Returns: + a tuple (n1, f0, f1) containing the label of the cell in Z1 + overlapping n0 (n1), the fraction of n0 overlaping n1 (f0) and + the fraction of n1 overlapping n0 (f1) + """ + + z1Indexes = np.unique(segmentationZ1[segmentationZ0 == n0]) + + z1Indexes = z1Indexes[z1Indexes > 0] + + if z1Indexes.shape[0] > 0: + + # calculate overlap fraction + n0Area = np.count_nonzero(segmentationZ0 == n0) + n1Area = np.zeros(len(z1Indexes)) + overlapArea = np.zeros(len(z1Indexes)) + + for ii in range(len(z1Indexes)): + n1 = z1Indexes[ii] + n1Area[ii] = np.count_nonzero(segmentationZ1 == n1) + overlapArea[ii] = np.count_nonzero((segmentationZ0 == n0) * + (segmentationZ1 == n1)) + + n0OverlapFraction = np.asarray(overlapArea / n0Area) + n1OverlapFraction = np.asarray(overlapArea / n1Area) + index = list(range(len(n0OverlapFraction))) + + # select the nuclei that has the highest fraction in n0 and n1 + r1, r2, indexSorted = zip(*sorted(zip(n0OverlapFraction, + n1OverlapFraction, + index), + reverse=True)) + + if (n0OverlapFraction[indexSorted[0]] > 0.2 and + n1OverlapFraction[indexSorted[0]] > 0.5): + return (z1Indexes[indexSorted[0]], + n0OverlapFraction[indexSorted[0]], + n1OverlapFraction[indexSorted[0]]) + else: + return (False, False, False) + else: + return (False, False, False) + + +def combine_2d_segmentation_masks_into_3d(segmentationOutput: + np.ndarray) -> np.ndarray: + """Take a 3 dimensional segmentation masks and relabel them so that + nuclei in adjacent sections have the same label if the area their + overlap surpases certain threshold + + Args: + segmentationOutput: a 3 dimensional numpy array containing the + segmentation masks arranged as (z, x, y). + Returns: + ndarray containing a 3 dimensional mask arranged as (z, x, y) of + relabeled segmented cells + """ + + # Initialize empty array with size as segmentationOutput array + segmentationCombinedZ = np.zeros(segmentationOutput.shape) + + # copy the mask of the section farthest to the coverslip to start + segmentationCombinedZ[-1, :, :] = segmentationOutput[-1, :, :] + + # starting far from coverslip + for z in range(segmentationOutput.shape[0]-1, 0, -1): + + # get non-background cell indexes + zIndex = np.unique(segmentationOutput[z, :, :])[ + np.unique(segmentationOutput[z, :, :]) > 0] + + # compare each cell in z0 + for n0 in zIndex: + n1, f0, f1 = get_overlapping_objects(segmentationCombinedZ[z, :, :], + segmentationOutput[z-1, :, :], + n0) + if n1: + segmentationCombinedZ[z-1, :, :][ + (segmentationOutput[z-1, :, :] == n1)] = n0 + + return segmentationCombinedZ + + +def segment_using_ilastik(imageStackIn: np.ndarray) -> np.ndarray: + return None + + +def segment_using_unet(imageStackIn: np.ndarray) -> np.ndarray: + return None + + +def segment_using_cellpose(imageStackIn: np.ndarray, + params: dict) -> np.ndarray: + """Perform segmentation using cellpose. Code adapted from + https://nbviewer.jupyter.org/github/MouseLand/cellpose/blob/ + master/notebooks/run_cellpose.ipynb + Args: + imageStackIn: a 3 dimensional numpy array containing the images + arranged as (z, x, y). + params: a dictionary with the parameters for segmentation. + Available parameters: + channel + diameter + flow_threshold + cellprob_threshold + + Returns: + ndarray containing a 3 dimensional mask arranged as (z, x, y) + """ + channelName = params['channel'].lower() + + # Define cellpose model + if any([channelName == 'dapi', + channelName == 'lamin']): + model = models.Cellpose(gpu=False, model_type='nuclei') + if any([channelName == 'polyt', + channelName == 'polya', + channelName == 'ecadherin', + channelName == 'cd45', + channelName == 'wga', + channelName == 'cona']): + model = models.Cellpose(gpu=False, model_type='cyto') + + # define CHANNELS to run segementation on + # grayscale=0, R=1, G=2, B=3 + # channels = [cytoplasm, nucleus] + # if NUCLEUS channel does not exist, set the second channel to 0 + # channels = [0,0] + # IF ALL YOUR IMAGES ARE THE SAME TYPE, you can give a list with 2 elements + channels = [0, 0] # IF YOU HAVE GRAYSCALE + # channels = [2,3] # IF YOU HAVE G=cytoplasm and B=nucleus + # channels = [2,1] # IF YOU HAVE G=cytoplasm and R=nucleus + + # or if you have different types of channels in each image + # channels = [[0,0],[0,0]] + + # if diameter is set to None, the size of the cells is estimated on a per + # image basis you can set the average cell `diameter` in pixels yourself + # (recommended) diameter can be a list or a single number for all images + + # put list of images in cellpose format + imageList = np.split(imageStackIn, imageStackIn.shape[0]) + + masks, flows, styles, diams = model.eval(imageList, + diameter=params['diameter'], + channels=channels, + flow_threshold= + params['flow_threshold'], + cellprob_threshold= + params['cellprob_threshold']) + # combine masks into array + masksArray = np.stack(masks) + + return masksArray + + +def apply_machine_learning_segmentation(imageStackIn: np.ndarray, + params: dict) -> np.ndarray: + """Select segmentation algorithm to use + Args: + imageStackIn: a 3 dimensional numpy array containing the images + arranged as (z, x, y). + params: dictionary with key:value pairs with parameters to be passed + to the segmentation code. Keys used are 'method', 'diameter', + 'channel' + + Returns: + ndarray containing a 3 dimensional mask arranged as (z, x, y) + """ + if params['method'] == 'ilastik': + segmentOutput = segment_using_ilastik(imageStackIn, params) + elif params['method'] == 'cellpose': + segmentOutput = segment_using_cellpose(imageStackIn, params) + elif params['method'] == 'unet': + segmentOutput = segment_using_unet(imageStackIn, params) + + return segmentOutput diff --git a/merlin/util/watershed.py b/merlin/util/watershed.py index 809dfc5b..6bb453c7 100644 --- a/merlin/util/watershed.py +++ b/merlin/util/watershed.py @@ -1,16 +1,18 @@ import numpy as np import cv2 from scipy import ndimage +from scipy.ndimage.morphology import binary_fill_holes from skimage import morphology from skimage import filters from skimage import measure +from skimage import feature from pyclustering.cluster import kmedoids from typing import Tuple from merlin.util import matlab """ -This module contains utility functions for preparing imagmes for +This module contains utility functions for preparing imagmes for watershed segmentation. """ @@ -138,4 +140,4 @@ def prepare_watershed_images(watershedImageStack: np.ndarray - np.min(watershedImageStack)) normalizedWatershed[np.invert(watershedMask)] = 1 - return normalizedWatershed, watershedMask + return normalizedWatershed, watershedMask \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c65e9715..e4dc2d75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,6 @@ tables boto3 xmltodict google-cloud-storage -docutils<0.16,>=0.10 \ No newline at end of file +docutils<0.16,>=0.10 +pillow<=7.0.0 +cellpose \ No newline at end of file diff --git a/test/auxiliary_files/test_analysis_segmentation_cellpose.json b/test/auxiliary_files/test_analysis_segmentation_cellpose.json new file mode 100755 index 00000000..15ac7dcc --- /dev/null +++ b/test/auxiliary_files/test_analysis_segmentation_cellpose.json @@ -0,0 +1,101 @@ +{ + "analysis_tasks": [ + { + "task": "FiducialCorrelationWarp", + "module": "merlin.analysis.warp", + "analysis_name": "CellposeFiducialCorrelationWarp", + "parameters": { + "write_aligned_images": true + } + }, + { + "task": "SimpleGlobalAlignment", + "module": "merlin.analysis.globalalign", + "analysis_name": "CellposeSimpleGlobalAlignment" + }, + { + "task": "MachineLearningSegment", + "module": "merlin.analysis.segment", + "analysis_name": "CellposeSegment", + "parameters": { + "warp_task": "CellposeFiducialCorrelationWarp", + "global_align_task": "CellposeSimpleGlobalAlignment" + } + }, + { + "task": "CleanCellBoundaries", + "module": "merlin.analysis.segment", + "analysis_name": "CellposeCleanCellBoundaries", + "parameters": { + "segment_task": "CellposeSegment", + "global_align_task": "CellposeSimpleGlobalAlignment" + } + }, + { + "task": "CombineCleanedBoundaries", + "module": "merlin.analysis.segment", + "analysis_name": "CellposeCombineCleanedBoundaries", + "parameters": { + "cleaning_task": "CellposeCleanCellBoundaries" + } + }, + { + "task": "RefineCellDatabases", + "module": "merlin.analysis.segment", + "analysis_name": "CellposeRefineCellDatabases", + "parameters": { + "segment_task": "CellposeSegment", + "combine_cleaning_task": "CellposeCombineCleanedBoundaries" + } + }, + { + "task": "PartitionBarcodes", + "module": "merlin.analysis.partition", + "analysis_name": "CellposePartitionBarcodes", + "parameters": { + "filter_task": "AdaptiveFilterBarcodes", + "assignment_task": "CellposeRefineCellDatabases", + "alignment_task": "CellposeSimpleGlobalAlignment" + } + }, + { + "task": "ExportPartitionedBarcodes", + "module": "merlin.analysis.partition", + "analysis_name": "CellposeExportPartitionedBarcodes", + "parameters": { + "partition_task": "CellposePartitionBarcodes" + } + }, + { + "task": "ExportCellMetadata", + "module": "merlin.analysis.segment", + "analysis_name": "CellposeExportCellMetadata", + "parameters": { + "segment_task": "CellposeRefineCellDatabases" + } + }, + { + "task": "SumSignal", + "module": "merlin.analysis.sequential", + "analysis_name": "CellposeSumSignal", + "parameters": { + "z_index": 0, + "apply_highpass": true, + "warp_task": "CellposeFiducialCorrelationWarp", + "highpass_sigma": 5, + "segment_task": "CellposeRefineCellDatabases", + "global_align_task": "CellposeSimpleGlobalAlignment" + } + }, + { + "task": "ExportSumSignals", + "module": "merlin.analysis.sequential", + "analysis_name": "CellposeExportSumSignals", + "parameters": { + "sequential_task": "CellposeSumSignal" + } + } + + ] + +} diff --git a/test/auxiliary_files/test_analysis_segmentation_cv2.json b/test/auxiliary_files/test_analysis_segmentation_cv2.json new file mode 100755 index 00000000..fb2e24c3 --- /dev/null +++ b/test/auxiliary_files/test_analysis_segmentation_cv2.json @@ -0,0 +1,101 @@ +{ + "analysis_tasks": [ + { + "task": "FiducialCorrelationWarp", + "module": "merlin.analysis.warp", + "analysis_name": "CV2FiducialCorrelationWarp", + "parameters": { + "write_aligned_images": true + } + }, + { + "task": "SimpleGlobalAlignment", + "module": "merlin.analysis.globalalign", + "analysis_name": "CV2SimpleGlobalAlignment" + }, + { + "task": "WatershedSegmentNucleiCV2", + "module": "merlin.analysis.segment", + "analysis_name": "CV2Segment", + "parameters": { + "warp_task": "CV2FiducialCorrelationWarp", + "global_align_task": "CV2SimpleGlobalAlignment" + } + }, + { + "task": "CleanCellBoundaries", + "module": "merlin.analysis.segment", + "analysis_name": "CV2CleanCellBoundaries", + "parameters": { + "segment_task": "CV2Segment", + "global_align_task": "CV2SimpleGlobalAlignment" + } + }, + { + "task": "CombineCleanedBoundaries", + "module": "merlin.analysis.segment", + "analysis_name": "CV2CombineCleanedBoundaries", + "parameters": { + "cleaning_task": "CV2CleanCellBoundaries" + } + }, + { + "task": "RefineCellDatabases", + "module": "merlin.analysis.segment", + "analysis_name": "CV2RefineCellDatabases", + "parameters": { + "segment_task": "CV2Segment", + "combine_cleaning_task": "CV2CombineCleanedBoundaries" + } + }, + { + "task": "PartitionBarcodes", + "module": "merlin.analysis.partition", + "analysis_name": "CV2PartitionBarcodes", + "parameters": { + "filter_task": "AdaptiveFilterBarcodes", + "assignment_task": "CV2RefineCellDatabases", + "alignment_task": "CV2SimpleGlobalAlignment" + } + }, + { + "task": "ExportPartitionedBarcodes", + "module": "merlin.analysis.partition", + "analysis_name": "CV2ExportPartitionedBarcodes", + "parameters": { + "partition_task": "CV2PartitionBarcodes" + } + }, + { + "task": "ExportCellMetadata", + "module": "merlin.analysis.segment", + "analysis_name": "CV2ExportCellMetadata", + "parameters": { + "segment_task": "CV2RefineCellDatabases" + } + }, + { + "task": "SumSignal", + "module": "merlin.analysis.sequential", + "analysis_name": "CV2SumSignal", + "parameters": { + "z_index": 0, + "apply_highpass": true, + "warp_task": "CV2FiducialCorrelationWarp", + "highpass_sigma": 5, + "segment_task": "CV2RefineCellDatabases", + "global_align_task": "CV2SimpleGlobalAlignment" + } + }, + { + "task": "ExportSumSignals", + "module": "merlin.analysis.sequential", + "analysis_name": "CV2ExportSumSignals", + "parameters": { + "sequential_task": "CV2SumSignal" + } + } + + ] + +} diff --git a/test/conftest.py b/test/conftest.py index 83714b35..6a00262e 100755 --- a/test/conftest.py +++ b/test/conftest.py @@ -57,6 +57,16 @@ def base_files(): [root, 'auxiliary_files', 'test_analysis_parameters.json']), os.sep.join( [merlin.ANALYSIS_PARAMETERS_HOME, 'test_analysis_parameters.json'])) + shutil.copyfile( + os.sep.join( + [root, 'auxiliary_files', 'test_analysis_segmentation_cellpose.json']), + os.sep.join( + [merlin.ANALYSIS_PARAMETERS_HOME, 'test_analysis_segmentation_cellpose.json'])) + shutil.copyfile( + os.sep.join( + [root, 'auxiliary_files', 'test_analysis_segmentation_cv2.json']), + os.sep.join( + [merlin.ANALYSIS_PARAMETERS_HOME, 'test_analysis_segmentation_cv2.json'])) shutil.copyfile( os.sep.join( [root, 'auxiliary_files', 'test_microscope_parameters.json']), diff --git a/test/test_merfish_segmentation_cellpose.py b/test/test_merfish_segmentation_cellpose.py new file mode 100755 index 00000000..81989cf7 --- /dev/null +++ b/test/test_merfish_segmentation_cellpose.py @@ -0,0 +1,15 @@ +import os +import pytest + +import merlin +from merlin import merlin as m + + +@pytest.mark.fullrun +@pytest.mark.slowtest +def test_cellpose_2d_local(simple_merfish_data): + with open(os.sep.join([merlin.ANALYSIS_PARAMETERS_HOME, + 'test_analysis_segmentation_cellpose.json']), 'r') as f: + snakefilePath = m.generate_analysis_tasks_and_snakefile( + simple_merfish_data, f) + m.run_with_snakemake(simple_merfish_data, snakefilePath, 5) diff --git a/test/test_merfish_segmentation_cv2.py b/test/test_merfish_segmentation_cv2.py new file mode 100755 index 00000000..ab41c8fd --- /dev/null +++ b/test/test_merfish_segmentation_cv2.py @@ -0,0 +1,15 @@ +import os +import pytest + +import merlin +from merlin import merlin as m + + +@pytest.mark.fullrun +@pytest.mark.slowtest +def test_cv2_2d_local(simple_merfish_data): + with open(os.sep.join([merlin.ANALYSIS_PARAMETERS_HOME, + 'test_analysis_segmentation_cv2.json']), 'r') as f: + snakefilePath = m.generate_analysis_tasks_and_snakefile( + simple_merfish_data, f) + m.run_with_snakemake(simple_merfish_data, snakefilePath, 5)