Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / perfkernels / lstm_unit_cpu-impl.h

#pragma once
#include <cmath>
#include "caffe2/utils/conversions.h"

#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG)
#if defined(__clang__) && (__clang_major__ > 7)
#define IS_SANITIZER                          \
  ((__has_feature(address_sanitizer) == 1) || \
   (__has_feature(memory_sanitizer) == 1) ||  \
   (__has_feature(thread_sanitizer) == 1) ||  \
   (__has_feature(undefined_sanitizer) == 1))

#if IS_SANITIZER == 0
#define VECTOR_LOOP _Pragma("clang loop vectorize(enable)")
#endif
#elif defined(_OPENMP) && (_OPENMP >= 201511)
// Support with OpenMP4.5 and above
#define VECTOR_LOOP _Pragma("omp for simd")
#endif
#endif

#ifndef VECTOR_LOOP
// Not supported
#define VECTOR_LOOP
#endif

namespace caffe2 {
namespace perfkernels {
namespace {
template <typename T>
inline T sigmoid(T x) {
  return 1 / (1 + std::exp(-x));
}

template <typename T>
inline T host_tanh(T x) {
  return 2 * sigmoid(2 * x) - 1;
}

template <typename T>
inline void LstmUnitImpl(
    const int N,
    const int D,
    const int t,
    const T* H_prev,
    const T* C_prev,
    const T* X,
    const int32_t* seqLengths,
    const bool drop_states,
    T* C,
    T* H,
    const float forget_bias) {
  const T forgetBias = convert::To<float, T>(forget_bias);
  for (int n = 0; n < N; ++n) {
    const bool valid = seqLengths == nullptr || t < seqLengths[n];
    if (!valid) {
      if (drop_states) {
        memset(H, 0, sizeof(T) * D);
        memset(C, 0, sizeof(T) * D);
      } else {
        memcpy(H, H_prev, sizeof(T) * D);
        memcpy(C, C_prev, sizeof(T) * D);
      }
    } else {
      const T* X_D = &X[D];
      const T* X_2D = &X[2 * D];
      const T* X_3D = &X[3 * D];
      VECTOR_LOOP for (int d = 0; d < D; ++d) {
        const T i = sigmoid(X[d]);
        const T f = sigmoid(X_D[d] + forgetBias);
        const T o = sigmoid(X_2D[d]);
        const T g = host_tanh(X_3D[d]);
        const T c_prev = C_prev[d];
        const T c = f * c_prev + i * g;
        C[d] = c;
        const T host_tanh_c = host_tanh(c);
        H[d] = o * host_tanh_c;
      }
    }
    H_prev += D;
    C_prev += D;
    X += 4 * D;
    C += D;
    H += D;
  }
}

template <typename T>
inline void LstmUnitGradientImpl(
    int N,
    int D,
    int t,
    const T* C_prev,
    const T* X,
    const int32_t* seqLengths,
    const T* C,
    const T* H,
    const T* C_diff,
    const T* H_diff,
    bool drop_states,
    T* H_prev_diff,
    T* C_prev_diff,
    T* X_diff,
    const float forget_bias) {
  const T localForgetBias = convert::To<float, T>(forget_bias);
  for (int n = 0; n < N; ++n) {
    const bool valid = seqLengths == nullptr || t < seqLengths[n];

    if (!valid) {
      if (drop_states) {
        memset(C_prev_diff, 0, sizeof(T) * D);
        memset(H_prev_diff, 0, sizeof(T) * D);
      } else {
        memcpy(H_prev_diff, H_diff, sizeof(T) * D);
        memcpy(C_prev_diff, C_diff, sizeof(T) * D);
      }
      memset(X_diff, 0, 4 * sizeof(T) * D);
    } else {
      VECTOR_LOOP for (int d = 0; d < D; ++d) {
        T* c_prev_diff = C_prev_diff + d;
        T* h_prev_diff = H_prev_diff + d;
        T* i_diff = X_diff + d;
        T* f_diff = X_diff + 1 * D + d;
        T* o_diff = X_diff + 2 * D + d;
        T* g_diff = X_diff + 3 * D + d;

        const T i = sigmoid(X[d]);
        const T f = sigmoid(X[1 * D + d] + localForgetBias);
        const T o = sigmoid(X[2 * D + d]);
        const T g = host_tanh(X[3 * D + d]);
        const T c_prev = C_prev[d];
        const T c = C[d];
        const T host_tanh_c = host_tanh(c);
        const T c_term_diff =
            C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c);
        *c_prev_diff = c_term_diff * f;
        *h_prev_diff = 0; // not used in 'valid' case
        *i_diff = c_term_diff * g * i * (1 - i);
        *f_diff = c_term_diff * c_prev * f * (1 - f);
        *o_diff = H_diff[d] * host_tanh_c * o * (1 - o);
        *g_diff = c_term_diff * i * (1 - g * g);
      }
    }
    C_prev += D;
    X += 4 * D;
    C += D;
    H += D;
    C_diff += D;
    H_diff += D;
    X_diff += 4 * D;
    H_prev_diff += D;
    C_prev_diff += D;
  }
}

} // namespace
} // namespace perfkernels
} // namespace caffe2