Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pyarrow / include / arrow / testing / generator.h
Size: Mime:
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include "arrow/array/array_base.h"
#include "arrow/array/util.h"
#include "arrow/buffer_builder.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/visibility.h"
#include "arrow/type_fwd.h"

namespace arrow {

class ARROW_TESTING_EXPORT ConstantArrayGenerator {
 public:
  /// \brief Generates a constant BooleanArray
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Boolean(int64_t size, bool value = false);

  /// \brief Generates a constant UInt8Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> UInt8(int64_t size, uint8_t value = 0);

  /// \brief Generates a constant Int8Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Int8(int64_t size, int8_t value = 0);

  /// \brief Generates a constant UInt16Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> UInt16(int64_t size, uint16_t value = 0);

  /// \brief Generates a constant UInt16Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Int16(int64_t size, int16_t value = 0);

  /// \brief Generates a constant UInt32Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> UInt32(int64_t size, uint32_t value = 0);

  /// \brief Generates a constant UInt32Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Int32(int64_t size, int32_t value = 0);

  /// \brief Generates a constant UInt64Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> UInt64(int64_t size, uint64_t value = 0);

  /// \brief Generates a constant UInt64Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Int64(int64_t size, int64_t value = 0);

  /// \brief Generates a constant Float16Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Float16(int64_t size, uint16_t value = 0);

  /// \brief Generates a constant Float32Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Float32(int64_t size, float value = 0);

  /// \brief Generates a constant Float64Array
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Float64(int64_t size, double value = 0);

  /// \brief Generates a constant StringArray
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] value to repeat
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> String(int64_t size, std::string value = "");

  template <typename ArrowType, typename CType = typename ArrowType::c_type>
  static std::shared_ptr<Array> Numeric(int64_t size, CType value = 0) {
    switch (ArrowType::type_id) {
      case Type::BOOL:
        return Boolean(size, static_cast<bool>(value));
      case Type::UINT8:
        return UInt8(size, static_cast<uint8_t>(value));
      case Type::INT8:
        return Int8(size, static_cast<int8_t>(value));
      case Type::UINT16:
        return UInt16(size, static_cast<uint16_t>(value));
      case Type::INT16:
        return Int16(size, static_cast<int16_t>(value));
      case Type::UINT32:
        return UInt32(size, static_cast<uint32_t>(value));
      case Type::INT32:
        return Int32(size, static_cast<int32_t>(value));
      case Type::UINT64:
        return UInt64(size, static_cast<uint64_t>(value));
      case Type::INT64:
        return Int64(size, static_cast<int64_t>(value));
      case Type::HALF_FLOAT:
        return Float16(size, static_cast<uint16_t>(value));
      case Type::FLOAT:
        return Float32(size, static_cast<float>(value));
      case Type::DOUBLE:
        return Float64(size, static_cast<double>(value));
      case Type::INTERVAL_DAY_TIME:
      case Type::DATE32: {
        EXPECT_OK_AND_ASSIGN(auto viewed,
                             Int32(size, static_cast<uint32_t>(value))->View(date32()));
        return viewed;
      }
      case Type::INTERVAL_MONTHS: {
        EXPECT_OK_AND_ASSIGN(auto viewed,
                             Int32(size, static_cast<uint32_t>(value))
                                 ->View(std::make_shared<MonthIntervalType>()));
        return viewed;
      }
      case Type::TIME32: {
        EXPECT_OK_AND_ASSIGN(auto viewed,
                             Int32(size, static_cast<uint32_t>(value))
                                 ->View(std::make_shared<Time32Type>(TimeUnit::SECOND)));
        return viewed;
      }
      case Type::TIME64: {
        EXPECT_OK_AND_ASSIGN(auto viewed, Int64(size, static_cast<uint64_t>(value))
                                              ->View(std::make_shared<Time64Type>()));
        return viewed;
      }
      case Type::DATE64: {
        EXPECT_OK_AND_ASSIGN(auto viewed,
                             Int64(size, static_cast<uint64_t>(value))->View(date64()));
        return viewed;
      }
      case Type::TIMESTAMP: {
        EXPECT_OK_AND_ASSIGN(
            auto viewed, Int64(size, static_cast<int64_t>(value))
                             ->View(std::make_shared<TimestampType>(TimeUnit::SECOND)));
        return viewed;
      }
      default:
        return nullptr;
    }
  }

  /// \brief Generates a constant Array of zeroes
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] type the type of the Array
  ///
  /// \return a generated Array
  static std::shared_ptr<Array> Zeroes(int64_t size,
                                       const std::shared_ptr<DataType>& type);

  /// \brief Generates a RecordBatch of zeroes
  ///
  /// \param[in] size the size of the array to generate
  /// \param[in] schema to conform to
  ///
  /// This function is handy to return of RecordBatch of a desired shape.
  ///
  /// \return a generated RecordBatch
  static std::shared_ptr<RecordBatch> Zeroes(int64_t size,
                                             const std::shared_ptr<Schema>& schema);

  /// \brief Generates a RecordBatchReader by repeating a RecordBatch
  ///
  /// \param[in] n_batch the number of times it repeats batch
  /// \param[in] batch the RecordBatch to repeat
  ///
  /// \return a generated RecordBatchReader
  static std::shared_ptr<RecordBatchReader> Repeat(
      int64_t n_batch, const std::shared_ptr<RecordBatch> batch);

  /// \brief Generates a RecordBatchReader of zeroes batches
  ///
  /// \param[in] n_batch the number of RecordBatch
  /// \param[in] batch_size the size of each RecordBatch
  /// \param[in] schema to conform to
  ///
  /// \return a generated RecordBatchReader
  static std::shared_ptr<RecordBatchReader> Zeroes(int64_t n_batch, int64_t batch_size,
                                                   const std::shared_ptr<Schema>& schema);
};

ARROW_TESTING_EXPORT
Result<std::shared_ptr<Array>> ScalarVectorToArray(const ScalarVector& scalars);

namespace gen {

class ARROW_TESTING_EXPORT ArrayGenerator {
 public:
  virtual ~ArrayGenerator() = default;
  virtual Result<std::shared_ptr<Array>> Generate(int64_t num_rows) = 0;
  virtual std::shared_ptr<DataType> type() const = 0;
};

// Same as DataGenerator below but instead of returning Result an ok status is EXPECT'd
class ARROW_TESTING_EXPORT GTestDataGenerator {
 public:
  virtual ~GTestDataGenerator() = default;
  virtual std::shared_ptr<::arrow::RecordBatch> RecordBatch(int64_t num_rows) = 0;
  virtual std::vector<std::shared_ptr<::arrow::RecordBatch>> RecordBatches(
      int64_t rows_per_batch, int num_batches) = 0;

  virtual ::arrow::compute::ExecBatch ExecBatch(int64_t num_rows) = 0;
  virtual std::vector<::arrow::compute::ExecBatch> ExecBatches(int64_t rows_per_batch,
                                                               int num_batches) = 0;

  virtual std::shared_ptr<::arrow::Table> Table(int64_t rows_per_chunk,
                                                int num_chunks = 1) = 0;
  virtual std::shared_ptr<::arrow::Schema> Schema() = 0;
};

class ARROW_TESTING_EXPORT DataGenerator {
 public:
  virtual ~DataGenerator() = default;
  virtual Result<std::shared_ptr<::arrow::RecordBatch>> RecordBatch(int64_t num_rows) = 0;
  virtual Result<std::vector<std::shared_ptr<::arrow::RecordBatch>>> RecordBatches(
      int64_t rows_per_batch, int num_batches) = 0;

  virtual Result<::arrow::compute::ExecBatch> ExecBatch(int64_t num_rows) = 0;
  virtual Result<std::vector<::arrow::compute::ExecBatch>> ExecBatches(
      int64_t rows_per_batch, int num_batches) = 0;

  virtual Result<std::shared_ptr<::arrow::Table>> Table(int64_t rows_per_chunk,
                                                        int num_chunks = 1) = 0;
  virtual std::shared_ptr<::arrow::Schema> Schema() = 0;
  /// @brief Converts this generator to a variant that fails (in a googletest sense)
  ///        if any error is encountered.
  virtual std::unique_ptr<GTestDataGenerator> FailOnError() = 0;
};

/// @brief A potentially named field
///
/// If name is not specified then a name will be generated automatically (e.g. f0, f1)
struct ARROW_TESTING_EXPORT GeneratorField {
 public:
  GeneratorField(std::shared_ptr<ArrayGenerator> gen)  // NOLINT implicit conversion
      : name(), gen(std::move(gen)) {}
  GeneratorField(std::string name, std::shared_ptr<ArrayGenerator> gen)
      : name(std::move(name)), gen(std::move(gen)) {}

  std::optional<std::string> name;
  std::shared_ptr<ArrayGenerator> gen;
};

/// Create a table generator with the given fields
ARROW_TESTING_EXPORT std::shared_ptr<DataGenerator> Gen(
    std::vector<GeneratorField> column_gens);

/// make a generator that returns a constant value
ARROW_TESTING_EXPORT std::shared_ptr<ArrayGenerator> Constant(
    std::shared_ptr<Scalar> value);

/// make a generator that returns an incrementing value
///
/// Note: overflow is not prevented standard unsigned integer overflow applies
template <typename T = uint32_t>
std::shared_ptr<ArrayGenerator> Step(T start = 0, T step = 1) {
  class StepGenerator : public ArrayGenerator {
   public:
    // Use [[maybe_unused]] to avoid a compiler warning in Clang versions before 15 that
    // incorrectly reports 'unused type alias'.
    using ArrowType [[maybe_unused]] = typename CTypeTraits<T>::ArrowType;
    static_assert(is_number_type<ArrowType>::value,
                  "Step generator only supports numeric types");

    StepGenerator(T start, T step) : start_(start), step_(step) {}

    Result<std::shared_ptr<Array>> Generate(int64_t num_rows) override {
      TypedBufferBuilder<T> builder;
      ARROW_RETURN_NOT_OK(builder.Reserve(num_rows));
      T val = start_;
      for (int64_t i = 0; i < num_rows; i++) {
        builder.UnsafeAppend(val);
        val += step_;
      }
      start_ = val;
      ARROW_ASSIGN_OR_RAISE(auto buf, builder.Finish());
      return MakeArray(ArrayData::Make(TypeTraits<ArrowType>::type_singleton(), num_rows,
                                       {NULLPTR, std::move(buf)}, /*null_count=*/0));
    }

    std::shared_ptr<DataType> type() const override {
      return TypeTraits<ArrowType>::type_singleton();
    }

   private:
    T start_;
    T step_;
  };

  return std::make_shared<StepGenerator>(start, step);
}

/// make a generator that returns a random value
ARROW_TESTING_EXPORT std::shared_ptr<ArrayGenerator> Random(
    std::shared_ptr<DataType> type);
/// TODO(if-needed) could add a repeat-scalars generator, e.g. Repeat({1, 2, 3}) for
/// 1,2,3,1,2,3,1
///
/// TODO(if-needed) could add a repeat-from-json generator e.g. Repeat(int32(), "[1, 2,
/// 3]")), same behavior as repeat-scalars

}  // namespace gen

}  // namespace arrow