Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / schema.py

## @package schema
# Module caffe2.python.schema
"""
Defines a minimal set of data types that allow to represent datasets with
arbitrary nested structure, including objects of variable length, such as
maps and lists.

This defines a columnar storage format for such datasets on top of caffe2
tensors. In terms of capacity of representation, it can represent most of
the data types supported by Parquet, ORC, DWRF file formats.

See comments in operator_test/dataset_ops_test.py for an example and
walkthrough on how to use schema to store and iterate through a structured
in-memory dataset.
"""





import logging
import numpy as np
from caffe2.python import core
from caffe2.python import workspace
from caffe2.python.core import BlobReference
from collections import OrderedDict, namedtuple
from past.builtins import basestring
from future.utils import viewitems, viewkeys, viewvalues
from itertools import islice
from six import StringIO

logger = logging.getLogger(__name__)

FIELD_SEPARATOR = ':'


def _join_field_name(prefix, suffix):
    if prefix and suffix:
        return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
    elif prefix:
        return prefix
    elif suffix:
        return suffix
    else:
        return ''


def _normalize_field(field_or_type_or_blob, keep_blobs=True):
    """Clones/normalizes a field before adding it to a container."""
    if isinstance(field_or_type_or_blob, Field):
        return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
    elif type(field_or_type_or_blob) in (type, np.dtype):
        return Scalar(dtype=field_or_type_or_blob)
    else:
        return Scalar(blob=field_or_type_or_blob)


FeatureSpec = namedtuple(
    'FeatureSpec',
    [
        'feature_type',
        'feature_names',
        'feature_ids',
        'feature_is_request_only',
        'desired_hash_size',
        'feature_to_index',
    ]
)

FeatureSpec.__new__.__defaults__ = (None, None, None, None, None, None)


class Metadata(
    namedtuple(
        'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
    )
):
    """Represents additional information associated with a scalar in schema.

    `categorical_limit` - for fields of integral type that are guaranteed to be
    non-negative it specifies the maximum possible value plus one. It's often
    used as a size of an embedding table.

    `expected_value` - anticipated average value of elements in the field.
    Usually makes sense for length fields of lists.

    `feature_specs` - information about the features that contained in this
    field. For example if field have more than 1 feature it can have list of
    feature names contained in this field."""
    __slots__ = ()


Metadata.__new__.__defaults__ = (None, None, None)


class Field(object):
    """Represents an abstract field type in a dataset.
    """

    __slots__ = ("_parent", "_field_offsets")

    def __init__(self, children):
        """Derived classes must call this after their initialization."""
        self._parent = (None, 0)
        offset = 0
        self._field_offsets = []
        for child in children:
            self._field_offsets.append(offset)
            offset += len(child.field_names())
        self._field_offsets.append(offset)

    def clone_schema(self):
        return self.clone(keep_blobs=False)

    def field_names(self):
        """Return the children field names for this field."""
        raise NotImplementedError('Field is an abstract class.')

    def field_types(self):
        """Return the numpy.dtype for each of the children fields."""
        raise NotImplementedError('Field is an abstract class.')

    def field_metadata(self):
        """Return the Metadata for each of the children fields."""
        raise NotImplementedError('Field is an abstract class.')

    def field_blobs(self):
        """Return the list of blobs with contents for this Field.
        Values can either be all numpy.ndarray or BlobReference.
        If any of the fields doesn't have a blob, throws.
        """
        raise NotImplementedError('Field is an abstract class.')

    def all_scalars(self):
        """Return the list of all Scalar instances in the Field.
        The order is the same as for field_names() or field_blobs()"""
        raise NotImplementedError('Field is an abstract class.')

    def has_blobs(self):
        """Return True if every scalar of this field has blobs."""
        raise NotImplementedError('Field is an abstract class.')

    def clone(self, keep_blobs=True):
        """Clone this Field along with its children."""
        raise NotImplementedError('Field is an abstract class.')

    def _set_parent(self, parent, relative_id):
        self._parent = (parent, relative_id)

    def slice(self):
        """
        Returns a slice representing the range of field ids that belong to
        this field. This slice can be used to index a list of fields.

        E.g.:

        >>> s = Struct(
        >>>     ('a', Scalar()),
        >>>     ('b', Struct(
        >>>         ('b1', Scalar()),
        >>>         ('b2', Scalar()),
        >>>     )),
        >>>     ('c', Scalar()),
        >>> )
        >>> field_data = ['da', 'db1', 'db2', 'dc']
        >>> field_data[s.b.split()]
        ['db1', 'db2']
        """
        base_id = self._child_base_id()
        return slice(base_id, base_id + len(self.field_names()))

    def _child_base_id(self, child_index=None):
        """Get the base id of the given child"""
        p, i = self._parent
        pos = 0 if child_index is None else self._field_offsets[child_index]
        if p:
            pos += p._child_base_id(i)
        return pos

    def __eq__(self, other):
        """Equivalance of two schemas"""
        return (
            (self.field_names() == other.field_names()) and
            (self.field_types() == other.field_types()) and
            (self.field_metadata() == other.field_metadata())
        )

    def _pprint_impl(self, indent, str_buffer):
        raise NotImplementedError('Field is an abstract class.')

    def __repr__(self):
        str_buffer = StringIO()
        self._pprint_impl(0, str_buffer)
        contents = str_buffer.getvalue()
        str_buffer.close()
        return contents


class List(Field):
    """Represents a variable-length list.

    Values of a list can also be complex fields such as Lists and Structs.
    In addition to the fields exposed by its `values` field, a List exposes an
    additional `lengths` field, which will contain the size of each list under
    the parent domain.
    """

    __slots__ = ("lengths", "_items")

    def __init__(self, values, lengths_blob=None):
        if isinstance(lengths_blob, Field):
            assert isinstance(lengths_blob, Scalar)
            self.lengths = _normalize_field(lengths_blob)
        else:
            self.lengths = Scalar(np.int32, lengths_blob)
        self._items = _normalize_field(values)
        self.lengths._set_parent(self, 0)
        self._items._set_parent(self, 1)
        super(List, self).__init__([self.lengths, self._items])

    def field_names(self):
        value_fields = self._items.field_names()
        return (
            ['lengths'] + [_join_field_name('values', v) for v in value_fields]
        )

    def field_types(self):
        return self.lengths.field_types() + self._items.field_types()

    def field_metadata(self):
        return self.lengths.field_metadata() + self._items.field_metadata()

    def field_blobs(self):
        return self.lengths.field_blobs() + self._items.field_blobs()

    def all_scalars(self):
        return self.lengths.all_scalars() + self._items.all_scalars()

    def has_blobs(self):
        return self.lengths.has_blobs() and self._items.has_blobs()

    def clone(self, keep_blobs=True):
        return type(self)(
            _normalize_field(self._items, keep_blobs=keep_blobs),
            _normalize_field(self.lengths, keep_blobs=keep_blobs)
        )

    def _pprint_impl(self, indent, str_buffer):
        str_buffer.write('  ' * indent + "List(\n")
        str_buffer.write('  ' * (indent + 1) + "lengths=\n")
        self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
        str_buffer.write('  ' * (indent + 1) + "_items=\n")
        self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
        str_buffer.write('  ' * indent + ")\n")

    def __getattr__(self, item):
        """If the value of this list is a struct,
        allow to introspect directly into its fields."""
        if item.startswith('__'):
            raise AttributeError(item)
        if isinstance(self._items, Struct):
            return getattr(self._items, item)
        elif item == 'value' or item == 'items':
            return self._items
        else:
            raise AttributeError('Field not found in list: %s.' % item)

    def __getitem__(self, item):
        names = item.split(FIELD_SEPARATOR, 1)

        if len(names) == 1:
            if item == 'lengths':
                return self.lengths
            elif item == 'values':
                return self._items
        else:
            if names[0] == 'values':
                return self._items[names[1]]
        raise KeyError('Field not found in list: %s.' % item)


class ListWithEvicted(List):
    """
    This class is similar with List, but containing extra field evicted_values for
    LRU Hashing.
    """

    __slots__ = ("_evicted_values",)

    def __init__(self, values, lengths_blob=None, evicted_values=None):
        if isinstance(evicted_values, Field):
            assert isinstance(evicted_values, Scalar)
            self._evicted_values = _normalize_field(evicted_values)
        else:
            self._evicted_values = Scalar(np.int64, evicted_values)
        super(ListWithEvicted, self).__init__(values, lengths_blob=lengths_blob)

    def field_names(self):
        value_fields = self._items.field_names()
        return (
            ['lengths'] + [_join_field_name('values', v) for v in value_fields] + ["_evicted_values"]
        )

    def field_types(self):
        return self.lengths.field_types() + self._items.field_types() + self._evicted_values.field_types()

    def field_metadata(self):
        return self.lengths.field_metadata() + self._items.field_metadata() + self._evicted_values.field_metadata()

    def field_blobs(self):
        return self.lengths.field_blobs() + self._items.field_blobs() + self._evicted_values.field_blobs()

    def all_scalars(self):
        return self.lengths.all_scalars() + self._items.all_scalars() + self._evicted_values.all_scalars()

    def has_blobs(self):
        return self.lengths.has_blobs() and self._items.has_blobs() + self._evicted_values.has_blobs()

    def clone(self, keep_blobs=True):
        return type(self)(
            _normalize_field(self._items, keep_blobs=keep_blobs),
            _normalize_field(self.lengths, keep_blobs=keep_blobs),
            _normalize_field(self._evicted_values, keep_blobs=keep_blobs)
        )

    def _pprint_impl(self, indent, str_buffer):
        str_buffer.write('  ' * indent + "ListWithEvicted(\n")
        str_buffer.write('  ' * (indent + 1) + "lengths=\n")
        self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
        str_buffer.write('  ' * (indent + 1) + "_items=\n")
        self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
        str_buffer.write('  ' * (indent + 1) + "_evicted_values=\n")
        self._evicted_values._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
        str_buffer.write('  ' * indent + ")\n")


    def __getattr__(self, item):
        """If the value of this list is a struct,
        allow to introspect directly into its fields."""
        if item.startswith('__'):
            raise AttributeError(item)
        if item == "_evicted_values":
            return self._evicted_values
        if isinstance(self._items, Struct):
            return getattr(self._items, item)
Loading ...