Repository URL to install this package:
|
Version:
0.36.2 ▾
|
from __future__ import absolute_import, print_function
from collections import OrderedDict, Sequence
import types as pytypes
import inspect
from llvmlite import ir as llvmir
from numba import types
from numba.targets.registry import cpu_target
from numba import njit
from numba.typing import templates
from numba.datamodel import default_manager, models
from numba.targets import imputils
from numba import cgutils, utils
from numba.six import exec_
from . import _box
##############################################################################
# Data model
class InstanceModel(models.StructModel):
def __init__(self, dmm, fe_typ):
cls_data_ty = types.ClassDataType(fe_typ)
# MemInfoPointer uses the `dtype` attribute to traverse for nested
# NRT MemInfo. Since we handle nested NRT MemInfo ourselves,
# we will replace provide MemInfoPointer with an opaque type
# so that it does not raise exception for nested meminfo.
dtype = types.Opaque('Opaque.' + str(cls_data_ty))
members = [
('meminfo', types.MemInfoPointer(dtype)),
('data', types.CPointer(cls_data_ty)),
]
super(InstanceModel, self).__init__(dmm, fe_typ, members)
class InstanceDataModel(models.StructModel):
def __init__(self, dmm, fe_typ):
clsty = fe_typ.class_type
members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()]
super(InstanceDataModel, self).__init__(dmm, fe_typ, members)
default_manager.register(types.ClassInstanceType, InstanceModel)
default_manager.register(types.ClassDataType, InstanceDataModel)
default_manager.register(types.ClassType, models.OpaqueModel)
def _mangle_attr(name):
"""
Mangle attributes.
The resulting name does not startswith an underscore '_'.
"""
return 'm_' + name
##############################################################################
# Class object
_ctor_template = """
def ctor({args}):
return __numba_cls_({args})
"""
def _getargs(fn):
"""
Returns list of positional and keyword argument names in order.
"""
sig = utils.pysignature(fn)
params = sig.parameters
args = [k for k, v in params.items()
if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD]
return args
class JitClassType(type):
"""
The type of any jitclass.
"""
def __new__(cls, name, bases, dct):
if len(bases) != 1:
raise TypeError("must have exactly one base class")
[base] = bases
if isinstance(base, JitClassType):
raise TypeError("cannot subclass from a jitclass")
assert 'class_type' in dct, 'missing "class_type" attr'
outcls = type.__new__(cls, name, bases, dct)
outcls._set_init()
return outcls
def _set_init(cls):
"""
Generate a wrapper for calling the constructor from pure Python.
Note the wrapper will only accept positional arguments.
"""
init = cls.class_type.instance_type.methods['__init__']
# get postitional and keyword arguments
# offset by one to exclude the `self` arg
args = _getargs(init)[1:]
ctor_source = _ctor_template.format(args=', '.join(args))
glbls = {"__numba_cls_": cls}
exec_(ctor_source, glbls)
ctor = glbls['ctor']
cls._ctor = njit(ctor)
def __instancecheck__(cls, instance):
if isinstance(instance, _box.Box):
return instance._numba_type_.class_type is cls.class_type
return False
def __call__(cls, *args, **kwargs):
return cls._ctor(*args, **kwargs)
##############################################################################
# Registration utils
def _validate_spec(spec):
for k, v in spec.items():
if not isinstance(k, str):
raise TypeError("spec keys should be strings, got %r" % (k,))
if not isinstance(v, types.Type):
raise TypeError("spec values should be Numba type instances, got %r"
% (v,))
def _fix_up_private_attr(clsname, spec):
"""
Apply the same changes to dunder names as CPython would.
"""
out = OrderedDict()
for k, v in spec.items():
if k.startswith('__') and not k.endswith('__'):
k = '_' + clsname + k
out[k] = v
return out
def register_class_type(cls, spec, class_ctor, builder):
"""
Internal function to create a jitclass.
Args
----
cls: the original class object (used as the prototype)
spec: the structural specification contains the field types.
class_ctor: the numba type to represent the jitclass
builder: the internal jitclass builder
"""
# Normalize spec
if isinstance(spec, Sequence):
spec = OrderedDict(spec)
_validate_spec(spec)
# Fix up private attribute names
spec = _fix_up_private_attr(cls.__name__, spec)
# Copy methods from base classes
clsdct = {}
for basecls in reversed(inspect.getmro(cls)):
clsdct.update(basecls.__dict__)
methods = dict((k, v) for k, v in clsdct.items()
if isinstance(v, pytypes.FunctionType))
props = dict((k, v) for k, v in clsdct.items()
if isinstance(v, property))
others = dict((k, v) for k, v in clsdct.items()
if k not in methods and k not in props)
# Check for name shadowing
shadowed = (set(methods) | set(props)) & set(spec)
if shadowed:
raise NameError("name shadowing: {0}".format(', '.join(shadowed)))
docstring = others.pop('__doc__', "")
_drop_ignored_attrs(others)
if others:
msg = "class members are not yet supported: {0}"
members = ', '.join(others.keys())
raise TypeError(msg.format(members))
for k, v in props.items():
if v.fdel is not None:
raise TypeError("deleter is not supported: {0}".format(k))
jitmethods = {}
for k, v in methods.items():
jitmethods[k] = njit(v)
jitprops = {}
for k, v in props.items():
dct = {}
if v.fget:
dct['get'] = njit(v.fget)
if v.fset:
dct['set'] = njit(v.fset)
jitprops[k] = dct
# Instantiate class type
class_type = class_ctor(cls, ConstructorTemplate, spec, jitmethods,
jitprops)
cls = JitClassType(cls.__name__, (cls,), dict(class_type=class_type,
__doc__=docstring))
# Register resolution of the class object
typingctx = cpu_target.typing_context
typingctx.insert_global(cls, class_type)
# Register class
targetctx = cpu_target.target_context
builder(class_type, methods, typingctx, targetctx).register()
return cls
class ConstructorTemplate(templates.AbstractTemplate):
"""
Base class for jitclass constructor templates.
"""
def generic(self, args, kws):
# Redirect resolution to __init__
instance_type = self.key.instance_type
ctor = instance_type.jitmethods['__init__']
boundargs = (instance_type.get_reference_type(),) + args
disp_type = types.Dispatcher(ctor)
sig = disp_type.get_call_type(self.context, boundargs, kws)
# Actual constructor returns an instance value (not None)
out = templates.signature(instance_type, *sig.args[1:])
return out
def _drop_ignored_attrs(dct):
# ignore anything defined by object
drop = set(['__weakref__',
'__module__',
'__dict__'])
for k, v in dct.items():
if isinstance(v, (pytypes.BuiltinFunctionType,
pytypes.BuiltinMethodType)):
drop.add(k)
elif getattr(v, '__objclass__', None) is object:
drop.add(k)
for k in drop:
del dct[k]
class ClassBuilder(object):
"""
A jitclass builder for a mutable jitclass. This will register
typing and implementation hooks to the given typing and target contexts.
"""
class_impl_registry = imputils.Registry()
implemented_methods = set()
def __init__(self, class_type, methods, typingctx, targetctx):
self.class_type = class_type
self.methods = methods
self.typingctx = typingctx
self.targetctx = targetctx
def register(self):
"""
Register to the frontend and backend.
"""
# Register generic implementations for all jitclasses
self._register_methods(self.class_impl_registry,
self.class_type.instance_type)
# NOTE other registrations are done at the top-level
# (see ctor_impl and attr_impl below)
self.targetctx.install_registry(self.class_impl_registry)
def _register_methods(self, registry, instance_type):
"""
Register method implementations for the given instance type.
"""
for meth in instance_type.jitmethods:
# There's no way to retrive the particular method name
# inside the implementation function, so we have to register a
# specific closure for each different name
if meth not in self.implemented_methods:
self._implement_method(registry, meth)
self.implemented_methods.add(meth)
def _implement_method(self, registry, attr):
@registry.lower((types.ClassInstanceType, attr),
types.ClassInstanceType, types.VarArg(types.Any))
def imp(context, builder, sig, args):
instance_type = sig.args[0]
method = instance_type.jitmethods[attr]
disp_type = types.Dispatcher(method)
call = context.get_function(disp_type, sig)
out = call(builder, args)
return imputils.impl_ret_new_ref(context, builder,
sig.return_type, out)
@templates.infer_getattr
class ClassAttribute(templates.AttributeTemplate):
key = types.ClassInstanceType
def generic_resolve(self, instance, attr):
if attr in instance.struct:
# It's a struct field => the type is well-known
return instance.struct[attr]
elif attr in instance.jitmethods:
# It's a jitted method => typeinfer it
meth = instance.jitmethods[attr]
disp_type = types.Dispatcher(meth)
class MethodTemplate(templates.AbstractTemplate):
key = (self.key, attr)
def generic(self, args, kws):
args = (instance,) + tuple(args)
sig = disp_type.get_call_type(self.context, args, kws)
return sig.as_method()
return types.BoundFunction(MethodTemplate, instance)
elif attr in instance.jitprops:
# It's a jitted property => typeinfer its getter
impdct = instance.jitprops[attr]
getter = impdct['get']
disp_type = types.Dispatcher(getter)
sig = disp_type.get_call_type(self.context, (instance,), {})
return sig.return_type
@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType)
def attr_impl(context, builder, typ, value, attr):
"""
Generic getattr() for @jitclass instances.
"""
if attr in typ.struct:
# It's a struct field
inst = context.make_helper(builder, typ, value=value)
data_pointer = inst.data
data = context.make_data_helper(builder, typ.get_data_type(),
ref=data_pointer)
return imputils.impl_ret_borrowed(context, builder,
typ.struct[attr],
getattr(data, _mangle_attr(attr)))
elif attr in typ.jitprops:
# It's a jitted property
getter = typ.jitprops[attr]['get']
sig = templates.signature(None, typ)
dispatcher = types.Dispatcher(getter)
sig = dispatcher.get_call_type(context.typing_context, [typ], {})
call = context.get_function(dispatcher, sig)
out = call(builder, [value])
return imputils.impl_ret_new_ref(context, builder, sig.return_type, out)
raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType)
def attr_impl(context, builder, sig, args, attr):
"""
Generic setattr() for @jitclass instances.
"""
typ, valty = sig.args
target, val = args
if attr in typ.struct:
# It's a struct member
inst = context.make_helper(builder, typ, value=target)
data_ptr = inst.data
data = context.make_data_helper(builder, typ.get_data_type(),
ref=data_ptr)
# Get old value
attr_type = typ.struct[attr]
oldvalue = getattr(data, _mangle_attr(attr))
# Store n
setattr(data, _mangle_attr(attr), val)
context.nrt.incref(builder, attr_type, val)
# Delete old value
context.nrt.decref(builder, attr_type, oldvalue)
elif attr in typ.jitprops:
# It's a jitted property
setter = typ.jitprops[attr]['set']
disp_type = types.Dispatcher(setter)
sig = disp_type.get_call_type(context.typing_context,
(typ, valty), {})
call = context.get_function(disp_type, sig)
call(builder, (target, val))
else:
raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
def imp_dtor(context, module, instance_type):
llvoidptr = context.get_value_type(types.voidptr)
llsize = context.get_value_type(types.uintp)
dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),
[llvoidptr, llsize, llvoidptr])
fname = "_Dtor.{0}".format(instance_type.name)
dtor_fn = module.get_or_insert_function(dtor_ftype,
name=fname)
if dtor_fn.is_declaration:
# Define
builder = llvmir.IRBuilder(dtor_fn.append_basic_block())
alloc_fe_type = instance_type.get_data_type()
alloc_type = context.get_value_type(alloc_fe_type)
ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())
data = context.make_helper(builder, alloc_fe_type, ref=ptr)
context.nrt.decref(builder, alloc_fe_type, data._getvalue())
builder.ret_void()
return dtor_fn
@ClassBuilder.class_impl_registry.lower(types.ClassType, types.VarArg(types.Any))
def ctor_impl(context, builder, sig, args):
"""
Generic constructor (__new__) for jitclasses.
"""
# Allocate the instance
inst_typ = sig.return_type
alloc_type = context.get_data_type(inst_typ.get_data_type())
alloc_size = context.get_abi_sizeof(alloc_type)
meminfo = context.nrt.meminfo_alloc_dtor(
builder,
context.get_constant(types.uintp, alloc_size),
imp_dtor(context, builder.module, inst_typ),
)
data_pointer = context.nrt.meminfo_data(builder, meminfo)
data_pointer = builder.bitcast(data_pointer,
alloc_type.as_pointer())
# Nullify all data
builder.store(cgutils.get_null_value(alloc_type),
data_pointer)
inst_struct = context.make_helper(builder, inst_typ)
inst_struct.meminfo = meminfo
inst_struct.data = data_pointer
# Call the jitted __init__
# TODO: extract the following into a common util
init_sig = (sig.return_type,) + sig.args
init = inst_typ.jitmethods['__init__']
disp_type = types.Dispatcher(init)
call = context.get_function(disp_type, types.void(*init_sig))
realargs = [inst_struct._getvalue()] + list(args)
call(builder, realargs)
# Prepare return value
ret = inst_struct._getvalue()
return imputils.impl_ret_new_ref(context, builder, inst_typ, ret)