Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / caffe2 / utils / math / utils.h

#ifndef CAFFE2_UTILS_MATH_UTILS_H_
#define CAFFE2_UTILS_MATH_UTILS_H_

#include <vector>

#include "caffe2/core/common.h"

#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) || \
    defined(__HIP__) || (defined(__clang__) && defined(__CUDA__))
#define MATH_UTILS_DECL inline __host__ __device__
#else
#define MATH_UTILS_DECL inline
#endif

namespace caffe2 {
namespace math {

namespace utils {

template <typename T>
MATH_UTILS_DECL T Not(const T x) {
  return !x;
}

template <typename T>
MATH_UTILS_DECL T Sign(const T x) {
  return x > 0 ? T(1) : (x < 0 ? T(-1) : T(0));
}

template <typename T>
MATH_UTILS_DECL T Negate(const T x) {
  return -x;
}

template <typename T>
MATH_UTILS_DECL T Inv(const T x) {
  return T(1) / x;
}

template <typename T>
MATH_UTILS_DECL T Square(const T x) {
  return x * x;
}

template <typename T>
MATH_UTILS_DECL T Cube(const T x) {
  return x * x * x;
}

// Function uses casting from int to unsigned to compare if value of
// parameter a is greater or equal to zero and lower than value of
// parameter b. The b parameter is of type signed and is always
// positive,
// therefore its value is always lower than 0x800... where casting
// negative value of a parameter converts it to value higher than
// 0x800...
// The casting allows to use one condition instead of two.
MATH_UTILS_DECL bool IsAGeZeroAndALtB(const int a, const int b) {
  return static_cast<unsigned int>(a) < static_cast<unsigned int>(b);
}

// Increase the index digits by one based on dims.
template <typename TIndex>
TORCH_API void
IncreaseIndexInDims(int ndim, const TIndex* dims, TIndex* index);

// Get index value from dims and index digits.
template <typename TIndex>
TORCH_API TIndex
GetIndexFromDims(const int n, const TIndex* dims, const TIndex* index);

// Checks if the input permutation is an identity permutation;
TORCH_API bool IsIdentityPermutation(const int n, const int* perm);

TORCH_API bool
CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims);

TORCH_API bool IsRowwiseReduce(
    const int ndim,
    const int* X_dims,
    const int* Y_dims,
    int* rows,
    int* cols);

TORCH_API bool IsColwiseReduce(
    const int ndim,
    const int* X_dims,
    const int* Y_dims,
    int* rows,
    int* cols);

TORCH_API bool IsBothEndsReduce(
    const int ndim,
    const int* X_dims,
    const int* Y_dims,
    int* pre,
    int* mid,
    int* nxt);

// Computest the broadcast binary operation dims.
template <typename TIndex>
TORCH_API void ComputeBroadcastBinaryOpDims(
    const int A_ndim,
    const TIndex* A_dims,
    const int B_ndim,
    const TIndex* B_dims,
    TIndex* A_broadcast_dims,
    TIndex* B_broadcast_dims,
    TIndex* C_broadcast_dims);

TORCH_API bool IsRowwiseBroadcastBinaryOp(
    const int ndim,
    const int* A_dims,
    const int* B_dims,
    int* rows,
    int* cols,
    bool* broadcast_1st);

TORCH_API bool IsColwiseBroadcastBinaryOp(
    const int ndim,
    const int* A_dims,
    const int* B_dims,
    int* rows,
    int* cols,
    bool* broadcast_1st);

TORCH_API bool IsBothEndsBroadcastBinaryOp(
    const int ndim,
    const int* A_dims,
    const int* B_dims,
    int* pre,
    int* mid,
    int* nxt,
    bool* broadcast_1st);

TORCH_API bool IsBatchTranspose2D(const int ndim, const int* axes);

TORCH_API void ComputeTransposeAxesForReduceOp(
    const int num_dims,
    const int num_reduce_axes,
    const int* reduce_axes,
    int* transpose_axes);

TORCH_API void
ComputeTransposeAxesForReduceOp(const int ndim, const int* dims, int* axes);

template <typename TIndex>
TORCH_API void ComputeTransposedStrides(
    int ndim,
    const TIndex* dims,
    const int* axes,
    TIndex* strides);

} // namespace utils

// Calculates ceil(a / b). User must be careful to ensure that there
// is no overflow or underflow in the calculation.
template <typename T>
constexpr T DivUp(const T a, const T b) {
  return (a + b - T(1)) / b;
}

// Rounds a up to the next highest multiple of b. User must be careful
// to ensure that there is no overflow or underflow in the calculation
// of divUp.
template <typename T>
constexpr T RoundUp(const T a, const T b) {
  return DivUp<T>(a, b) * b;
}

// Returns log2(n) for a positive integer type
template <typename T>
constexpr int IntegerLog2(T n, int p = 0) {
  return (n <= 1) ? p : IntegerLog2(n / 2, p + 1);
}

// Returns the next highest power-of-2 for an integer type
template <typename T>
constexpr T IntegerNextHighestPowerOf2(T v) {
  return (IntegerIsPowerOf2(v) ? T(2) * v : (T(1) << (IntegerLog2(v) + 1)));
}

} // namespace math
} // namespace caffe2

#endif // CAFFE2_UTILS_MATH_UTILS_H_