Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
keras / distribute / keras_image_model_correctness_test.py
Size: Mime:
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Correctness tests for tf.keras CNN models using DistributionStrategy."""

import numpy as np
import tensorflow.compat.v2 as tf

import keras
from keras.distribute import keras_correctness_test_base
from keras.optimizers.legacy import gradient_descent
from keras.testing_infra import test_utils


@test_utils.run_all_without_tensor_float_32(
    "Uses Dense layers, which call matmul. Even if Dense layers run in "
    "float64, the test sometimes fails with TensorFloat-32 enabled for unknown "
    "reasons"
)
@test_utils.run_v2_only()
class DistributionStrategyCnnCorrectnessTest(
    keras_correctness_test_base.TestDistributionStrategyCorrectnessBase
):
    def get_model(
        self, initial_weights=None, distribution=None, input_shapes=None
    ):
        del input_shapes
        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            image = keras.layers.Input(shape=(28, 28, 3), name="image")
            c1 = keras.layers.Conv2D(
                name="conv1",
                filters=16,
                kernel_size=(3, 3),
                strides=(4, 4),
                kernel_regularizer=keras.regularizers.l2(1e-4),
            )(image)
            if self.with_batch_norm == "regular":
                c1 = keras.layers.BatchNormalization(name="bn1")(c1)
            elif self.with_batch_norm == "sync":
                # Test with parallel batch norms to verify all-reduce works OK.
                bn1 = keras.layers.BatchNormalization(
                    name="bn1", synchronized=True
                )(c1)
                bn2 = keras.layers.BatchNormalization(
                    name="bn2", synchronized=True
                )(c1)
                c1 = keras.layers.Add()([bn1, bn2])
            c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
            logits = keras.layers.Dense(10, activation="softmax", name="pred")(
                keras.layers.Flatten()(c1)
            )
            model = keras.Model(inputs=[image], outputs=[logits])

            if initial_weights:
                model.set_weights(initial_weights)

            model.compile(
                optimizer=gradient_descent.SGD(learning_rate=0.1),
                loss="sparse_categorical_crossentropy",
                metrics=["sparse_categorical_accuracy"],
            )

        return model

    def _get_data(self, count, shape=(28, 28, 3), num_classes=10):
        centers = np.random.randn(num_classes, *shape)

        features = []
        labels = []
        for _ in range(count):
            label = np.random.randint(0, num_classes, size=1)[0]
            offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape))
            offset = offset.reshape(shape)
            labels.append(label)
            features.append(centers[label] + offset)

        x = np.asarray(features, dtype=np.float32)
        y = np.asarray(labels, dtype=np.float32).reshape((count, 1))
        return x, y

    def get_data(self):
        x_train, y_train = self._get_data(
            count=keras_correctness_test_base._GLOBAL_BATCH_SIZE
            * keras_correctness_test_base._EVAL_STEPS
        )
        x_predict = x_train
        return x_train, y_train, x_predict

    def get_data_with_partial_last_batch_eval(self):
        x_train, y_train = self._get_data(count=1280)
        x_eval, y_eval = self._get_data(count=1000)
        return x_train, y_train, x_eval, y_eval, x_eval

    @tf.__internal__.distribute.combinations.generate(
        keras_correctness_test_base.all_strategy_and_input_config_combinations()
        + keras_correctness_test_base.multi_worker_mirrored_eager()
    )
    def test_cnn_correctness(
        self, distribution, use_numpy, use_validation_data
    ):
        if (
            distribution
            == tf.__internal__.distribute.combinations.central_storage_strategy_with_gpu_and_cpu  # noqa: E501
        ):
            self.skipTest("b/183958183")
        self.run_correctness_test(distribution, use_numpy, use_validation_data)

    @tf.__internal__.distribute.combinations.generate(
        keras_correctness_test_base.all_strategy_and_input_config_combinations()
        + keras_correctness_test_base.multi_worker_mirrored_eager()
    )
    def test_cnn_with_batch_norm_correctness(
        self, distribution, use_numpy, use_validation_data
    ):
        self.run_correctness_test(
            distribution,
            use_numpy,
            use_validation_data,
            with_batch_norm="regular",
        )

    @tf.__internal__.distribute.combinations.generate(
        keras_correctness_test_base.all_strategy_and_input_config_combinations()
        + keras_correctness_test_base.multi_worker_mirrored_eager()
    )
    def test_cnn_with_sync_batch_norm_correctness(
        self, distribution, use_numpy, use_validation_data
    ):
        if not tf.executing_eagerly():
            self.skipTest(
                "BatchNorm with `synchronized` is not enabled in graph mode."
            )
        self.run_correctness_test(
            distribution, use_numpy, use_validation_data, with_batch_norm="sync"
        )

    @tf.__internal__.distribute.combinations.generate(
        keras_correctness_test_base.all_strategy_and_input_config_combinations_eager()  # noqa: E501
        + keras_correctness_test_base.multi_worker_mirrored_eager()
        + keras_correctness_test_base.test_combinations_with_tpu_strategies_graph()  # noqa: E501
    )
    def test_cnn_correctness_with_partial_last_batch_eval(
        self, distribution, use_numpy, use_validation_data
    ):
        self.run_correctness_test(
            distribution,
            use_numpy,
            use_validation_data,
            partial_last_batch=True,
            training_epochs=1,
        )

    @tf.__internal__.distribute.combinations.generate(
        keras_correctness_test_base.all_strategy_and_input_config_combinations_eager()  # noqa: E501
        + keras_correctness_test_base.multi_worker_mirrored_eager()
        + keras_correctness_test_base.test_combinations_with_tpu_strategies_graph()  # noqa: E501
    )
    def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval(
        self, distribution, use_numpy, use_validation_data
    ):
        self.run_correctness_test(
            distribution,
            use_numpy,
            use_validation_data,
            with_batch_norm="regular",
            partial_last_batch=True,
        )


if __name__ == "__main__":
    tf.__internal__.distribute.multi_process_runner.test_main()