Source code for wsinfer.patchlib.segment

"""Segment thumbnail of a whole slide image."""

from __future__ import annotations

import cv2 as cv
import numpy as np
import numpy.typing as npt
from skimage.morphology import binary_closing
from skimage.morphology import remove_small_holes
from skimage.morphology import remove_small_objects


[docs] def segment_tissue( im_arr: npt.NDArray, median_filter_size: int = 7, binary_threshold: int = 7, closing_kernel_size: int = 6, min_object_size_px: int = 512, min_hole_size_px: int = 1024, ) -> npt.NDArray[np.bool_]: """Create a binary tissue mask from an image. Parameters ---------- im_arr : array-like RGB image array (uint8) with shape (rows, cols, 3). median_filter_size : int The kernel size for median filtering. Must be odd and greater than one. binary_threshold : int The pixel threshold for image binarization. closing_kernel_size : int The kernel size for morphological closing (in pixel units). min_object_size_px : int The minimum area of an object in pixels. If an object is smaller than this area, it is removed and is made into background. min_hole_size_px : int The minimum area of a hole in pixels. If a hole is smaller than this area, it is filled and is made into foreground. Returns ------- mask Boolean array, where True values indicate presence of tissue. """ im_arr = np.asarray(im_arr) assert im_arr.ndim == 3 assert im_arr.shape[2] == 3 # Convert to HSV color space. im_arr = cv.cvtColor(im_arr, cv.COLOR_RGB2HSV) im_arr = im_arr[:, :, 1] # Keep saturation channel only. # Use median blurring to smooth the image. if median_filter_size <= 1 or median_filter_size % 2 == 0: raise ValueError( "median_filter_size must be greater than 1 and odd, but got" f" {median_filter_size}" ) # We use opencv here instead of PIL because opencv is _much_ faster. We use skimage # further down for artifact removal (hole filling, object removal) because skimage # provides easy to use methods for those. im_arr = cv.medianBlur(im_arr, median_filter_size) # Binarize image. _, im_arr = cv.threshold( im_arr, thresh=binary_threshold, maxval=255, type=cv.THRESH_BINARY ) # Convert to boolean dtype. This helps with static type analysis because at this # point, im_arr is a uint8 array. im_arr_binary: npt.NDArray[np.bool_] = im_arr > 0 # type: ignore # Closing. This removes small holes. It might not be entirely necessary because # we have hole removal below. im_arr_binary = binary_closing( im_arr_binary, footprint=np.ones((closing_kernel_size, closing_kernel_size)) ) # Remove small objects. im_arr_binary = remove_small_objects(im_arr_binary, min_size=min_object_size_px) # Remove small holes. im_arr_binary = remove_small_holes(im_arr_binary, area_threshold=min_hole_size_px) return im_arr_binary