Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
keras / distribute / dataset_creator_model_fit_test_base.py
Size: Mime:
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `DatasetCreator` with `Model.fit` across usages and strategies."""

import os

import numpy as np
import tensorflow.compat.v2 as tf
from absl.testing import parameterized

import keras
from keras import callbacks as callbacks_lib
from keras.engine import sequential
from keras.layers import core as core_layers
from keras.layers.preprocessing import string_lookup
from keras.optimizers.legacy import gradient_descent
from keras.utils import dataset_creator

# isort: off
from tensorflow.python.platform import tf_logging as logging


class DatasetCreatorModelFitTestBase(tf.test.TestCase, parameterized.TestCase):
    """The base class for DatasetCreator with Model.fit tests."""

    def _get_dataset_fn(self, use_lookup_layer):

        if use_lookup_layer:

            filepath = os.path.join(self.get_temp_dir(), "vocab")
            with open(filepath, "w") as f:
                f.write("\n".join(["earth", "wind", "and", "fire"]))

            def dataset_fn(input_context):
                del input_context
                lookup_layer = string_lookup.StringLookup(
                    num_oov_indices=1, vocabulary=filepath
                )
                x = np.array(
                    [
                        ["earth", "wind", "and", "fire"],
                        ["fire", "and", "earth", "michigan"],
                    ]
                )
                y = np.array([0, 1])
                map_fn = lambda x, y: (lookup_layer(x), y)
                return (
                    tf.data.Dataset.from_tensor_slices((x, y))
                    .shuffle(10)
                    .repeat()
                    .batch(2)
                    .map(map_fn)
                )

        else:

            def dataset_fn(input_context):
                del input_context
                x = tf.random.uniform((10, 10))
                y = tf.random.uniform((10,))
                return (
                    tf.data.Dataset.from_tensor_slices((x, y))
                    .shuffle(10)
                    .repeat()
                    .batch(2)
                )

        return dataset_fn

    def _model_compile(
        self,
        strategy,
        steps_per_execution=1,
        run_eagerly=False,
        with_normalization_layer=False,
        jit_compile=None,
    ):
        class ResultAssertingCallback(callbacks_lib.Callback):
            """A callback that asserts the result of the tests."""

            def __init__(self):
                self._prev_epoch = -1

            def on_epoch_end(self, epoch, logs=None):
                logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs)
                if epoch <= self._prev_epoch:
                    raise RuntimeError(
                        "Epoch is supposed to be larger than previous."
                    )
                self._prev_epoch = epoch
                is_loss_float = logs.get(
                    "loss", None
                ) is not None and isinstance(logs["loss"], (float, np.floating))
                if not is_loss_float:
                    raise RuntimeError(
                        "loss is supposed to be in the logs and float."
                    )

        with strategy.scope():
            model = sequential.Sequential([core_layers.Dense(10)])
            if with_normalization_layer:
                norm = keras.layers.BatchNormalization(
                    axis=-1, input_shape=(4, 4, 3), momentum=0.8
                )
                model.add(norm)
            model.add(core_layers.Dense(1, activation="sigmoid"))
            self._accuracy_metric = keras.metrics.Accuracy()

        model.compile(
            gradient_descent.SGD(),
            loss="binary_crossentropy",
            metrics=[self._accuracy_metric],
            steps_per_execution=steps_per_execution,
            run_eagerly=run_eagerly,
            jit_compile=jit_compile,
        )
        return model, [ResultAssertingCallback()]

    def _model_fit(
        self,
        strategy,
        steps_per_execution=1,
        validation_data=None,
        x=None,
        y=None,
        shuffle=True,
        batch_size=None,
        steps_per_epoch=10,
        run_eagerly=False,
        with_normalization_layer=False,
        callbacks=None,
        use_lookup_layer=False,
        use_dataset_creator=True,
        verbose="auto",
        jit_compile=None,
    ):
        if callbacks is None:
            callbacks = []

        model, default_callbacks = self._model_compile(
            strategy,
            steps_per_execution,
            run_eagerly,
            with_normalization_layer,
            jit_compile,
        )
        callbacks += default_callbacks

        if x is None:
            if use_dataset_creator:
                x = dataset_creator.DatasetCreator(
                    self._get_dataset_fn(use_lookup_layer)
                )
            else:
                x = self._get_dataset_fn(use_lookup_layer)(None)

        if validation_data is None:
            if use_dataset_creator:
                validation_data = dataset_creator.DatasetCreator(
                    self._get_dataset_fn(use_lookup_layer)
                )
            else:
                validation_data = self._get_dataset_fn(use_lookup_layer)(None)

        model.fit(
            x,
            y,
            shuffle=shuffle,
            batch_size=batch_size,
            epochs=10,
            steps_per_epoch=steps_per_epoch,
            callbacks=callbacks,
            validation_data=validation_data,
            validation_steps=steps_per_epoch,
            verbose=verbose,
        )
        return model

    def _model_evaluate(
        self,
        strategy,
        steps_per_execution=1,
        x=None,
        y=None,
        batch_size=None,
        steps=10,
        run_eagerly=False,
        with_normalization_layer=False,
        callbacks=None,
        use_dataset_creator=True,
    ):
        if callbacks is None:
            callbacks = []

        model, default_callbacks = self._model_compile(
            strategy,
            steps_per_execution,
            run_eagerly,
            with_normalization_layer,
        )
        callbacks += default_callbacks

        def dataset_fn(input_context):
            del input_context
            x = tf.random.uniform((10, 10))
            y = tf.random.uniform((10, 1))
            return (
                tf.data.Dataset.from_tensor_slices((x, y))
                .shuffle(10)
                .repeat()
                .batch(8)
            )

        if x is None:
            if use_dataset_creator:
                x = dataset_creator.DatasetCreator(dataset_fn)
            else:
                x = dataset_fn(None)

        model.evaluate(
            x=x, y=y, steps=steps, callbacks=callbacks, batch_size=batch_size
        )
        return model

    def _model_predict(
        self,
        strategy,
        model=None,
        steps_per_execution=1,
        test_data=None,
        steps=10,
        with_normalization_layer=False,
    ):
        callbacks = []

        if model is None:
            model, default_callbacks = self._model_compile(
                strategy,
                steps_per_execution,
                with_normalization_layer=with_normalization_layer,
            )
            callbacks += default_callbacks

        def create_test_data():
            x = tf.constant([[1.0], [2.0], [3.0], [1.0], [5.0], [1.0]])
            return tf.data.Dataset.from_tensor_slices(x).repeat().batch(2)

        if test_data is None:
            test_data = create_test_data()

        predictions = model.predict(
            x=test_data, steps=steps, callbacks=callbacks
        )
        predictions = np.around(predictions, 4)
        return model, predictions