## @package schema
# Module caffe2.python.schema
"""
Defines a minimal set of data types that allow to represent datasets with
arbitrary nested structure, including objects of variable length, such as
maps and lists.
This defines a columnar storage format for such datasets on top of caffe2
tensors. In terms of capacity of representation, it can represent most of
the data types supported by Parquet, ORC, DWRF file formats.
See comments in operator_test/dataset_ops_test.py for an example and
walkthrough on how to use schema to store and iterate through a structured
in-memory dataset.
"""
import logging
import numpy as np
from caffe2.python import core
from caffe2.python import workspace
from caffe2.python.core import BlobReference
from collections import OrderedDict, namedtuple
from past.builtins import basestring
from future.utils import viewitems, viewkeys, viewvalues
from itertools import islice
from six import StringIO
logger = logging.getLogger(__name__)
FIELD_SEPARATOR = ':'
def _join_field_name(prefix, suffix):
if prefix and suffix:
return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
elif prefix:
return prefix
elif suffix:
return suffix
else:
return ''
def _normalize_field(field_or_type_or_blob, keep_blobs=True):
"""Clones/normalizes a field before adding it to a container."""
if isinstance(field_or_type_or_blob, Field):
return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
elif type(field_or_type_or_blob) in (type, np.dtype):
return Scalar(dtype=field_or_type_or_blob)
else:
return Scalar(blob=field_or_type_or_blob)
FeatureSpec = namedtuple(
'FeatureSpec',
[
'feature_type',
'feature_names',
'feature_ids',
'feature_is_request_only',
'desired_hash_size',
'feature_to_index',
]
)
FeatureSpec.__new__.__defaults__ = (None, None, None, None, None, None)
class Metadata(
namedtuple(
'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
)
):
"""Represents additional information associated with a scalar in schema.
`categorical_limit` - for fields of integral type that are guaranteed to be
non-negative it specifies the maximum possible value plus one. It's often
used as a size of an embedding table.
`expected_value` - anticipated average value of elements in the field.
Usually makes sense for length fields of lists.
`feature_specs` - information about the features that contained in this
field. For example if field have more than 1 feature it can have list of
feature names contained in this field."""
__slots__ = ()
Metadata.__new__.__defaults__ = (None, None, None)
class Field(object):
"""Represents an abstract field type in a dataset.
"""
__slots__ = ("_parent", "_field_offsets")
def __init__(self, children):
"""Derived classes must call this after their initialization."""
self._parent = (None, 0)
offset = 0
self._field_offsets = []
for child in children:
self._field_offsets.append(offset)
offset += len(child.field_names())
self._field_offsets.append(offset)
def clone_schema(self):
return self.clone(keep_blobs=False)
def field_names(self):
"""Return the children field names for this field."""
raise NotImplementedError('Field is an abstract class.')
def field_types(self):
"""Return the numpy.dtype for each of the children fields."""
raise NotImplementedError('Field is an abstract class.')
def field_metadata(self):
"""Return the Metadata for each of the children fields."""
raise NotImplementedError('Field is an abstract class.')
def field_blobs(self):
"""Return the list of blobs with contents for this Field.
Values can either be all numpy.ndarray or BlobReference.
If any of the fields doesn't have a blob, throws.
"""
raise NotImplementedError('Field is an abstract class.')
def all_scalars(self):
"""Return the list of all Scalar instances in the Field.
The order is the same as for field_names() or field_blobs()"""
raise NotImplementedError('Field is an abstract class.')
def has_blobs(self):
"""Return True if every scalar of this field has blobs."""
raise NotImplementedError('Field is an abstract class.')
def clone(self, keep_blobs=True):
"""Clone this Field along with its children."""
raise NotImplementedError('Field is an abstract class.')
def _set_parent(self, parent, relative_id):
self._parent = (parent, relative_id)
def slice(self):
"""
Returns a slice representing the range of field ids that belong to
this field. This slice can be used to index a list of fields.
E.g.:
>>> s = Struct(
>>> ('a', Scalar()),
>>> ('b', Struct(
>>> ('b1', Scalar()),
>>> ('b2', Scalar()),
>>> )),
>>> ('c', Scalar()),
>>> )
>>> field_data = ['da', 'db1', 'db2', 'dc']
>>> field_data[s.b.split()]
['db1', 'db2']
"""
base_id = self._child_base_id()
return slice(base_id, base_id + len(self.field_names()))
def _child_base_id(self, child_index=None):
"""Get the base id of the given child"""
p, i = self._parent
pos = 0 if child_index is None else self._field_offsets[child_index]
if p:
pos += p._child_base_id(i)
return pos
def __eq__(self, other):
"""Equivalance of two schemas"""
return (
(self.field_names() == other.field_names()) and
(self.field_types() == other.field_types()) and
(self.field_metadata() == other.field_metadata())
)
def _pprint_impl(self, indent, str_buffer):
raise NotImplementedError('Field is an abstract class.')
def __repr__(self):
str_buffer = StringIO()
self._pprint_impl(0, str_buffer)
contents = str_buffer.getvalue()
str_buffer.close()
return contents
class List(Field):
"""Represents a variable-length list.
Values of a list can also be complex fields such as Lists and Structs.
In addition to the fields exposed by its `values` field, a List exposes an
additional `lengths` field, which will contain the size of each list under
the parent domain.
"""
__slots__ = ("lengths", "_items")
def __init__(self, values, lengths_blob=None):
if isinstance(lengths_blob, Field):
assert isinstance(lengths_blob, Scalar)
self.lengths = _normalize_field(lengths_blob)
else:
self.lengths = Scalar(np.int32, lengths_blob)
self._items = _normalize_field(values)
self.lengths._set_parent(self, 0)
self._items._set_parent(self, 1)
super(List, self).__init__([self.lengths, self._items])
def field_names(self):
value_fields = self._items.field_names()
return (
['lengths'] + [_join_field_name('values', v) for v in value_fields]
)
def field_types(self):
return self.lengths.field_types() + self._items.field_types()
def field_metadata(self):
return self.lengths.field_metadata() + self._items.field_metadata()
def field_blobs(self):
return self.lengths.field_blobs() + self._items.field_blobs()
def all_scalars(self):
return self.lengths.all_scalars() + self._items.all_scalars()
def has_blobs(self):
return self.lengths.has_blobs() and self._items.has_blobs()
def clone(self, keep_blobs=True):
return type(self)(
_normalize_field(self._items, keep_blobs=keep_blobs),
_normalize_field(self.lengths, keep_blobs=keep_blobs)
)
def _pprint_impl(self, indent, str_buffer):
str_buffer.write(' ' * indent + "List(\n")
str_buffer.write(' ' * (indent + 1) + "lengths=\n")
self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
str_buffer.write(' ' * (indent + 1) + "_items=\n")
self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
str_buffer.write(' ' * indent + ")\n")
def __getattr__(self, item):
"""If the value of this list is a struct,
allow to introspect directly into its fields."""
if item.startswith('__'):
raise AttributeError(item)
if isinstance(self._items, Struct):
return getattr(self._items, item)
elif item == 'value' or item == 'items':
return self._items
else:
raise AttributeError('Field not found in list: %s.' % item)
def __getitem__(self, item):
names = item.split(FIELD_SEPARATOR, 1)
if len(names) == 1:
if item == 'lengths':
return self.lengths
elif item == 'values':
return self._items
else:
if names[0] == 'values':
return self._items[names[1]]
raise KeyError('Field not found in list: %s.' % item)
class ListWithEvicted(List):
"""
This class is similar with List, but containing extra field evicted_values for
LRU Hashing.
"""
__slots__ = ("_evicted_values",)
def __init__(self, values, lengths_blob=None, evicted_values=None):
if isinstance(evicted_values, Field):
assert isinstance(evicted_values, Scalar)
self._evicted_values = _normalize_field(evicted_values)
else:
self._evicted_values = Scalar(np.int64, evicted_values)
super(ListWithEvicted, self).__init__(values, lengths_blob=lengths_blob)
def field_names(self):
value_fields = self._items.field_names()
return (
['lengths'] + [_join_field_name('values', v) for v in value_fields] + ["_evicted_values"]
)
def field_types(self):
return self.lengths.field_types() + self._items.field_types() + self._evicted_values.field_types()
def field_metadata(self):
return self.lengths.field_metadata() + self._items.field_metadata() + self._evicted_values.field_metadata()
def field_blobs(self):
return self.lengths.field_blobs() + self._items.field_blobs() + self._evicted_values.field_blobs()
def all_scalars(self):
return self.lengths.all_scalars() + self._items.all_scalars() + self._evicted_values.all_scalars()
def has_blobs(self):
return self.lengths.has_blobs() and self._items.has_blobs() + self._evicted_values.has_blobs()
def clone(self, keep_blobs=True):
return type(self)(
_normalize_field(self._items, keep_blobs=keep_blobs),
_normalize_field(self.lengths, keep_blobs=keep_blobs),
_normalize_field(self._evicted_values, keep_blobs=keep_blobs)
)
def _pprint_impl(self, indent, str_buffer):
str_buffer.write(' ' * indent + "ListWithEvicted(\n")
str_buffer.write(' ' * (indent + 1) + "lengths=\n")
self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
str_buffer.write(' ' * (indent + 1) + "_items=\n")
self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
str_buffer.write(' ' * (indent + 1) + "_evicted_values=\n")
self._evicted_values._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
str_buffer.write(' ' * indent + ")\n")
def __getattr__(self, item):
"""If the value of this list is a struct,
allow to introspect directly into its fields."""
if item.startswith('__'):
raise AttributeError(item)
if item == "_evicted_values":
return self._evicted_values
if isinstance(self._items, Struct):
return getattr(self._items, item)
Loading ...