Repository URL to install this package:
|
Version:
0.15.1 ▾
|
from __future__ import absolute_import, print_function
import os
import sys
import tempfile
import numpy
from numpy.testing import TestCase, assert_, run_module_suite
from scipy.weave import inline_tools, ext_tools
from scipy.weave.build_tools import msvc_exists, gcc_exists
from scipy.weave.catalog import unique_file
from scipy.weave.numpy_scalar_spec import numpy_complex_scalar_converter
from weave_test_utils import dec
def unique_mod(d,file_name):
f = os.path.basename(unique_file(d,file_name))
m = os.path.splitext(f)[0]
return m
#----------------------------------------------------------------------------
# Scalar conversion test classes
# int, float, complex
#----------------------------------------------------------------------------
class NumpyComplexScalarConverter(TestCase):
compiler = ''
def setUp(self):
self.converter = numpy_complex_scalar_converter()
@dec.slow
def test_type_match_string(self):
assert_(not self.converter.type_match('string'))
@dec.slow
def test_type_match_int(self):
assert_(not self.converter.type_match(5))
@dec.slow
def test_type_match_float(self):
assert_(not self.converter.type_match(5.))
@dec.slow
def test_type_match_complex128(self):
assert_(self.converter.type_match(numpy.complex128(5.+1j)))
@dec.slow
def test_complex_var_in(self):
mod_name = sys._getframe().f_code.co_name + self.compiler
mod_name = unique_mod(test_dir,mod_name)
mod = ext_tools.ext_module(mod_name)
a = numpy.complex(1.+1j)
code = "a=std::complex<double>(2.,2.);"
test = ext_tools.ext_function('test',code,['a'])
mod.add_function(test)
mod.compile(location=test_dir, compiler=self.compiler)
exec('from ' + mod_name + ' import test')
b = numpy.complex128(1.+1j)
test(b)
try:
b = 1.
test(b)
except TypeError:
pass
try:
b = 'abc'
test(b)
except TypeError:
pass
@dec.slow
def test_complex_return(self):
mod_name = sys._getframe().f_code.co_name + self.compiler
mod_name = unique_mod(test_dir,mod_name)
mod = ext_tools.ext_module(mod_name)
a = 1.+1j
code = """
a= a + std::complex<double>(2.,2.);
return_val = PyComplex_FromDoubles(a.real(),a.imag());
"""
test = ext_tools.ext_function('test',code,['a'])
mod.add_function(test)
mod.compile(location=test_dir, compiler=self.compiler)
exec('from ' + mod_name + ' import test')
b = 1.+1j
c = test(b)
assert_(c == 3.+3j)
@dec.slow
def test_inline(self):
a = numpy.complex128(1+1j)
result = inline_tools.inline("return_val=1.0/a;",['a'])
assert_(result == .5-.5j)
for _n in dir():
if _n[-9:] == 'Converter':
if msvc_exists():
exec("class Test%sMsvc(%s):\n compiler = 'msvc'" % (_n,_n))
else:
exec("class Test%sUnix(%s):\n compiler = ''" % (_n,_n))
if gcc_exists():
exec("class Test%sGcc(%s):\n compiler = 'gcc'" % (_n,_n))
def setup_test_location():
test_dir = tempfile.mkdtemp()
sys.path.insert(0,test_dir)
return test_dir
test_dir = setup_test_location()
def teardown_test_location():
import tempfile
test_dir = os.path.join(tempfile.gettempdir(),'test_files')
if sys.path[0] == test_dir:
sys.path = sys.path[1:]
return test_dir
if not msvc_exists():
for _n in dir():
if _n[:8] == 'TestMsvc':
exec('del '+_n)
else:
for _n in dir():
if _n[:8] == 'TestUnix':
exec('del '+_n)
if not (gcc_exists() and msvc_exists() and sys.platform == 'win32'):
for _n in dir():
if _n[:7] == 'TestGcc':
exec('del '+_n)
if __name__ == "__main__":
run_module_suite()