Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / tensorexpr / tensor.h

#pragma once

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <functional>
#include <vector>

#include <torch/csrc/jit/tensorexpr/dim_arg.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/reduction.h>

namespace torch {
namespace jit {
namespace tensorexpr {

class TORCH_API Tensor : KernelScopedObject {
 public:
  Tensor(const Buf* buf, const std::vector<const Var*>& args, const Expr* body)
      : buf_(buf) {
    stmt_ = constructStmt(args, body, {}, {});
  }

  Tensor(
      const Buf* buf,
      const std::vector<const Var*>& args,
      const std::vector<const Expr*>& reduce_dims,
      const std::vector<const Var*>& reduce_args,
      const Expr* body)
      : buf_(buf) {
    stmt_ = constructStmt(args, body, reduce_dims, reduce_args);
  }

  Tensor(const Buf* buf, Stmt* stmt) : buf_(buf), stmt_(stmt) {}

  const Buf* buf() const {
    return buf_;
  }

  Stmt* stmt() const {
    return stmt_;
  }

  template <typename... Ts>
  inline ExprHandle operator()(const Ts&... ts);
  template <typename T>
  inline ExprHandle call(const std::vector<T>& args);
  template <typename... Ts>
  inline ExprHandle call(const Ts&... ts);

 private:
  Stmt* constructStmt(
      const std::vector<const Var*>& args,
      const Expr* body,
      const std::vector<const Expr*>& reduce_dims,
      const std::vector<const Var*>& reduce_args) const;

  const Buf* buf_;
  Stmt* stmt_;
};

class Placeholder {
 public:
  Placeholder(const BufHandle& data) : data_(data.node()) {
    if (data_->base_handle()->dtype() != kHandle) {
      throw malformed_input("Placeholder dtype must be Handle");
    }

    std::vector<ExprHandle> stride_handles(ndim());
    for (int i = (int)ndim() - 1; i >= 0; i--) {
      if (i == ndim() - 1) {
        stride_handles[i] = 1;
      } else {
        stride_handles[i] = stride_handles[i + 1] * ExprHandle(dim(i + 1));
      }
    }
    strides_ = ExprHandleVectorToExprVector(stride_handles);
  }
  Placeholder(
      const std::string& name,
      const Dtype& dtype,
      const std::vector<ExprHandle>& dims)
      : Placeholder(BufHandle(name, dims, dtype)) {}

  const Buf* data() const {
    return data_;
  }
  Dtype dtype() const {
    return data_->dtype();
  }
  int ndim() const {
    return data_->ndim();
  }
  const Expr* dim(int index) const {
    return data_->dim(index);
  }
  std::vector<const Expr*> dims() const {
    return data_->dims();
  }

  template <typename... Ts>
  inline ExprHandle load(const Ts&... ts) const;

  template <typename T>
  inline ExprHandle load(const std::vector<T>& args) const;

  inline ExprHandle loadWithMask(
      const std::vector<ExprHandle>& args,
      const ExprHandle& mask) const {
    return ExprHandle(
        new Load(data(), ExprHandleVectorToExprVector(args), mask.node()));
  }

  inline Store* store(
      const std::vector<ExprHandle>& args,
      const ExprHandle& val) const {
    return new Store(
        data(), ExprHandleVectorToExprVector(args), val.node(), new IntImm(1));
  }

  inline Store* storeWithMask(
      const std::vector<ExprHandle>& args,
      const ExprHandle& val,
      const ExprHandle& mask) const {
    return new Store(
        data(), ExprHandleVectorToExprVector(args), val.node(), mask.node());
  }

 private:
  const Buf* data_;
  std::vector<const Expr*> strides_;
};

TORCH_API Tensor* Compute(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const std::function<ExprHandle(const VarHandle&)>& body_func);
TORCH_API Tensor* Compute(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
        body_func);
TORCH_API Tensor* Compute(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const std::function<
        ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
        body_func);
TORCH_API Tensor* Compute(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const std::function<ExprHandle(
        const VarHandle&,
        const VarHandle&,
        const VarHandle&,
        const VarHandle&)>& body_func);
TORCH_API Tensor* Compute(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);

inline void unpack_dim_args(
    const std::vector<DimArg>& dim_args,
    std::vector<const Expr*>* dims,
    std::vector<const Var*>* vars) {
  dims->clear();
  vars->clear();
  for (const DimArg& dim_arg : dim_args) {
    dims->push_back(dim_arg.dim().node());
    vars->push_back(new Var(dim_arg.name_hint(), kInt));
  }
}

// Handle reductions over a Reducer and a body_func which produces values.
template <typename BodyFunc>
Tensor* Reduce(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const Reducer& reducer,
    const BodyFunc& body_func,
    const std::vector<DimArg>& reduce_args) {
  std::vector<const Expr*> dims;
  std::vector<const Var*> vars;
  unpack_dim_args(dim_args, &dims, &vars);

  std::vector<const Expr*> reduce_dims;
  std::vector<const Var*> reduce_vars;
  unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);

  std::vector<const Var*> all_vars;
  all_vars.insert(all_vars.end(), vars.begin(), vars.end());
  all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());

  ExprHandle body =
      Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars));
  std::vector<const Expr*> output_args(vars.begin(), vars.end());
  const Expr* init_expr = new Cast(body.dtype(), reducer.initializer());
  Buf* func_result = new Buf(func_name, dims, body.dtype(), init_expr);
  const ReduceOp* reduce_op =
      reducer(func_result, body, output_args, reduce_vars);
  Tensor* t =
      new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
  return t;
}

// Overload which allows inline lambda functions for the body_func.
template <typename BodyFunc>
Tensor* Reduce(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const Reducer& reducer,
    const BodyFunc&& body_func,
    const std::vector<DimArg>& reduce_args) {
  return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
}

// Overload for the common case of all dimensions of a Placeholder.
TORCH_API Tensor* Reduce(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const Reducer& reducer,
    const Placeholder& buffer,
    const std::vector<DimArg>& reduce_args);

// Overload for the common case of all dimensions of a prevously Computed
// Tensor.
TORCH_API Tensor* Reduce(
    const std::string& func_name,
    const std::vector<DimArg>& dim_args,
    const Reducer& reducer,
    Tensor* tensor,
    const std::vector<DimArg>& reduce_args);

class FunctionCall : public CallNode<FunctionCall> {
 public:
  using BaseClass = CallNode<FunctionCall>;
  static ExprHandle make(
      Tensor* tensor,
      const std::vector<ExprHandle>& params) {
    std::vector<const Expr*> params_nodes(params.size());
    for (size_t i = 0; i < params.size(); i++) {
      params_nodes[i] = params[i].node();
    }
    return ExprHandle(new FunctionCall(tensor, params_nodes));
  }

  const Tensor* tensor() const {
    return tensor_;
  }
  Tensor* tensor() {
    return tensor_;
  }

  FunctionCall(Tensor* tensor, const std::vector<const Expr*>& params)
      : BaseClass(tensor->buf()->dtype(), kFunctionCall, params),
        tensor_(tensor) {}

 private:
  const Expr* DefaultMutator(
      const std::vector<const Expr*>& new_params) const override {
    return new FunctionCall(tensor_, new_params);
  }

  std::string func_name() const override {
    return tensor_->buf()->name_hint();
  }

  Tensor* tensor_;
};
template <typename... Ts>
inline ExprHandle Tensor::operator()(const Ts&... ts) {
  std::vector<ExprHandle> params({ExprHandle(ts)...});
  return FunctionCall::make(this, std::move(params));
}

template <typename... Ts>
inline ExprHandle Tensor::call(const Ts&... ts) {
  std::vector<ExprHandle> params({ExprHandle(ts)...});
  return FunctionCall::make(this, std::move(params));
}

template <typename T>
inline ExprHandle Tensor::call(const std::vector<T>& args) {
  std::vector<ExprHandle> params(args.begin(), args.end());
  return FunctionCall::make(this, params);
}

template <typename... Ts>
inline ExprHandle Placeholder::load(const Ts&... ts) const {
  std::vector<ExprHandle> params({ExprHandle(ts)...});
  return ExprHandle(
      new Load(data(), ExprHandleVectorToExprVector(params), new IntImm(1)));
}

template <typename T>
inline ExprHandle Placeholder::load(const std::vector<T>& args) const {
  std::vector<ExprHandle> params(args.begin(), args.end());
  return ExprHandle(
      new Load(data(), ExprHandleVectorToExprVector(params), new IntImm(1)));
}

} // namespace tensorexpr
} // namespace jit
} // namespace torch