Repository URL to install this package:
Version:
5.0.0 ▾
|
# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
import os
import tempfile
from contextlib import contextmanager
import unittest
import numpy as np
from enable.api import Component, ComponentEditor
from traits.api import HasTraits, Instance
from traits.etsconfig.api import ETSConfig
from traitsui.api import Group, UItem, View
from traitsui.testing.api import UITester
from chaco.api import (
ArrayPlotData,
PlotGraphicsContext,
GridDataSource,
GridMapper,
DataRange2D,
ImageData,
ImagePlot,
Plot,
)
# The Quartz backend rescales pixel values, so use a higher threshold.
MAX_RMS_ERROR = 16 if ETSConfig.kiva_backend == "quartz" else 1
IMAGE = np.random.random_integers(0, 255, size=(100, 200)).astype(np.uint8)
RGB = np.dstack([IMAGE] * 3)
# Rendering adds rows and columns for some reason.
TRIM_RENDERED = (slice(1, -1), slice(1, -1), 0)
is_null = ETSConfig.toolkit == "null"
def is_qt4():
if not ETSConfig.toolkit.startswith('qt'):
return False
# Only AFTER confirming Qt's availability...
# We lean on Pyface here since the check is complicated.
import pyface.qt
return pyface.qt.is_qt4
@contextmanager
def temp_image_file(suffix=".tif", prefix="test", dir=None):
fd, filename = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=dir)
try:
yield filename
finally:
os.close(fd)
os.remove(filename)
def get_image_index_and_mapper(image):
h, w = image.shape[:2]
index = GridDataSource(np.arange(h + 1), np.arange(w + 1))
index_mapper = GridMapper(range=DataRange2D(low=(0, 0), high=(h, w)))
return index, index_mapper
def save_renderer_result(renderer, filename):
renderer.padding = 0
gc = PlotGraphicsContext(renderer.outer_bounds)
with gc:
gc.render_component(renderer)
gc.save(filename)
def image_from_renderer(renderer, orientation):
data = renderer.value
# Set bounding box size and origin to align rendered image with array
renderer.bounds = (data.get_width() + 1, data.get_height() + 1)
if orientation == "v":
renderer.bounds = renderer.bounds[::-1]
renderer.position = 0, 0
with temp_image_file() as filename:
save_renderer_result(renderer, filename)
rendered_image = ImageData.fromfile(filename).data[TRIM_RENDERED]
return rendered_image
def rendered_image_result(image, filename=None, **plot_kwargs):
data_source = ImageData(data=image)
index, index_mapper = get_image_index_and_mapper(image)
renderer = ImagePlot(
value=data_source,
index=index,
index_mapper=index_mapper,
**plot_kwargs
)
orientation = plot_kwargs.get("orientation", "h")
return image_from_renderer(renderer, orientation)
def calculate_rms(image_result, expected_image):
"""Return root-mean-square error.
Implementation taken from matplotlib.
"""
# calculate the per-pixel errors, then compute the root mean square error
num_values = np.prod(expected_image.shape)
# Images may be e.g. 8-bit unsigned integer; upcast to default integer size
# (32 or 64 bit) to reduce likelihood of over-/under-flow.
abs_diff_image = abs(np.int_(expected_image) - np.int_(image_result))
histogram = np.bincount(abs_diff_image.ravel(), minlength=256)
sum_of_squares = np.sum(
np.int64(histogram) * np.arange(len(histogram)) ** 2
)
rms = np.sqrt(float(sum_of_squares) / num_values)
return rms
def plot_comparison(input_image, expected_image, **plot_kwargs):
import matplotlib.pyplot as plt
image_result = rendered_image_result(input_image, **plot_kwargs)
diff = np.int64(expected_image) - np.int64(image_result)
max_diff = max(abs(diff.min()), abs(diff.max()), 1)
fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, sharex="all", sharey="all")
ax0.imshow(expected_image)
ax1.imshow(image_result)
im_plot = ax2.imshow(diff, vmin=-max_diff, vmax=max_diff, cmap=plt.cm.bwr)
fig.colorbar(im_plot)
plt.show()
class TestResultImage(unittest.TestCase):
def verify_result_image(self, input_image, expected_image, **plot_kwargs):
# These tests were written assuming uint8 inputs.
self.assertEqual(input_image.dtype, np.uint8)
self.assertEqual(expected_image.dtype, np.uint8)
image_result = rendered_image_result(input_image, **plot_kwargs)
rms = calculate_rms(image_result, expected_image)
self.assertLess(rms, MAX_RMS_ERROR)
def test_horizontal_top_left(self):
# Horizontal orientation with top left origin renders original image.
self.verify_result_image(RGB, IMAGE, origin="top left")
def test_horizontal_bottom_left(self):
# Horizontal orientation with bottom left origin renders a vertically
# flipped image.
self.verify_result_image(RGB, IMAGE[::-1], origin="bottom left")
def test_horizontal_top_right(self):
# Horizontal orientation with top right origin renders a horizontally
# flipped image.
self.verify_result_image(RGB, IMAGE[:, ::-1], origin="top right")
def test_horizontal_bottom_right(self):
# Horizontal orientation with top right origin renders an image flipped
# horizontally and vertically.
self.verify_result_image(RGB, IMAGE[::-1, ::-1], origin="bottom right")
def test_vertical_top_left(self):
# Vertical orientation with top left origin renders transposed image.
self.verify_result_image(
RGB, IMAGE.T, origin="top left", orientation="v"
)
def test_vertical_bottom_left(self):
# Vertical orientation with bottom left origin renders transposed image
# that is vertically flipped.
self.verify_result_image(
RGB, (IMAGE.T)[::-1], origin="bottom left", orientation="v"
)
def test_vertical_top_right(self):
# Vertical orientation with top right origin renders transposed image
# that is horizontally flipped.
self.verify_result_image(
RGB, (IMAGE.T)[:, ::-1], origin="top right", orientation="v"
)
def test_vertical_bottom_right(self):
# Vertical orientation with bottom right origin renders transposed
# image that is flipped vertically and horizontally.
self.verify_result_image(
RGB, (IMAGE.T)[::-1, ::-1], origin="bottom right", orientation="v"
)
# regression test for enthought/chaco#528
@unittest.skipIf(is_null or is_qt4(), "Skip on 'null' or 'qt4' toolkit")
def test_resize_to_zero(self):
class TestResize(HasTraits):
plot = Instance(Component)
# at the exact size of (101, 101) for the component editor the
# image is 0x0 on the screen and the failure was triggered
traits_view = View(
Group(
UItem("plot", editor=ComponentEditor(size=(101, 101))),
),
resizable=True,
)
def _plot_default(self):
pd = ArrayPlotData(image=RGB)
plot = Plot(pd)
plot.img_plot("image")
return plot
test_obj = TestResize()
tester = UITester()
# should not fail
with tester.create_ui(test_obj):
pass