Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
chaco / tests / test_image_plot.py
Size: Mime:
import os

import tempfile
from contextlib import contextmanager

import six

import unittest

import numpy as np

from traits.etsconfig.api import ETSConfig
from chaco.api import (PlotGraphicsContext, GridDataSource, GridMapper,
                       DataRange2D, ImageData, ImagePlot)


# The Quartz backend rescales pixel values, so use a higher threshold.
MAX_RMS_ERROR = 16 if ETSConfig.kiva_backend == 'quartz' else 1

IMAGE = np.random.random_integers(0, 255, size=(100, 200)).astype(np.uint8)
RGB = np.dstack([IMAGE] * 3)
# Rendering adds rows and columns for some reason.
TRIM_RENDERED = (slice(1, -1), slice(1, -1), 0)


@contextmanager
def temp_image_file(suffix='.tif', prefix='test', dir=None):
    fd, filename = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=dir)
    try:
        yield filename
    finally:
        os.close(fd)
        os.remove(filename)


def get_image_index_and_mapper(image):
    h, w = image.shape[:2]
    index = GridDataSource(np.arange(h+1), np.arange(w+1))
    index_mapper = GridMapper(range=DataRange2D(low=(0, 0), high=(h, w)))
    return index, index_mapper


def save_renderer_result(renderer, filename):
    renderer.padding = 0
    gc = PlotGraphicsContext(renderer.outer_bounds)
    with gc:
        gc.render_component(renderer)
        gc.save(filename)


def image_from_renderer(renderer, orientation):
    data = renderer.value
    # Set bounding box size and origin to align rendered image with array
    renderer.bounds = (data.get_width() + 1, data.get_height() + 1)
    if orientation == 'v':
        renderer.bounds = renderer.bounds[::-1]
    renderer.position = 0, 0

    with temp_image_file() as filename:
        save_renderer_result(renderer, filename)
        rendered_image = ImageData.fromfile(filename).data[TRIM_RENDERED]
    return rendered_image


def rendered_image_result(image, filename=None, **plot_kwargs):
    data_source = ImageData(data=image)
    index, index_mapper = get_image_index_and_mapper(image)
    renderer = ImagePlot(value=data_source, index=index,
                         index_mapper=index_mapper,
                         **plot_kwargs)
    orientation = plot_kwargs.get('orientation', 'h')
    return image_from_renderer(renderer, orientation)


def calculate_rms(image_result, expected_image):
    """Return root-mean-square error.

    Implementation taken from matplotlib.
    """
    # calculate the per-pixel errors, then compute the root mean square error
    num_values = np.prod(expected_image.shape)
    # Images may be e.g. 8-bit unsigned integer; upcast to default integer size
    # (32 or 64 bit) to reduce likelihood of over-/under-flow.
    abs_diff_image = abs(np.int_(expected_image) - np.int_(image_result))

    histogram = np.bincount(abs_diff_image.ravel(), minlength=256)
    sum_of_squares = np.sum(np.int64(histogram) * np.arange(len(histogram))**2)
    rms = np.sqrt(float(sum_of_squares) / num_values)
    return rms


def plot_comparison(input_image, expected_image, **plot_kwargs):
    import matplotlib.pyplot as plt

    image_result = rendered_image_result(input_image, **plot_kwargs)
    diff = np.int64(expected_image) - np.int64(image_result)
    max_diff = max(abs(diff.min()), abs(diff.max()), 1)

    fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, sharex='all', sharey='all')
    ax0.imshow(expected_image)
    ax1.imshow(image_result)
    im_plot = ax2.imshow(diff, vmin=-max_diff, vmax=max_diff, cmap=plt.cm.bwr)
    fig.colorbar(im_plot)
    plt.show()


class TestResultImage(unittest.TestCase):

    def verify_result_image(self, input_image, expected_image, **plot_kwargs):
        # These tests were written assuming uint8 inputs.
        self.assertEqual(input_image.dtype, np.uint8)
        self.assertEqual(expected_image.dtype, np.uint8)
        image_result = rendered_image_result(input_image, **plot_kwargs)
        rms = calculate_rms(image_result, expected_image)
        self.assertLess(rms, MAX_RMS_ERROR)

    def test_horizontal_top_left(self):
        # Horizontal orientation with top left origin renders original image.
        self.verify_result_image(RGB, IMAGE, origin='top left')

    def test_horizontal_bottom_left(self):
        # Horizontal orientation with bottom left origin renders a vertically
        # flipped image.
        self.verify_result_image(RGB, IMAGE[::-1], origin='bottom left')

    def test_horizontal_top_right(self):
        # Horizontal orientation with top right origin renders a horizontally
        # flipped image.
        self.verify_result_image(RGB, IMAGE[:, ::-1], origin='top right')

    def test_horizontal_bottom_right(self):
        # Horizontal orientation with top right origin renders an image flipped
        # horizontally and vertically.
        self.verify_result_image(RGB, IMAGE[::-1, ::-1], origin='bottom right')

    def test_vertical_top_left(self):
        # Vertical orientation with top left origin renders transposed image.
        self.verify_result_image(RGB, IMAGE.T, origin='top left', orientation='v')

    def test_vertical_bottom_left(self):
        # Vertical orientation with bottom left origin renders transposed image
        # that is vertically flipped.
        self.verify_result_image(RGB, (IMAGE.T)[::-1],
                                 origin='bottom left', orientation='v')

    def test_vertical_top_right(self):
        # Vertical orientation with top right origin renders transposed image
        # that is horizontally flipped.
        self.verify_result_image(RGB, (IMAGE.T)[:, ::-1],
                                 origin='top right', orientation='v')

    def test_vertical_bottom_right(self):
        # Vertical orientation with bottom right origin renders transposed image
        # that is flipped vertically and horizontally.
        self.verify_result_image(RGB, (IMAGE.T)[::-1, ::-1],
                                 origin='bottom right', orientation='v')