Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import logging
import random
import numpy as np

# imports for deformed slice
from skimage.draw import line
from scipy.ndimage.measurements import label
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.morphology import binary_dilation

from gunpowder.batch_request import BatchRequest
from gunpowder.coordinate import Coordinate
from .batch_filter import BatchFilter

logger = logging.getLogger(__name__)

class DefectAugment(BatchFilter):
    '''Augment intensity arrays section-wise with artifacts like missing
    sections, low-contrast sections, by blending in artifacts drawn from a
    separate source, or by deforming a section.

    Args:

        intensities (:class:`ArrayKey`):

            The key of the array of intensities to modify.

        prob_missing(``float``):
        prob_low_contrast(``float``):
        prob_artifact(``float``):
        prob_deform(``float``):

            Probabilities of having a missing section, low-contrast section, an
            artifact (see param ``artifact_source``) or a deformed slice. The
            sum should not exceed 1. Values in missing sections will be set to
            0.

        contrast_scale (``float``, optional):

            By how much to scale the intensities for a low-contrast section,
            used if ``prob_low_contrast`` > 0.

        artifact_source (class:`BatchProvider`, optional):

            A gunpowder batch provider that delivers intensities (via
            :class:`ArrayKey` ``artifacts``) and an alpha mask (via
            :class:`ArrayKey` ``artifacts_mask``), used if ``prob_artifact`` > 0.

        artifacts(:class:`ArrayKey`, optional):

            The key to query ``artifact_source`` for to get the intensities
            of the artifacts.

        artifacts_mask(:class:`ArrayKey`, optional):

            The key to query ``artifact_source`` for to get the alpha mask
            of the artifacts to blend them with ``intensities``.

        deformation_strength (``int``, optional):

            Strength of the slice deformation in voxels, used if
            ``prob_deform`` > 0. The deformation models a fold by shifting the
            section contents towards a randomly oriented line in the section.
            The line itself will be drawn with a value of 0.

        axis (``int``, optional):

            Along which axis sections are cut.
    '''

    def __init__(
            self,
            intensities,
            prob_missing=0.05,
            prob_low_contrast=0.05,
            prob_artifact=0.0,
            prob_deform=0.0,
            contrast_scale=0.1,
            artifact_source=None,
            artifacts=None,
            artifacts_mask=None,
            deformation_strength=20,
            axis=0):
        self.intensities = intensities
        self.prob_missing = prob_missing
        self.prob_low_contrast = prob_low_contrast
        self.prob_artifact = prob_artifact
        self.prob_deform = prob_deform
        self.contrast_scale = contrast_scale
        self.artifact_source = artifact_source
        self.artifacts = artifacts
        self.artifacts_mask = artifacts_mask
        self.deformation_strength = deformation_strength
        self.axis = axis

    def setup(self):

        if self.artifact_source is not None:
            self.artifact_source.setup()

    def teardown(self):

        if self.artifact_source is not None:
            self.artifact_source.teardown()

    # send roi request to data-source upstream
    def prepare(self, request):

        # we prepare the augmentations, by determining which slices
        # will be augmented by which method
        # If one of the slices is augmented with 'deform',
        # we prepare these trafos already
        # and request a bigger roi from upstream

        prob_missing_threshold = self.prob_missing
        prob_low_contrast_threshold = prob_missing_threshold + self.prob_low_contrast
        prob_artifact_threshold = prob_low_contrast_threshold + self.prob_artifact
        prob_deform_slice = prob_artifact_threshold + self.prob_deform

        spec = request[self.intensities]
        roi = spec.roi
        logger.debug("downstream request ROI is %s" % roi)
        raw_voxel_size = self.spec[self.intensities].voxel_size

        # store the mapping slice to augmentation type in a dict
        self.slice_to_augmentation = {}
        # store the transformations for deform slice
        self.deform_slice_transformations = {}
        for c in range((roi / raw_voxel_size).get_shape()[self.axis]):
            r = random.random()

            if r < prob_missing_threshold:
                logger.debug("Zero-out " + str(c))
                self.slice_to_augmentation[c] = 'zero_out'

            elif r < prob_low_contrast_threshold:
                logger.debug("Lower contrast " + str(c))
                self.slice_to_augmentation[c] = 'lower_contrast'

            elif r < prob_artifact_threshold:
                logger.debug("Add artifact " + str(c))
                self.slice_to_augmentation[c] = 'artifact'

            elif r < prob_deform_slice:
                logger.debug("Add deformed slice " + str(c))
                self.slice_to_augmentation[c] = 'deformed_slice'
                # get the shape of a single slice
                slice_shape = (roi / raw_voxel_size).get_shape()
                slice_shape = slice_shape[:self.axis] + slice_shape[self.axis+1:]
                self.deform_slice_transformations[c] = self.__prepare_deform_slice(slice_shape)

        # prepare transformation and
        # request bigger upstream roi for deformed slice
        if 'deformed_slice' in self.slice_to_augmentation.values():

            # create roi sufficiently large to feed deformation
            logger.debug("before growth: %s" % spec.roi)
            growth = Coordinate(
                tuple(0 if d == self.axis else raw_voxel_size[d] * self.deformation_strength
                      for d in range(spec.roi.dims()))
            )
            logger.debug("growing request by %s" % str(growth))
            source_roi = roi.grow(growth, growth)

            # update request ROI to get all voxels necessary to perfrom
            # transformation
            spec.roi = source_roi
            logger.debug("upstream request roi is %s" % spec.roi)

    def process(self, batch, request):

        assert batch.get_total_roi().dims() == 3, "defectaugment works on 3d batches only"

        raw = batch.arrays[self.intensities]
        raw_voxel_size = self.spec[self.intensities].voxel_size

        for c, augmentation_type in self.slice_to_augmentation.items():

            section_selector = tuple(
                slice(None if d != self.axis else c, None if d != self.axis else c+1)
                for d in range(raw.spec.roi.dims())
            )

            if augmentation_type == 'zero_out':
                raw.data[section_selector] = 0

            elif augmentation_type == 'low_contrast':
                section = raw.data[section_selector]

                mean = section.mean()
                section -= mean
                section *= self.contrast_scale
                section += mean

                raw.data[section_selector] = section

            elif augmentation_type == 'artifact':

                section = raw.data[section_selector]

                alpha_voxel_size = self.artifact_source.spec[self.artifacts_mask].voxel_size

                assert raw_voxel_size == alpha_voxel_size, ("Can only alpha blend RAW with "
                                                            "ALPHA_MASK if both have the same "
                                                            "voxel size")

                artifact_request = BatchRequest()
                artifact_request.add(self.artifacts, Coordinate(section.shape) * raw_voxel_size, voxel_size=raw_voxel_size)
                artifact_request.add(self.artifacts_mask, Coordinate(section.shape) * alpha_voxel_size, voxel_size=raw_voxel_size)
                logger.debug("Requesting artifact batch %s", artifact_request)

                artifact_batch = self.artifact_source.request_batch(artifact_request)
                artifact_alpha = artifact_batch.arrays[self.artifacts_mask].data
                artifact_raw   = artifact_batch.arrays[self.artifacts].data

                assert artifact_alpha.dtype == np.float32
                assert artifact_alpha.min() >= 0.0
                assert artifact_alpha.max() <= 1.0

                raw.data[section_selector] = section*(1.0 - artifact_alpha) + artifact_raw*artifact_alpha

            elif augmentation_type == 'deformed_slice':

                section = raw.data[section_selector].squeeze()

                # set interpolation to cubic, spec interploatable is true, else to 0
                interpolation = 3 if self.spec[self.intensities].interpolatable else 0

                # load the deformation fields that were prepared for this slice
                flow_x, flow_y, line_mask = self.deform_slice_transformations[c]

                # apply the deformation fields
                shape = section.shape
                section = map_coordinates(
                    section, (flow_y, flow_x), mode='constant', order=interpolation
                ).reshape(shape)

                # things can get smaller than 0 at the boundary, so we clip
                section = np.clip(section, 0., 1.)

                # zero-out data below the line mask
                section[line_mask] = 0.

                raw.data[section_selector] = section

        # in case we needed to change the ROI due to a deformation augment,
        # restore original ROI and crop the array data
        if 'deformed_slice' in self.slice_to_augmentation.values():
            old_roi = request[self.intensities].roi
            logger.debug("resetting roi to %s" % old_roi)
            crop = tuple(
                slice(None) if d == self.axis else slice(self.deformation_strength, -self.deformation_strength)
                for d in range(raw.spec.roi.dims())
            )
            raw.data = raw.data[crop]
            raw.spec.roi = old_roi

    def __prepare_deform_slice(self, slice_shape):

        # grow slice shape by 2 x deformation strength
        grow_by = 2 * self.deformation_strength
        shape = (slice_shape[0] + grow_by, slice_shape[1] + grow_by)

        # randomly choose fixed x or fixed y with p = 1/2
        fixed_x = random.random() < .5
        if fixed_x:
            x0, y0 = 0, np.random.randint(1, shape[1] - 2)
            x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2)
        else:
            x0, y0 = np.random.randint(1, shape[0] - 2), 0
            x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1

        ## generate the mask of the line that should be blacked out
        line_mask = np.zeros(shape, dtype='bool')
        rr, cc = line(x0, y0, x1, y1)
        line_mask[rr, cc] = 1

        # generate vectorfield pointing towards the line to compress the image
        # first we get the unit vector representing the line
        line_vector = np.array([x1 - x0, y1 - y0], dtype='float32')
        line_vector /= np.linalg.norm(line_vector)
        # next, we generate the normal to the line
        normal_vector = np.zeros_like(line_vector)
        normal_vector[0] = - line_vector[1]
        normal_vector[1] = line_vector[0]

        # make meshgrid
        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
        # generate the vector field
        flow_x, flow_y = np.zeros(shape), np.zeros(shape)

        # find the 2 components where coordinates are bigger / smaller than the line
        # to apply normal vector in the correct direction
        components, n_components = label(np.logical_not(line_mask).view('uint8'))
        assert n_components == 2, "%i" % n_components
        neg_val = components[0, 0] if fixed_x else components[-1, -1]
        pos_val = components[-1, -1] if fixed_x else components[0, 0]

        flow_x[components == pos_val] = self.deformation_strength * normal_vector[1]
        flow_y[components == pos_val] = self.deformation_strength * normal_vector[0]
        flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1]
        flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0]

        # generate the flow fields
        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)

        # dilate the line mask
        line_mask = binary_dilation(line_mask, iterations=10)

        return flow_x, flow_y, line_mask