/ ma / mrecords.py


Defines the equivalent of :class:`numpy.recarrays` for masked arrays,
where fields can be accessed as attributes.
Note that :class:`numpy.ma.MaskedArray` already supports structured datatypes
and the masking of individual fields.

.. moduleauthor:: Pierre Gerard-Marchant

#  We should make sure that no field is called '_mask','mask','_fieldmask',
#  or whatever restricted keywords.  An idea would be to no bother in the
#  first place, and then rename the invalid fields with a trailing
#  underscore. Maybe we could just overload the parser function ?

import warnings

import numpy as np
from numpy import (
        bool_, dtype, ndarray, recarray, array as narray
from numpy.core.records import (
        fromarrays as recfromarrays, fromrecords as recfromrecords

_byteorderconv = np.core.records._byteorderconv

import numpy.ma as ma
from numpy.ma import (
        MAError, MaskedArray, masked, nomask, masked_array, getdata,
        getmaskarray, filled

_check_fill_value = ma.core._check_fill_value

__all__ = [
    'MaskedRecords', 'mrecarray', 'fromarrays', 'fromrecords',
    'fromtextfile', 'addfield',

reserved_fields = ['_data', '_mask', '_fieldmask', 'dtype']

def _checknames(descr, names=None):
    Checks that field names ``descr`` are not reserved keywords.

    If this is the case, a default 'f%i' is substituted.  If the argument
    `names` is not None, updates the field names to valid names.

    ndescr = len(descr)
    default_names = ['f%i' % i for i in range(ndescr)]
    if names is None:
        new_names = default_names
        if isinstance(names, (tuple, list)):
            new_names = names
        elif isinstance(names, str):
            new_names = names.split(',')
            raise NameError("illegal input names %s" % repr(names))
        nnames = len(new_names)
        if nnames < ndescr:
            new_names += default_names[nnames:]
    ndescr = []
    for (n, d, t) in zip(new_names, default_names, descr.descr):
        if n in reserved_fields:
            if t[0] in reserved_fields:
                ndescr.append((d, t[1]))
            ndescr.append((n, t[1]))
    return np.dtype(ndescr)

def _get_fieldmask(self):
    mdescr = [(n, '|b1') for n in self.dtype.names]
    fdmask = np.empty(self.shape, dtype=mdescr)
    fdmask.flat = tuple([False] * len(mdescr))
    return fdmask

class MaskedRecords(MaskedArray):

    _data : recarray
        Underlying data, as a record array.
    _mask : boolean array
        Mask of the records. A record is masked when all its fields are
    _fieldmask : boolean recarray
        Record array of booleans, setting the mask of each individual field
        of each record.
    _fill_value : record
        Filling values for each field.


    def __new__(cls, shape, dtype=None, buf=None, offset=0, strides=None,
                formats=None, names=None, titles=None,
                byteorder=None, aligned=False,
                mask=nomask, hard_mask=False, fill_value=None, keep_mask=True,

        self = recarray.__new__(cls, shape, dtype=dtype, buf=buf, offset=offset,
                                strides=strides, formats=formats, names=names,
                                titles=titles, byteorder=byteorder,

        mdtype = ma.make_mask_descr(self.dtype)
        if mask is nomask or not np.size(mask):
            if not keep_mask:
                self._mask = tuple([False] * len(mdtype))
            mask = np.array(mask, copy=copy)
            if mask.shape != self.shape:
                (nd, nm) = (self.size, mask.size)
                if nm == 1:
                    mask = np.resize(mask, self.shape)
                elif nm == nd:
                    mask = np.reshape(mask, self.shape)
                    msg = "Mask and data not compatible: data size is %i, " + \
                          "mask size is %i."
                    raise MAError(msg % (nd, nm))
                copy = True
            if not keep_mask:
                self._sharedmask = True
                if mask.dtype == mdtype:
                    _mask = mask
                    _mask = np.array([tuple([m] * len(mdtype)) for m in mask],
                self._mask = _mask
        return self

    def __array_finalize__(self, obj):
        # Make sure we have a _fieldmask by default
        _mask = getattr(obj, '_mask', None)
        if _mask is None:
            objmask = getattr(obj, '_mask', nomask)
            _dtype = ndarray.__getattribute__(self, 'dtype')
            if objmask is nomask:
                _mask = ma.make_mask_none(self.shape, dtype=_dtype)
                mdescr = ma.make_mask_descr(_dtype)
                _mask = narray([tuple([m] * len(mdescr)) for m in objmask],
        # Update some of the attributes
        _dict = self.__dict__
        if _dict['_baseclass'] == ndarray:
            _dict['_baseclass'] = recarray

    def _data(self):
        Returns the data as a recarray.

        return ndarray.view(self, recarray)

    def _fieldmask(self):
        Alias to mask.

        return self._mask

    def __len__(self):
        Returns the length

        # We have more than one record
        if self.ndim:
            return len(self._data)
        # We have only one record: return the nb of fields
        return len(self.dtype)

    def __getattribute__(self, attr):
            return object.__getattribute__(self, attr)
        except AttributeError:
            # attr must be a fieldname
        fielddict = ndarray.__getattribute__(self, 'dtype').fields
            res = fielddict[attr][:2]
        except (TypeError, KeyError):
            raise AttributeError("record array has no attribute %s" % attr)
        # So far, so good
        _localdict = ndarray.__getattribute__(self, '__dict__')
        _data = ndarray.view(self, _localdict['_baseclass'])
        obj = _data.getfield(*res)
        if obj.dtype.names is not None:
            raise NotImplementedError("MaskedRecords is currently limited to"
                                      "simple records.")
        # Get some special attributes
        # Reset the object's mask
        hasmasked = False
        _mask = _localdict.get('_mask', None)
        if _mask is not None:
                _mask = _mask[attr]
            except IndexError:
                # Couldn't find a mask: use the default (nomask)
            tp_len = len(_mask.dtype)
            hasmasked = _mask.view((bool, ((tp_len,) if tp_len else ()))).any()
        if (obj.shape or hasmasked):
            obj = obj.view(MaskedArray)
            obj._baseclass = ndarray
            obj._isfield = True
            obj._mask = _mask
            # Reset the field values
            _fill_value = _localdict.get('_fill_value', None)
            if _fill_value is not None:
                    obj._fill_value = _fill_value[attr]
                except ValueError:
                    obj._fill_value = None
            obj = obj.item()
        return obj

    def __setattr__(self, attr, val):
        Sets the attribute attr to the value val.

        # Should we call __setmask__ first ?
        if attr in ['mask', 'fieldmask']:
        # Create a shortcut (so that we don't have to call getattr all the time)
        _localdict = object.__getattribute__(self, '__dict__')
        # Check whether we're creating a new field
        newattr = attr not in _localdict
            # Is attr a generic attribute ?
            ret = object.__setattr__(self, attr, val)
        except Exception:
            # Not a generic attribute: exit if it's not a valid field
            fielddict = ndarray.__getattribute__(self, 'dtype').fields or {}
            optinfo = ndarray.__getattribute__(self, '_optinfo') or {}
            if not (attr in fielddict or attr in optinfo):
            # Get the list of names
            fielddict = ndarray.__getattribute__(self, 'dtype').fields or {}
            # Check the attribute
            if attr not in fielddict:
                return ret
            if newattr:
                # We just added this one or this setattr worked on an
                # internal attribute.
                    object.__delattr__(self, attr)
                except Exception:
                    return ret
        # Let's try to set the field
            res = fielddict[attr][:2]
        except (TypeError, KeyError):
            raise AttributeError("record array has no attribute %s" % attr)

        if val is masked:
            _fill_value = _localdict['_fill_value']
            if _fill_value is not None:
                dval = _localdict['_fill_value'][attr]
                dval = val
            mval = True
            dval = filled(val)
            mval = getmaskarray(val)
        obj = ndarray.__getattribute__(self, '_data').setfield(dval, *res)
        _localdict['_mask'].__setitem__(attr, mval)
        return obj

    def __getitem__(self, indx):
        Returns all the fields sharing the same fieldname base.

        The fieldname base is either `_data` or `_mask`.

        _localdict = self.__dict__
        _mask = ndarray.__getattribute__(self, '_mask')
        _data = ndarray.view(self, _localdict['_baseclass'])
        # We want a field
        if isinstance(indx, str):
            # Make sure _sharedmask is True to propagate back to _fieldmask
            # Don't use _set_mask, there are some copies being made that
            # break propagation Don't force the mask to nomask, that wreaks
            # easy masking
            obj = _data[indx].view(MaskedArray)
            obj._mask = _mask[indx]
            obj._sharedmask = True
            fval = _localdict['_fill_value']
            if fval is not None:
                obj._fill_value = fval[indx]
            # Force to masked if the mask is True
            if not obj.ndim and obj._mask:
                return masked
            return obj
        # We want some elements.
        # First, the data.
        obj = np.array(_data[indx], copy=False).view(mrecarray)
        obj._mask = np.array(_mask[indx], copy=False).view(recarray)
        return obj

    def __setitem__(self, indx, value):
        Sets the given record to value.

        MaskedArray.__setitem__(self, indx, value)
        if isinstance(indx, str):
            self._mask[indx] = ma.getmaskarray(value)

    def __str__(self):
        Calculates the string representation.

        if self.size > 1:
            mstr = ["(%s)" % ",".join([str(i) for i in s])
                    for s in zip(*[getattr(self, f) for f in self.dtype.names])]
            return "[%s]" % ", ".join(mstr)
            mstr = ["%s" % ",".join([str(i) for i in s])
                    for s in zip([getattr(self, f) for f in self.dtype.names])]
