Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

Version: 19.0.0.dev246 

/ src / arrow / python / numpy_internal.h

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

// Internal utilities for dealing with NumPy

#pragma once

#include "arrow/python/numpy_init.h"
#include "arrow/python/numpy_interop.h"

#include "arrow/status.h"

#include "arrow/python/platform.h"

#include <cstdint>
#include <sstream>
#include <string>

namespace arrow {
namespace py {

/// Indexing convenience for interacting with strided 1-dim ndarray objects
template <typename T>
class Ndarray1DIndexer {
 public:
  typedef int64_t size_type;

  Ndarray1DIndexer() : arr_(NULLPTR), data_(NULLPTR) {}

  explicit Ndarray1DIndexer(PyArrayObject* arr) : Ndarray1DIndexer() {
    arr_ = arr;
    DCHECK_EQ(1, PyArray_NDIM(arr)) << "Only works with 1-dimensional arrays";
    data_ = reinterpret_cast<uint8_t*>(PyArray_DATA(arr));
    stride_ = PyArray_STRIDES(arr)[0];
  }

  ~Ndarray1DIndexer() = default;

  int64_t size() const { return PyArray_SIZE(arr_); }

  const T* data() const { return reinterpret_cast<const T*>(data_); }

  bool is_strided() const { return stride_ != sizeof(T); }

  T& operator[](size_type index) {
    return *reinterpret_cast<T*>(data_ + index * stride_);
  }
  const T& operator[](size_type index) const {
    return *reinterpret_cast<const T*>(data_ + index * stride_);
  }

 private:
  PyArrayObject* arr_;
  uint8_t* data_;
  int64_t stride_;
};

// Handling of Numpy Types by their static numbers
// (the NPY_TYPES enum and related defines)

static inline std::string GetNumPyTypeName(int npy_type) {
#define TYPE_CASE(TYPE, NAME) \
  case NPY_##TYPE:            \
    return NAME;

  switch (npy_type) {
    TYPE_CASE(BOOL, "bool")
    TYPE_CASE(INT8, "int8")
    TYPE_CASE(INT16, "int16")
    TYPE_CASE(INT32, "int32")
    TYPE_CASE(INT64, "int64")
#if !NPY_INT32_IS_INT
    TYPE_CASE(INT, "intc")
#endif
#if !NPY_INT64_IS_LONG_LONG
    TYPE_CASE(LONGLONG, "longlong")
#endif
    TYPE_CASE(UINT8, "uint8")
    TYPE_CASE(UINT16, "uint16")
    TYPE_CASE(UINT32, "uint32")
    TYPE_CASE(UINT64, "uint64")
#if !NPY_INT32_IS_INT
    TYPE_CASE(UINT, "uintc")
#endif
#if !NPY_INT64_IS_LONG_LONG
    TYPE_CASE(ULONGLONG, "ulonglong")
#endif
    TYPE_CASE(FLOAT16, "float16")
    TYPE_CASE(FLOAT32, "float32")
    TYPE_CASE(FLOAT64, "float64")
    TYPE_CASE(DATETIME, "datetime64")
    TYPE_CASE(TIMEDELTA, "timedelta64")
    TYPE_CASE(OBJECT, "object")
    TYPE_CASE(VOID, "void")
    default:
      break;
  }

#undef TYPE_CASE
  std::stringstream ss;
  ss << "unrecognized type (" << npy_type << ") in GetNumPyTypeName";
  return ss.str();
}

#define TYPE_VISIT_INLINE(TYPE) \
  case NPY_##TYPE:              \
    return visitor->template Visit<NPY_##TYPE>(arr);

template <typename VISITOR>
inline Status VisitNumpyArrayInline(PyArrayObject* arr, VISITOR* visitor) {
  switch (PyArray_TYPE(arr)) {
    TYPE_VISIT_INLINE(BOOL);
    TYPE_VISIT_INLINE(INT8);
    TYPE_VISIT_INLINE(UINT8);
    TYPE_VISIT_INLINE(INT16);
    TYPE_VISIT_INLINE(UINT16);
    TYPE_VISIT_INLINE(INT32);
    TYPE_VISIT_INLINE(UINT32);
    TYPE_VISIT_INLINE(INT64);
    TYPE_VISIT_INLINE(UINT64);
#if !NPY_INT32_IS_INT
    TYPE_VISIT_INLINE(INT);
    TYPE_VISIT_INLINE(UINT);
#endif
#if !NPY_INT64_IS_LONG_LONG
    TYPE_VISIT_INLINE(LONGLONG);
    TYPE_VISIT_INLINE(ULONGLONG);
#endif
    TYPE_VISIT_INLINE(FLOAT16);
    TYPE_VISIT_INLINE(FLOAT32);
    TYPE_VISIT_INLINE(FLOAT64);
    TYPE_VISIT_INLINE(DATETIME);
    TYPE_VISIT_INLINE(TIMEDELTA);
    TYPE_VISIT_INLINE(OBJECT);
  }
  return Status::NotImplemented("NumPy type not implemented: ",
                                GetNumPyTypeName(PyArray_TYPE(arr)));
}

#undef TYPE_VISIT_INLINE

namespace internal {

inline bool PyFloatScalar_Check(PyObject* obj) {
  if (has_numpy()) {
    return PyFloat_Check(obj) || PyArray_IsScalar(obj, Floating);
  } else {
    return PyFloat_Check(obj);
  }
}

inline bool PyIntScalar_Check(PyObject* obj) {
  if (has_numpy()) {
    return PyLong_Check(obj) || PyArray_IsScalar(obj, Integer);
  } else {
    return PyLong_Check(obj);
  }
}

inline bool PyBoolScalar_Check(PyObject* obj) {
  if (has_numpy()) {
    return PyBool_Check(obj) || PyArray_IsScalar(obj, Bool);
  } else {
    return PyBool_Check(obj);
  }
}

static inline PyArray_Descr* GetSafeNumPyDtype(int type) {
  if (type == NPY_DATETIME || type == NPY_TIMEDELTA) {
    // It is not safe to mutate the result of DescrFromType for datetime and
    // timedelta descriptors
    return PyArray_DescrNewFromType(type);
  } else {
    return PyArray_DescrFromType(type);
  }
}

}  // namespace internal

}  // namespace py
}  // namespace arrow