Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / ATen / core / function_schema.h

#pragma once

#include <c10/util/StringUtil.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/alias_info.h>
#include <ATen/core/operator_name.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <unordered_map>

namespace c10 {

// schema as used in the compiler for resolving function calls and reporting
// errors. These objects should be constructed from C10 schema once those
// are available.

struct Argument;
struct FunctionSchema;

bool operator==(const Argument& lhs, const Argument& rhs);

struct Argument {
  Argument(
      std::string name = "",
      TypePtr type = nullptr,
      c10::optional<int32_t> N = c10::nullopt,
      c10::optional<IValue> default_value = c10::nullopt,
      bool kwarg_only = false,
      c10::optional<AliasInfo> alias_info = c10::nullopt)
      : name_(std::move(name)),
        type_(type ? type : TensorType::get()),
        N_(std::move(N)),
        default_value_(std::move(default_value)),
        kwarg_only_(kwarg_only),
        alias_info_(std::move(alias_info)) {
  }
  const std::string& name() const {
    return name_;
  }
  const TypePtr& type() const {
    return type_;
  }
  c10::optional<int32_t> N() const {
    return N_;
  }
  const c10::optional<IValue>& default_value() const {
    return default_value_;
  }
  bool kwarg_only() const {
    return kwarg_only_;
  }
  const c10::optional<AliasInfo>& alias_info() const {
    return alias_info_;
  }
  bool is_inferred_type() const {
    bool is_inferred_type = false;
    TORCH_INTERNAL_ASSERT(type_);
    if (auto pt = type_->cast<TensorType>()) {
      if (pt->isInferredType()) {
        is_inferred_type = true;
      }
    }
    return is_inferred_type;
  }

  std::string formatTypeMismatchMsg(const std::string& actual_type) const {
    std::string inferred_type_hint;
    if (is_inferred_type()) {
      inferred_type_hint = c10::str(
          "Inferred '",
          name(),
          "' to be of type 'Tensor' ",
          "because it was not annotated with an explicit type.\n");
    }
    return c10::str(
        "Expected a value of type '",
        type()->repr_str(),
        "' for argument '",
        name(),
        "' but instead found type '",
        actual_type,
        "'.\n",
        inferred_type_hint);
  }

  Argument cloneWithType(TypePtr new_type) const {
    return Argument(
        name_,
        std::move(new_type),
        N_,
        default_value_,
        kwarg_only_,
        alias_info_);
  }

  // this function checks whether this Argument is backward compatible with
  // the old one. we consider the following cases are backward compatible:
  //   1) two arguments are equal
  //   2) this arg's type should be subtype of old
  //   3) this arg must provide the same default value if old arg has one,
  bool isBackwardCompatibleWith(
      const Argument& old,
      std::ostream* why_not=nullptr) const;

 private:
  std::string name_;
  TypePtr type_;
  // for list types, an optional statically known length for the list
  // e.g. for int[3]: type = ListType::ofInts(), N = 3
  // If present, this will allow scalars to be broadcast to this length to
  // become a list.
  c10::optional<int32_t> N_;

  c10::optional<IValue> default_value_;
  // is this only specifiable as a keyword argument?
  bool kwarg_only_;
  c10::optional<AliasInfo> alias_info_;
};

inline bool operator==(const Argument& lhs, const Argument& rhs) {
  return lhs.name() == rhs.name()
          && *lhs.type() == *rhs.type()
          && lhs.N() == rhs.N()
          && lhs.default_value() == rhs.default_value()
          && lhs.kwarg_only() == rhs.kwarg_only()
          && lhs.alias_info() == rhs.alias_info();
}

bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs);

struct FunctionSchema {
  FunctionSchema(
      std::string name,
      std::string overload_name,
      std::vector<Argument> arguments,
      std::vector<Argument> returns,
      bool is_vararg = false,
      bool is_varret = false)
      : name_({std::move(name), std::move(overload_name)}),
        arguments_(std::move(arguments)),
        returns_(std::move(returns)),
        is_vararg_(is_vararg),
        is_varret_(is_varret) {
    checkSchema();
  }

  FunctionSchema(
      Symbol name,
      std::string overload_name,
      std::vector<Argument> arguments,
      std::vector<Argument> returns,
      bool is_vararg = false,
      bool is_varret = false)
      : FunctionSchema(
            name.toQualString(),
            std::move(overload_name),
            std::move(arguments),
            std::move(returns),
            is_vararg,
            is_varret) {
    checkSchema();
  }

  // Checks whether this schema is backward compatible with the old one.
  // The following conditions must be true:
  // [Function structure] The new schema's name, overload-name, varargs, and
  //      return arity are the same.
  // [Output Narrowing] The new schema's output type must be the same class
  //      or inherit from the old schema's output type.
  // [Argument count] The new schema must have at least as many arguments as
  //      the old schema (considering the list of positional and kwargs).
  // [Arg Compatibility] Every argument in the old schema has a corresponding
  //      argument in the new schema that:
  //        * is at the same position.
  //        * has the same name.
  //        * is either positional, or kwarg and the old argument was kwarg.
  //        * has the same type, or the old argument's type inherits from the
  //          new argument's type.
  // [Default Values] Every new argument must have a default value.
  // E.g.
  //   OK    f_new(a, b, c=1) => f_old(a, b)
  //   NOK   f_new(a, c=1, *, b) => f_old(a, *, b)
  //   OK    f_new(a, b, *, c) => f_old(a, *, b, c)
  //   NOK   f_new(a, *, b, c) -> f_old(a, b, *, c)
  //   NOK   f_new(a, *, c, b) => f_old(a, *, b, c)
  //   OK    f_new(a, *, b, c, d=1) => f_old(a, *, b, c)
  bool isBackwardCompatibleWith(
      const FunctionSchema& old,
      std::ostream* why_not = nullptr) const;

 private:
  OperatorName name_;
  std::vector<Argument> arguments_;
  std::vector<Argument> returns_;
  // if true then this schema takes an arbitrary number of additional arguments
  // after the argument specified in arguments
  // currently this is used primarily to represent 'primitive' operators whose
  // arguments are not checked by schema
  bool is_vararg_;
  bool is_varret_;

  // if no alias information is directly specified, what kind of "default"
  // alias information should we infer?
  // NB: due to alias analysis kind merging, this may be nullopt.  Eventually
  // this should always be set no matter what
  c10::optional<AliasAnalysisKind> alias_kind_;

  void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const;

  void checkSchema() const {
    bool seen_default_arg = false;
    for (const auto& arg : arguments()) {
      if (arg.default_value()) {
        seen_default_arg = true;
      } else {
        // we have historically serialized broadcasting lists wo/default values,
        // so to not break BC allow lists here
        if (arg.type()->kind() == ListType::Kind) {
          continue;
        }
        TORCH_INTERNAL_ASSERT(
            !seen_default_arg || arg.kwarg_only(),
            "Non-default positional argument follows default argument. Parameter ",
            arg.name(),
            " in ",
            *this);
      }
    }
  }

 public:

  void dump() const;

  const OperatorName& operator_name() const {
    return name_;
  }
  const std::string& name() const {
    return name_.name;
  }
  const std::string& overload_name() const {
    return name_.overload_name;
  }
  const std::vector<Argument>& arguments() const {
    return arguments_;
  }
  const std::vector<Argument>& returns() const {
    return returns_;
  }
  bool is_vararg() const {
    return is_vararg_;
  }
  bool is_varret() const {
    return is_varret_;
  }
  bool is_mutable() const {
    return std::any_of(
        arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
          const auto& aliasInfo = arg.alias_info();
          return aliasInfo && aliasInfo.value().isWrite();
        });
  }

  c10::optional<int> argumentIndexWithName(const std::string& name) const {
    for(size_t i = 0; i < arguments().size(); ++i) {
      if(name == arguments()[i].name())
        return i;
    }
    return c10::nullopt;
  }
  FunctionSchema cloneWithName(std::string name, std::string overload_name) const {
    return FunctionSchema(
        std::move(name),
        std::move(overload_name),
        arguments(),
        returns(),
        is_vararg(),
        is_varret()
        );
  }
  FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const {
    return FunctionSchema(
        name(),
        overload_name(),
        std::move(new_arguments),
        returns(),
        is_vararg(),
        is_varret());
  }
  FunctionSchema cloneWithReturns(std::vector<Argument> new_returns) const {
    return FunctionSchema(
        name(),
        overload_name(),
        arguments(),
        std::move(new_returns),
        is_vararg(),
        is_varret());
  }

  std::string formatTypeMismatchMsg(
      const Argument& expected,
      const std::string& actual_type,
      c10::optional<size_t> position = c10::nullopt,
      c10::optional<std::string> value = c10::nullopt) const;

  FunctionSchema cloneWithRemappedTypes(
      const std::function<TypePtr(TypePtr)> type_map) const;

  // Check that inputs have the correct types and appends any missing default
  // values.
  void checkAndNormalizeInputs(
      std::vector<IValue>& inputs,
      const std::unordered_map<std::string, IValue>& kwargs =
          std::unordered_map<std::string, IValue>{}) const;

  std::string findErrorInKwargs(const std::vector<std::string>& kwargs) const;

  bool hasAnyAliasInfo() const {
    for (const auto& arg : arguments_) {
      if (arg.alias_info().has_value()) {
        return true;
      }
    }
    for (const auto& ret : returns_) {
      if (ret.alias_info().has_value()) {
        return true;
      }
    }
    return false;
  }


  // TODO remove the mutation here
  bool isDefaultAliasAnalysisKind() const {
    return !alias_kind_;
  }
  AliasAnalysisKind aliasAnalysis() const {
    return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE);
  }
  void setAliasAnalysis(AliasAnalysisKind v) {
    alias_kind_ = v;
  }

  c10::optional<c10::string_view> getNamespace() const {
Loading ...