Repository URL to install this package:
|
Version:
0.36.2 ▾
|
from __future__ import print_function
from numba import jit
import numba.unittest_support as unittest
import numpy as np
import copy
from .support import MemoryLeakMixin
try:
xrange
except NameError:
xrange = range
@jit
def inc(a):
for i in xrange(len(a)):
a[i] += 1
return a
@jit
def inc1(a):
a[0] += 1
return a[0]
@jit
def inc2(a):
a[0] += 1
return a[0], a[0] + 1
def chain1(a):
x = y = z = inc(a)
return x + y + z
def chain2(v):
a = np.zeros(2)
a[0] = x = a[1] = v
return a[0] + a[1] + (x / 2)
def unpack1(x, y):
a, b = x, y
return a + b / 2
def unpack2(x, y):
a, b = c, d = inc1(x), inc1(y)
return a + c / 2, b + d / 2
def chain3(x, y):
a = (b, c) = (inc1(x), inc1(y))
(d, e) = f = (inc1(x), inc1(y))
return (a[0] + b / 2 + d + f[0]), (a[1] + c + e / 2 + f[1])
def unpack3(x):
a, b = inc2(x)
return a + b / 2
def unpack4(x):
a, b = c, d = inc2(x)
return a + c / 2, b + d / 2
def unpack5(x):
a = b, c = inc2(x)
d, e = f = inc2(x)
return (a[0] + b / 2 + d + f[0]), (a[1] + c + e / 2 + f[1])
def unpack6(x, y):
(a, b), (c, d) = (x, y), (y + 1, x + 1)
return a + c / 2, b / 2 + d
class TestChainedAssign(MemoryLeakMixin, unittest.TestCase):
def test_chain1(self):
args = [
[np.arange(2)],
[np.arange(4, dtype=np.double)],
]
self._test_template(chain1, args)
def test_chain2(self):
args = [
[3],
[3.0],
]
self._test_template(chain2, args)
def test_unpack1(self):
args = [
[1, 3.0],
[1.0, 3],
]
self._test_template(unpack1, args)
def test_unpack2(self):
args = [
[np.array([2]), np.array([4.0])],
[np.array([2.0]), np.array([4])],
]
self._test_template(unpack2, args)
def test_chain3(self):
args = [
[np.array([0]), np.array([1.5])],
[np.array([0.5]), np.array([1])],
]
self._test_template(chain3, args)
def test_unpack3(self):
args = [
[np.array([1])],
[np.array([1.0])],
]
self._test_template(unpack3, args)
def test_unpack4(self):
args = [
[np.array([1])],
[np.array([1.0])],
]
self._test_template(unpack4, args)
def test_unpack5(self):
args = [
[np.array([2])],
[np.array([2.0])],
]
self._test_template(unpack5, args)
def test_unpack6(self):
args1 = 3.0, 2
args2 = 3.0, 2.0
self._test_template(unpack6, [args1, args2])
def _test_template(self, pyfunc, argcases):
cfunc = jit(pyfunc)
for args in argcases:
a1 = copy.deepcopy(args)
a2 = copy.deepcopy(args)
np.testing.assert_allclose(pyfunc(*a1), cfunc(*a2))
if __name__ == '__main__':
unittest.main()