Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reuters topic classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json

import numpy as np

from tensorflow.python.keras.preprocessing.sequence import _remove_long_seq
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export


@keras_export('keras.datasets.reuters.load_data')
def load_data(path='reuters.npz',
              num_words=None,
              skip_top=0,
              maxlen=None,
              test_split=0.2,
              seed=113,
              start_char=1,
              oov_char=2,
              index_from=3,
              **kwargs):
  """Loads the Reuters newswire classification dataset.

  Arguments:
      path: where to cache the data (relative to `~/.keras/dataset`).
      num_words: max number of words to include. Words are ranked
          by how often they occur (in the training set) and only
          the most frequent words are kept
      skip_top: skip the top N most frequently occurring words
          (which may not be informative).
      maxlen: truncate sequences after this length.
      test_split: Fraction of the dataset to be used as test data.
      seed: random seed for sample shuffling.
      start_char: The start of a sequence will be marked with this character.
          Set to 1 because 0 is usually the padding character.
      oov_char: words that were cut out because of the `num_words`
          or `skip_top` limit will be replaced with this character.
      index_from: index actual words with this index and higher.
      **kwargs: Used for backwards compatibility.

  Returns:
      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

  Note that the 'out of vocabulary' character is only used for
  words that were present in the training set but are not included
  because they're not making the `num_words` cut here.
  Words that were not seen in the training set but are in the test set
  have simply been skipped.
  """
  # Legacy support
  if 'nb_words' in kwargs:
    logging.warning('The `nb_words` argument in `load_data` '
                    'has been renamed `num_words`.')
    num_words = kwargs.pop('nb_words')
  if kwargs:
    raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))

  origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  path = get_file(
      path,
      origin=origin_folder + 'reuters.npz',
      file_hash='87aedbeb0cb229e378797a632c1997b6')
  with np.load(path) as f:
    xs, labels = f['x'], f['y']

  np.random.seed(seed)
  indices = np.arange(len(xs))
  np.random.shuffle(indices)
  xs = xs[indices]
  labels = labels[indices]

  if start_char is not None:
    xs = [[start_char] + [w + index_from for w in x] for x in xs]
  elif index_from:
    xs = [[w + index_from for w in x] for x in xs]

  if maxlen:
    xs, labels = _remove_long_seq(maxlen, xs, labels)

  if not num_words:
    num_words = max([max(x) for x in xs])

  # by convention, use 2 as OOV word
  # reserve 'index_from' (=3 by default) characters:
  # 0 (padding), 1 (start), 2 (OOV)
  if oov_char is not None:
    xs = [[w if skip_top <= w < num_words else oov_char for w in x] for x in xs]
  else:
    xs = [[w for w in x if skip_top <= w < num_words] for x in xs]

  idx = int(len(xs) * (1 - test_split))
  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])

  return (x_train, y_train), (x_test, y_test)


@keras_export('keras.datasets.reuters.get_word_index')
def get_word_index(path='reuters_word_index.json'):
  """Retrieves the dictionary mapping word indices back to words.

  Arguments:
      path: where to cache the data (relative to `~/.keras/dataset`).

  Returns:
      The word index dictionary.
  """
  origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  path = get_file(
      path,
      origin=origin_folder + 'reuters_word_index.json',
      file_hash='4d44cc38712099c9e383dc6e5f11a921')
  with open(path) as f:
    return json.load(f)