Repository URL to install this package:
|
Version:
0.36.2 ▾
|
from __future__ import print_function
import warnings
import numpy as np
import numba.unittest_support as unittest
from numba import jit
from numba.errors import NumbaWarning, deprecated
class TestBuiltins(unittest.TestCase):
def test_type_infer_warning(self):
def add(x, y):
a = {}
return x + y
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always', NumbaWarning)
cfunc = jit(add)
cfunc(1, 2)
self.assertEqual(len(w), 2)
# Type inference failure
self.assertEqual(w[0].category, NumbaWarning)
self.assertIn('type inference', str(w[0].message))
# Object mode
self.assertEqual(w[1].category, NumbaWarning)
self.assertIn('object mode', str(w[1].message))
def test_return_type_warning(self):
y = np.ones(4, dtype=np.float32)
def return_external_array():
return y
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always', NumbaWarning)
cfunc = jit(_nrt=False)(return_external_array)
cfunc()
self.assertEqual(len(w), 2)
# Legal return value failure
self.assertEqual(w[0].category, NumbaWarning)
self.assertIn('return type', str(w[0].message))
# Object mode
self.assertEqual(w[1].category, NumbaWarning)
self.assertIn('object mode', str(w[1].message))
def test_return_type_warning_with_nrt(self):
"""
Rerun test_return_type_warning with nrt
"""
y = np.ones(4, dtype=np.float32)
def return_external_array():
return y
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always', NumbaWarning)
cfunc = jit(return_external_array)
cfunc()
# No more warning
self.assertEqual(len(w), 0)
def test_no_warning_with_forceobj(self):
def add(x, y):
a = []
return x + y
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always', NumbaWarning)
cfunc = jit(add, forceobj=True)
cfunc(1, 2)
self.assertEqual(len(w), 0)
def test_loop_lift_warn(self):
def do_loop(x):
a = {}
for i in range(x.shape[0]):
x[i] *= 2
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always', NumbaWarning)
x = np.ones(4, dtype=np.float32)
cfunc = jit(do_loop)
cfunc(x)
self.assertEqual(len(w), 3)
# Type inference failure (1st pass)
self.assertEqual(w[0].category, NumbaWarning)
self.assertIn('type inference', str(w[0].message))
# Type inference failure (2nd pass, with lifted loops)
self.assertEqual(w[1].category, NumbaWarning)
self.assertIn('type inference', str(w[1].message))
# Object mode
self.assertEqual(w[2].category, NumbaWarning)
self.assertIn('object mode', str(w[2].message))
self.assertIn('lifted loops', str(w[2].message))
def test_deprecated(self):
@deprecated('foo')
def bar(): pass
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
bar()
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, DeprecationWarning)
self.assertIn('bar', str(w[0].message))
self.assertIn('foo', str(w[0].message))
if __name__ == '__main__':
unittest.main()