Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / ATen / native / transformers / sdp_utils_cpp.h
Size: Mime:
#pragma once
#include <ATen/Context.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/grad_mode.h>
#include <ATen/native/DispatchStub.h>
#include <c10/core/ScalarType.h>

#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <c10/util/irange.h>

#include <c10/core/SymInt.h>
#include <c10/core/SymFloat.h>
#include <c10/util/string_view.h>
#include <c10/util/Array.h>
#include <cmath>
#include <cstdint>
#include <functional>

namespace sdp {

constexpr int32_t num_backends = 4;
enum class SDPBackend {
  error = -1,
  math = 0,
  flash_attention = 1,
  efficient_attention = 2,
  cudnn_attention = 3
};

// Note that if this changed make sure to update
// the templated enum in mem_eff/kernel_forward.h and mem_eff/kernel_backward.h
enum class CustomMaskType {
  NoCustomMask = 0,
  CausalFromTopLeft = 1,
  CausalFromBottomRight = 2,
  NumCustomMaskTypes,
};

struct sdp_params {
  at::Tensor query;
  at::Tensor key;
  at::Tensor value;
  std::optional<at::Tensor> attn_mask;
  double dropout;
  bool is_causal;
};

SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);

inline c10::SymFloat calculate_scale(
    const at::Tensor& query,
    std::optional<double> scale) {
  const auto softmax_scale = scale.has_value()
      ? scale.value()
      : (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt()));
  return c10::SymFloat(softmax_scale);
}

using c10::array_of;

inline bool input_requires_grad(sdp_params const& params) {
  const bool any_inputs_require_grad = params.query.requires_grad() ||
      params.key.requires_grad() || params.value.requires_grad();
  const bool gradmode_enabled = at::GradMode::is_enabled();
  return any_inputs_require_grad && gradmode_enabled;
}

inline bool has_for_nested_inputs(sdp_params const& params) {
  return
      (params.query.is_nested() && params.query.layout() == c10::kStrided) ||
      (params.key.is_nested() && params.key.layout() == c10::kStrided) ||
      (params.value.is_nested() && params.value.layout() == c10::kStrided);
}

inline bool has_for_dense_inputs(sdp_params const& params) {
  return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested();
}

inline bool has_only_dense_inputs(sdp_params const& params) {
  return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested();
}

template <typename dtype_vector>
inline bool check_tensor_dtype(
    sdp_params const& params,
    dtype_vector allowed_dtypes,
    bool debug) {
  auto query_dtype = params.query.dtype();
  if (!(query_dtype == params.key.dtype() &&
        query_dtype == params.value.dtype() &&
        (std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) !=
         allowed_dtypes.end()))) {
    if (debug) {
      TORCH_WARN(
          "Expected query, key and value to all be of dtype: {",
          c10::Join(", ", allowed_dtypes),
          "}. Got ",
          "Query dtype: ",
          params.query.dtype(),
          ", Key dtype: ",
          params.key.dtype(),
          ", and Value dtype: ",
          params.value.dtype(),
          " instead.");
    }
    return false;
  }
  return true;
}


inline bool try_broadcast_param_size(
    const c10::SymInt q_size,
    const c10::SymInt k_size,
    const c10::SymInt v_size,
    c10::string_view param_name,
    bool debug) {
  auto max_size = std::max({q_size, k_size, v_size});
  if ((q_size != max_size && q_size != 1) ||
      (k_size != max_size && k_size != 1) ||
      (v_size != max_size && v_size != 1)) {
    if (debug) {
      TORCH_WARN(
          "Both fused kernels require query, key and value to have broadcastable ",
          param_name,
          "got Query ",
          param_name,
          q_size,
          ", Key ",
          param_name,
          k_size,
          ", Value ",
          param_name,
          v_size,
          " instead.");
    }
    return false;
  }
  return true;
}

inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
    at::Tensor const& param,
    c10::string_view param_name,
    bool debug) {
  const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
  const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes();
  auto num_head_dims = nt_tensor_impl->opt_size(1);
  if (!num_head_dims.has_value()) {
    // num_head_dims is ragged
    if (debug) {
      TORCH_WARN(
          "Fused kernels do not support ragged num_head_dims, ",
          param_name,
          "has a ragged num_heads.");
    }
    return false;
  }

  auto* sizes_ptr = sizes.data_ptr<int64_t>();
  const int64_t n_tensors = param.size(0);
  const int64_t size_tensor_stride = sizes.stride(0);

  // This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
  for (const auto i : c10::irange(n_tensors)) {
    if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) {
      if (debug) {
        TORCH_WARN(
            "Fused kernels do not support seq_len == 0, ",
            param_name,
            "has a seq len of 0.");
      }
      return false;
    }
  }
  return true;
}

inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) {
  // When this function is called we are assured that the nt is dim==4
  bool q_is_safe = params.query.is_nested()
      ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
            params.query, "query ", debug)
      : true;
  // short circuit if any is unsafe
  if (!q_is_safe) {
    return false;
  }

  bool k_is_safe = params.key.is_nested()
      ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
            params.key, "key ", debug)
      : true;
  if (!k_is_safe) {
    return false;
  }

  bool v_is_safe = params.value.is_nested()
      ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
            params.value, "value ", debug)
      : true;
  if (!v_is_safe) {
    return false;
  }

  // We now know none of the inputs have ragged num_heads, so we can safely
  // access .size(1)
  auto q_num_heads = params.query.size(1);
  auto k_num_heads = params.key.size(1);
  auto v_num_heads = params.value.size(1);
  bool same_num_heads =
      q_num_heads == k_num_heads && q_num_heads == v_num_heads;

  if (!same_num_heads) {
    if (input_requires_grad(params)){
      if (debug) {
        TORCH_WARN(
              "Both fused kernels do not support training with broadcasted NT inputs.");
      }
      return false;
    }
    return try_broadcast_param_size(
        q_num_heads, k_num_heads, v_num_heads, "num heads ", debug);
  }

  return true;
}

inline bool check_nested_tensor(sdp_params const& params, bool debug) {
  // Return false if have nested tensor
  if (!has_only_dense_inputs(params)) {
    if (debug) {
      TORCH_WARN(
          "Both fused kernels of cpp version currently do not support Nested Tensor inputs.");
    }
    return false;
  }
  return true;
}

inline bool check_for_dropout(sdp_params const& params, bool debug) {
  if (params.dropout > 0.0) {
    if (debug) {
      TORCH_WARN("Both fused kernels do not support non-zero dropout.");
    }
    return false;
  }
  return true;
}

inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) {
  if (input_requires_grad(params)) {
    if (debug) {
      TORCH_WARN(
          "Memory efficient attention currently doesn't support training with NT inputs.");
    }
    return false;
  }
  return true;
}

inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
  if (params.attn_mask.has_value()) {
    if (debug) {
      TORCH_WARN("Flash Attention does not support non-null attn_mask.");
    }
    return false;
  }
  return true;
}

inline bool check_attn_mask_shape(sdp_params const& params, bool debug) {
  auto attn_mask = params.attn_mask;
  if (!attn_mask.has_value()) {
    return true;
  }
  if (attn_mask.value().requires_grad()) {
    return false;
  }
  auto batchSize = params.query.sym_size(0);
  auto qSize = params.query.sym_size(2);
  auto kvSize = params.key.sym_size(2);
  auto num_head = params.query.sym_size(1);
  if (attn_mask.value().sym_size(-2) != qSize && attn_mask.value().sym_size(-2) != 1) {
    return false;
  }
  if (attn_mask.value().sym_size(-1) != kvSize && attn_mask.value().sym_size(-1) != 1) {
    return false;
  }
  if (attn_mask.value().dim() == 2) {
    return true;
  } else if (attn_mask.value().dim() == 4) {
    if ((attn_mask.value().sym_size(0) == 1 || attn_mask.value().sym_size(0) == batchSize)
        && (attn_mask.value().sym_size(1) == 1 || attn_mask.value().sym_size(1) == num_head)) {
      return true;
    }
  }
  if (debug) {
    TORCH_WARN("Please use the following attn mask shapes: ",
        "2d - ({Q_seq_len, 1}  x {KV_seq_len, 1}); ",
        "4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1}  x {KV_seq_len, 1})");
  }
  return false;
}

inline bool check_tensor_shapes(sdp_params const& params, bool debug) {
  auto query_dim = params.query.dim();
  if (!(query_dim == params.key.dim() && query_dim == params.value.dim() &&
        (query_dim == 4))) {
    if (debug) {
      TORCH_WARN(
          "Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
          query_dim,
          ", Key dim: ",
          params.key.dim(),
          ", Value dim: ",
          params.value.dim(),
          " instead.");
    }
    return false;
  }
  return true;
}

inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
  const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
  auto seq_len = nt_tensor_impl->opt_size(2);
  if (!seq_len.has_value()) {
    if (debug) {
      TORCH_WARN(
          "For both fused kernels, if one of key/value batch_size requires "
          "broadcasting and the other does not, then the other must have a ",
          "consistent seq_len dim.")
    }
    return false;
  }
  return true;
}

inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
  // This is expected to be called after check_tensor_shapes ensuring that the
  // size() calls won't error since the inputs are all 4 dimensional

  auto q_batch_size = params.query.sym_size(0);
  auto k_batch_size = params.key.sym_size(0);
  auto v_batch_size = params.value.sym_size(0);

  bool same_batch_size =
      q_batch_size == k_batch_size && q_batch_size == v_batch_size;

  auto q_num_heads = params.query.sym_size(1);
  auto k_num_heads = params.key.sym_size(1);
  auto v_num_heads = params.value.sym_size(1);
  bool same_num_heads =
      q_num_heads == k_num_heads && q_num_heads == v_num_heads;

  if (!(same_batch_size && same_num_heads)) {
    if (debug) {
      TORCH_WARN(
          "For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. ",
          "Query.sizes(): ",
          params.query.sizes(),
          ", Key sizes(): ",
          params.key.sizes(),
          ", Value sizes(): ",
          params.value.sizes(),
          " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
    }
    return false;
  }
  return true;
}

inline bool check_batch_size_nested(sdp_params const& params, bool debug) {
  // This is expected to be called after check_tensor_shapes ensuring that the
  // size() calls won't error since the inputs are all 4 dimensional
  auto q_batch_size = params.query.sym_size(0);
  auto k_batch_size = params.key.sym_size(0);
  auto v_batch_size = params.value.sym_size(0);

  bool same_batch_size =
      q_batch_size == k_batch_size && q_batch_size == v_batch_size;

  // num_heads logic for nested input is checked in
  // check_for_seq_len_0_nested_tensor as there is handling there to make sure
  // num_heads is not ragged
  bool broadcastable_batch_size = true;
  if (!same_batch_size) {
    if (input_requires_grad(params)){
      if (debug) {
        TORCH_WARN(
            "Both fused kernels do not support training with broadcasted NT inputs.");
      }
      return false;
    }
    // try to broadcast batchsize
    broadcastable_batch_size = try_broadcast_param_size(
        q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);

    // if only one of k or v require broadcasting of batch size, the other
    // must have a consistent seq_len dim
    if (broadcastable_batch_size) {
      if (k_batch_size == 1 && v_batch_size != 1 &&
          !check_safe_kv_broadcast(params.value, debug)) {
        return false;
      }
      if (v_batch_size == 1 && k_batch_size != 1 &&
          !check_safe_kv_broadcast(params.key, debug)) {
        return false;
      }
    }
  }
  return broadcastable_batch_size;
}

inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) {
  // In some cases people will pass in 0 sized tensors, this will
  // cause the fused path to error with unaligned mask
  bool zero_seq_len_q = params.query.sym_size(-2) == 0;
  bool zero_seq_len_k = params.key.sym_size(-2) == 0;
  if (zero_seq_len_q || zero_seq_len_k) {
    if (debug) {
      TORCH_WARN(
          "Both fused kernels do not support zero seq_len_q or seq_len_kv.");
    }
    return false;
  }
  return true;
}

template<bool ignore_singleton_dim>
inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
  // The stride checking for NestedTensors is done within the kernel
  // And .contiguous will be called if needed

  // This function checks that the last dimension of the inputs to
  // fused_attention have stride 1
  bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
      params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;

  // https://github.com/pytorch/pytorch/issues/116333
  // If the head_dim is size 1 the stride won't matter, but we
  // check this condition before padding the head_dim to 1
  if (ignore_singleton_dim){
    qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
  }
  bool mask_stride_equal_1 = params.attn_mask.has_value()
      ? params.attn_mask.value().sym_stride(-1) == 1
      : true;
  if (!(qkv_strides_equal_1 && mask_stride_equal_1)) {
    if (debug) {
      std::ostringstream epilogue_message;
      if (params.attn_mask.has_value()) {
        epilogue_message << ", Attn_mask.stride(-1): "
                         << params.attn_mask.value().sym_stride(-1);
      }
      epilogue_message << " instead.";
      TORCH_WARN(
          "Both fused kernels require the last dimension of the input to have stride 1. ",
          "Got Query.stride(-1): ",
          params.query.sym_stride(-1),
          ", Key.stride(-1): ",
          params.key.sym_stride(-1),
          ", Value.stride(-1): ",
          params.value.sym_stride(-1),
          epilogue_message.str());
    }

    return false;
  }
  return true;
}

inline bool check_runtime_disabled_flash(sdp_params const& params, bool debug) {
  // We check the global context to see if user has explicitly turned of flash
  // sdp kernels
  if (!at::globalContext().userEnabledFlashSDP()) {
    if (debug) {
      TORCH_WARN("Flash attention has been runtime disabled.");
    }
    return false;
  }
  return true;
}

inline bool check_runtime_disabled_mem_efficient(sdp_params const& params, bool debug) {
  // We check the global context to see if user has explicitly turned of
  // mem_efficient sdp kernels
  if (!at::globalContext().userEnabledMemEfficientSDP()) {
    if (debug) {
      TORCH_WARN("Memory Efficient attention has been runtime disabled.");
    }
    return false;
  }
  return true;
}


} // namespace sdp