Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / passes / quantization / quantization_patterns.h

#pragma once

#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <string>
#include <unordered_map>
#include <utility>

namespace torch {
namespace jit {

struct QuantFusionInfo {
  std::string quantized_op_name;
  std::string pattern;
  std::string replacement;
  std::vector<MatchFilter> filters = {};
};

namespace {
std::string getExtraArgList(std::vector<std::string> extra_args) {
  return std::accumulate(
      extra_args.begin(),
      extra_args.end(),
      std::string(),
      [](std::string acc, const std::string& arg) { return acc + ", " + arg; });
}

// Get the pattern we want to replace the match with
std::string getAtenOpPattern(
    const std::string& graph_header,
    const std::string& op_name,
    const std::vector<std::string>& extra_op_args,
    bool scalar_args = false) {
  std::vector<std::string> _extra_op_args = extra_op_args;
  std::string aten_op_pattern = graph_header;
  if (scalar_args) {
    for (const auto& extra_arg : _extra_op_args) {
      aten_op_pattern
          .append(R"(
          )")
          .append(extra_arg)
          .append("_scalar = aten::item(")
          .append(extra_arg)
          .append(")");
    }

    for (auto& _extra_op_arg : _extra_op_args) {
      _extra_op_arg.append("_scalar");
    }
  }
  const auto& extra_op_arg_list = getExtraArgList(std::move(_extra_op_args));
  aten_op_pattern += R"(
          %r = )";
  aten_op_pattern += op_name + "(" + "%a_quant" + extra_op_arg_list + ")";
  aten_op_pattern += R"(
          return (%r) )";
  return aten_op_pattern;
}

// generate ops for quantize pattern for a scalar value
std::string getQuantizeForScalar(const std::string& value) {
  // 6 is `torch.float` ScalarType, we are creating a float scalar
  // tensor from a scalar value
  std::string quantize_pattern = R"(
          )" +
      value + "_float_scalar_type : int = prim::Constant[value=6]()";
  quantize_pattern += R"(
          )" +
      value + "_none : None = prim::Constant()";
  quantize_pattern += R"(
          )" +
      value + "_tensor : Tensor = aten::scalar_tensor(" + value + ", " + value +
      "_float_scalar_type";
  for (const auto i : c10::irange(3)) {
    (void)i; // Suppress unused variable warning
    quantize_pattern += ", " + value + "_none";
  }
  quantize_pattern += ")";
  quantize_pattern +=
      R"(
          )" +
      value + "_quant = aten::quantize_per_tensor(" + value + "_tensor" +
      getExtraArgList(
          {value + "_scale", value + "_zero_point", value + "_dtype"}) +
      ")";
  return quantize_pattern;
}

std::string getDequantize(const std::string& value) {
  return R"(
          )" +
      value + "_dequant = aten::dequantize(" + value + "_quant)";
}

std::string getItem(const std::string& value) {
  return R"(
          )" +
      value + "_scalar : float = aten::item(" + value + "_dequant)";
}

// Patterns for the ops that inherit parameters from input
std::string getInputTensorQParamOpPattern(
    const std::string& op_name,
    const std::vector<std::string>& extra_op_args) {
  const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
  std::string op_pattern = "graph(%a_quant" + extra_op_arg_list + "):" + R"(
          %a_dequant = aten::dequantize(%a_quant)
          %r = )" +
      op_name + "(" + "%a_dequant" + extra_op_arg_list + ")" + R"(
          %r_scale : float = aten::q_scale(%a_quant)
          %r_zero_point : int = aten::q_zero_point(%a_quant)
          %r_dtype : int = prim::dtype(%a_quant)
          %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
          return (%r_quant) )";
  return op_pattern;
}

// QuantFusionInfo for the ops that inherit parameters from input
QuantFusionInfo getInputTensorQParamOpFusionInfo(
    const std::string& op_name,
    const std::vector<std::string>& extra_op_args) {
  std::string op_pattern =
      getInputTensorQParamOpPattern(op_name, extra_op_args);
  const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
  std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
  std::string op_replacement =
      getAtenOpPattern(graph_header, op_name, extra_op_args);

  return {op_name, std::move(op_pattern), std::move(op_replacement)};
}

// quant fusion for ops like `quantized::add_scalar`, `quantized::mul_scalar`
QuantFusionInfo getBinaryOpScalarFusionInfo(
    const std::string& op_name,
    const std::vector<std::string>& extra_op_args,
    const std::string& quantized_op_name,
    const std::vector<std::string>& extra_quantized_op_args,
    const std::vector<MatchFilter>& filters = {}) {
  std::string op_pattern =
      getInputTensorQParamOpPattern(op_name, extra_op_args);

  const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
  std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
  std::string op_replacement = getAtenOpPattern(
      graph_header, quantized_op_name, extra_quantized_op_args);

  return {op_name, std::move(op_pattern), std::move(op_replacement), filters};
}

QuantFusionInfo getClampOpFusionInfo(
    const std::string& op_name,
    const std::vector<std::string>& extra_op_args) {
  std::vector<std::string> header_args = extra_op_args;
  std::vector<std::string> input_qparams = {"_scale", "_zero_point", "_dtype"};
  for (const auto& arg : extra_op_args) {
    for (const auto& qparam : input_qparams) {
      header_args.push_back(arg + qparam);
    }
  }
  for (const auto& qparam : input_qparams) {
    header_args.push_back("%r" + qparam);
  }
  const auto& extra_header_arg_list = getExtraArgList(std::move(header_args));
  std::string graph_header = "graph(%a_quant" + extra_header_arg_list + "):";
  std::string op_pattern = graph_header;
  for (const auto& arg : extra_op_args) {
    op_pattern += getQuantizeForScalar(arg);
    op_pattern += getDequantize(arg);
    op_pattern += getItem(arg);
  }
  op_pattern += getDequantize("%a");
  op_pattern += R"(
          %r = )";
  std::vector<std::string> scalar_extra_args;
  scalar_extra_args.reserve(extra_op_args.size());
  for (const auto& arg : extra_op_args) {
    scalar_extra_args.push_back(arg + "_scalar");
  }
  op_pattern += op_name + "(" + "%a_dequant" +
      getExtraArgList(std::move(scalar_extra_args)) + ")";
  // IR pattern common to all ops that inherit qparam from input
  op_pattern += R"(
          %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
          return (%r_quant) )";

  std::string aten_op_pattern =
      getAtenOpPattern(graph_header, op_name, extra_op_args);

  return {op_name, std::move(op_pattern), std::move(aten_op_pattern)};
}

// Patterns for the ops that has fixed quantization parameters
QuantFusionInfo getFixedQParamOpFusionInfo(
    const std::string& op_name,
    const std::vector<std::string>& extra_op_args,
    bool is_symmetric) {
  const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
  std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
  std::string op_pattern = graph_header;
  op_pattern += R"(
          %a_dequant = aten::dequantize(%a_quant)
          %r = )";
  op_pattern += op_name + "(" + "%a_dequant" + extra_op_arg_list + ")";
  // IR pattern common to all ops with fixed quantization parameters for
  // asymetric quantization
  std::string asym_fixed_qparam_op_suffix = R"(
          %r_scale : float = prim::Constant[value=0.00390625]()
          %r_zero_point : int = prim::Constant[value=0]()
          %r_dtype : int = prim::Constant[value=13]()
          %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
          return (%r_quant) )";

  std::string sym_fixed_qparam_op_suffix = R"(
          %r_scale : float = prim::Constant[value=0.0078125]()
          %r_zero_point : int = prim::Constant[value=128]()
          %r_dtype : int = prim::Constant[value=13]()
          %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
          return (%r_quant) )";
  op_pattern +=
      is_symmetric ? sym_fixed_qparam_op_suffix : asym_fixed_qparam_op_suffix;

  std::string aten_op_pattern =
      getAtenOpPattern(graph_header, op_name, extra_op_args);

  return {op_name, std::move(op_pattern), std::move(aten_op_pattern)};
}

// filter that checks %b_scalar is a scalar
bool input_b_is_scalar(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap) {
  const auto& match_vmap = match.values_map;
  auto b_scalar = match_vmap.at(vmap.at("b_scalar"));
  return isScalar(b_scalar);
}

// Patterns for ops that require observation for output quantization parameters
// Example:
//
// before fusion:
//
// graph(%a_quant, %r_scale, %r_zero_point, %r_dtype):
//     %a_dequant = aten::dequantize(%a_quant)
//     %r = {op_name}(%a_dequant, {extra_args})
//     %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point,
//     %r_dtype) return (%r_quant)
//
// after fusion:
//
// graph(%a_quant, %r_scale, %r_zero_point, %r_dtype):
//     %r_quant = {quantized_op_name}(%a_quant, {extra_args}, %r_scale,
//     %r_zero_point) return (%r_quant)
QuantFusionInfo getObservedQParamOpFusionInfo(
    const std::string& fp_op_name,
    const std::string& q_op_name,
    const std::vector<std::string>& fp_extra_args,
    const std::vector<std::string>& q_extra_args) {
  const auto& fp_extra_arg_list = getExtraArgList(fp_extra_args);
  const auto& q_extra_arg_list = getExtraArgList(q_extra_args);

  std::string op_pattern = "graph(%a_quant" + fp_extra_arg_list +
      ", %r_scale, %r_zero_point, %r_dtype):" + R"(
          %a_dequant = aten::dequantize(%a_quant)
          %r = )" +
      fp_op_name + "(" + "%a_dequant" + fp_extra_arg_list + ")" + R"(
          %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
          return (%r_quant) )";

  std::string aten_op_pattern = "graph(%a_quant" + fp_extra_arg_list +
      ", %r_scale, %r_zero_point, %r_dtype):" + R"(
          %r_quant = )" +
      q_op_name + "(%a_quant" + q_extra_arg_list +
      ", %r_scale, %r_zero_point)" + R"(
          return (%r_quant) )";

  return {q_op_name, std::move(op_pattern), std::move(aten_op_pattern)};
}

} // namespace

std::vector<QuantFusionInfo> quant_fusion_pattern_and_replacements() {
  // aten::conv1d
  std::string conv1d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // aten::conv1d - aten::relu
  std::string conv1d_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %conv_out = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r = aten::relu(%conv_out)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // aten::conv1d - aten::relu_
  std::string conv1d_inplace_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %conv_out = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r = aten::relu_(%conv_out)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // quantized::conv1d
  std::string quantized_conv1d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %r_quant = quantized::conv1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r_quant) )";

  // quantized::conv1d_relu
  std::string quantized_conv1d_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %r_quant = quantized::conv1d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r_quant) )";

  // aten::conv2d
  std::string conv2d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // aten::conv2d - aten::relu
  std::string conv2d_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %conv_out = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r = aten::relu(%conv_out)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // aten::conv2d - aten::relu_
  std::string conv2d_inplace_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %conv_out = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r = aten::relu_(%conv_out)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // quantized::conv2d
  std::string quantized_conv2d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %r_quant = quantized::conv2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r_quant) )";

  // quantized::conv2d_relu
  std::string quantized_conv2d_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %r_quant = quantized::conv2d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r_quant) )";

  // aten::conv3d
  std::string conv3d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // aten::conv3d - aten::relu
  std::string conv3d_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %conv_out = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r = aten::relu(%conv_out)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // aten::conv3d - aten::relu_
  std::string conv3d_inplace_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %conv_out = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        %r = aten::relu_(%conv_out)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // quantized::conv3d
  std::string quantized_conv3d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %r_quant = quantized::conv3d(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r_quant) )";

  // quantized::conv3d_relu
  std::string quantized_conv3d_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
        %r_quant = quantized::conv3d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r_quant) )";

  // aten::conv_transpose1d
  std::string conv_transpose1d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose1d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // quantized::conv_transpose1d
  std::string quantized_conv_transpose1d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
        %r_quant = quantized::conv_transpose1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r_quant) )";

  // aten::conv_transpose2d
  std::string conv_transpose2d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose2d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // quantized::conv_transpose1d
  std::string quantized_conv_transpose2d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
        %r_quant = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r_quant) )";

  std::string add_relu = R"(
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
         %r_relu = aten::relu(%r_add)
         %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
         return (%r) )";

  std::string add_inplace_relu = R"(
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
         %r_relu = aten::relu_(%r_add)
         %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
         return (%r) )";

  std::string inplace_add_relu = R"(
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
         %r_relu = aten::relu(%r_add)
         %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
         return (%r) )";

  std::string inplace_add_inplace_relu = R"(
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
         %r_relu = aten::relu_(%r_add)
         %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
         return (%r) )";

  std::string quantized_add_relu = R"(
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
         %r = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point)
         return (%r) )";

  // aten::linear
  std::string linear = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::linear(%a_dequant, %w_dequant, %b)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  std::string linear_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %linear_out = aten::linear(%a_dequant, %w_dequant, %b)
        %r = aten::relu(%linear_out)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  std::string linear_inplace_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %linear_out = aten::linear(%a_dequant, %w_dequant, %b)
        %r = aten::relu_(%linear_out)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  // quantized::linear
  std::string quantized_linear = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
        %r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r) )";

  std::string quantized_linear_relu = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
        %r = quantized::linear_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
        return (%r) )";

  std::string cat = R"(
graph(%input_quant, %dim, %r_scale, %r_zero_point, %r_dtype):
        %input_dequant = aten::dequantize(%input_quant)
        %r = aten::cat(%input_dequant, %dim)
        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
        return (%r_quant) )";

  std::string quantized_cat = R"(
graph(%input_quant, %dim, %r_scale, %r_zero_point, %r_dtype):
         %r_quant = quantized::cat(%input_quant, %dim, %r_scale, %r_zero_point)
         return (%r_quant) )";

  // aten::add
  std::string add = R"(
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
         %r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
         return (%r) )";

  // TODO: add %dtype after when https://github.com/pytorch/pytorch/issues/34351
  // is fixed
  // quantized::add
  std::string quantized_add = R"(
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
         %r = quantized::add(%a_quant, %b_quant, %scale, %zero_point)
         return (%r) )";

  // aten::add_
  std::string inplace_add = R"(
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
         %r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
         return (%r) )";

  auto add_scalar = getBinaryOpScalarFusionInfo(
      "aten::add",
      {"%b_scalar", "%alpha"},
      "quantized::add_scalar",
      {"%b_scalar"},
      {aten_add_alpha_is_one, input_b_is_scalar});

  auto add_scalar_out = getBinaryOpScalarFusionInfo(
      "aten::add_",
      {"%b_scalar", "%alpha"},
      "quantized::add_scalar_out",
      {"%b_scalar", "%a_quant"},
      {aten_add_alpha_is_one, input_b_is_scalar});

  // quantized::add_scalar_relu -- fusing quantized::add_scalar
  // and aten::relu
  auto quantized_add_scalar_relu_pattern = R"(
graph(%a_quant, %b_scalar):
         %r_add = quantized::add_scalar(%a_quant, %b_scalar)
         %r = aten::relu(%r_add)
         return (%r) )";

  auto quantized_add_scalar_inplace_relu_pattern = R"(
graph(%a_quant, %b_scalar):
         %r_add = quantized::add_scalar(%a_quant, %b_scalar)
         %r = aten::relu_(%r_add)
         return (%r) )";

  auto quantized_add_scalar_relu_replacement = R"(
graph(%a_quant, %b_scalar):
         %r = quantized::add_scalar_relu(%a_quant, %b_scalar)
         return (%r) )";

  // quantized::add_scalar_relu_out -- fusing quantized::add_scalarOut
  // and aten::relu
  auto quantized_add_scalar_relu_out_pattern = R"(
graph(%a_quant, %b_scalar):
         %r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
         %r = aten::relu(%r_add)
         return (%r) )";

  auto quantized_add_scalar_inplace_relu_out_pattern = R"(
graph(%a_quant, %b_scalar):
         %r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
         %r = aten::relu_(%r_add)
         return (%r) )";

  auto quantized_add_scalar_relu_out_replacement = R"(
graph(%a_quant, %b_scalar):
         %r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
         return (%r) )";

  // quantized::batch_norm
  std::string batch_norm = R"(
graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
         %a_dequant = aten::dequantize(%a_quant)
         %r_bn = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
         %r = aten::quantize_per_tensor(%r_bn, %scale, %zero_point, %scalar_type)
         return (%r) )";
  std::string quantized_batch_norm = R"(
graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
         %r = quantized::batch_norm(%a_quant, %weight, %bias, %mean, %var, %eps, %scale, %zero_point)
         return (%r) )";

  std::string batch_norm_relu = R"(
graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
         %a_dequant = aten::dequantize(%a_quant)
         %bn_out = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
         %relu = aten::relu(%bn_out)
         %r = aten::quantize_per_tensor(%relu, %scale, %zero_point, %scalar_type)
         return (%r) )";
  std::string batch_norm_inplace_relu = R"(
graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
         %a_dequant = aten::dequantize(%a_quant)
         %bn_out = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
         %relu = aten::relu_(%bn_out)
         %r = aten::quantize_per_tensor(%relu, %scale, %zero_point, %scalar_type)
         return (%r) )";

  std::string quantized_batch_norm_relu = R"(
graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
         %r = quantized::batch_norm_relu(%a_quant, %weight, %bias, %mean, %var, %eps, %scale, %zero_point)
         return (%r) )";

  // aten::mul
  std::string mul = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_mul = aten::mul(%a_dequant, %b_dequant)
         %r = aten::quantize_per_tensor(%r_mul, %scale, %zero_point, %dtype)
         return (%r) )";

  // aten::mul_
  std::string inplace_mul = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_mul = aten::mul_(%a_dequant, %b_dequant)
         %r = aten::quantize_per_tensor(%r_mul, %scale, %zero_point, %dtype)
         return (%r) )";

  // quantized::mul
  std::string quantized_mul = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
         %r = quantized::mul(%a_quant, %b_quant, %scale, %zero_point)
         return (%r) )";

  auto mul_scalar = getBinaryOpScalarFusionInfo(
      "aten::mul",
      {"%b_scalar"},
      "quantized::mul_scalar",
      {"%b_scalar"},
      {input_b_is_scalar});

  auto mul_scalar_out = getBinaryOpScalarFusionInfo(
      "aten::mul_",
      {"%b_scalar"},
      "quantized::mul_scalar_out",
      {"%b_scalar", "%a_quant"},
      {input_b_is_scalar});

  // quantized::mul_relu
  std::string mul_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_mul = aten::mul(%a_dequant, %b_dequant)
         %r_relu = aten::relu(%r_mul)
         %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
         return (%r) )";

  std::string mul_inplace_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_mul = aten::mul(%a_dequant, %b_dequant)
         %r_relu = aten::relu_(%r_mul)
         %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
         return (%r) )";

  std::string inplace_mul_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_mul = aten::mul_(%a_dequant, %b_dequant)
         %r_relu = aten::relu(%r_mul)
         %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
         return (%r) )";

  std::string inplace_mul_inplace_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %b_dequant = aten::dequantize(%b_quant)
         %r_mul = aten::mul_(%a_dequant, %b_dequant)
         %r_relu = aten::relu_(%r_mul)
         %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
         return (%r) )";

  std::string quantized_mul_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
         %r = quantized::mul_relu(%a_quant, %b_quant, %scale, %zero_point)
         return (%r) )";

  // quantized::mul_scalar_relu -- fusing quantized::mul_scalar
  // and aten::relu
  auto quantized_mul_scalar_relu_pattern = R"(
graph(%a_quant, %b_scalar):
         %r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
         %r = aten::relu(%r_mul)
         return (%r) )";

  auto quantized_mul_scalar_inplace_relu_pattern = R"(
graph(%a_quant, %b_scalar):
         %r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
         %r = aten::relu_(%r_mul)
         return (%r) )";

  auto quantized_mul_scalar_relu_replacement = R"(
graph(%a_quant, %b_scalar):
         %r = quantized::mul_scalar_relu(%a_quant, %b_scalar)
         return (%r) )";

  // quantized::mul_scalar_relu_out -- fusing quantized::mul_scalarOut
  // and aten::relu
  auto quantized_mul_scalar_relu_out_pattern = R"(
graph(%a_quant, %b_scalar):
         %r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
         %r = aten::relu(%r_mul)
         return (%r) )";

  auto quantized_mul_scalar_inplace_relu_out_pattern = R"(
graph(%a_quant, %b_scalar):
         %r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
         %r = aten::relu_(%r_mul)
         return (%r) )";

  auto quantized_mul_scalar_relu_out_replacement = R"(
graph(%a_quant, %b_scalar):
         %r = quantized::mul_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
         return (%r) )";

  // quantized::elu
  std::string elu = R"(
graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %r = aten::elu(%a_dequant, %alpha, %scale, %input_scale)
         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
         return (%r_quant) )";

  std::string quantized_elu = R"(
graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
         %r_quant = quantized::elu(%a_quant, %r_scale, %r_zero_point, %alpha, %scale, %input_scale)
         return (%r_quant) )";

  std::string elu_ = R"(
graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
         %a_dequant = aten::dequantize(%a_quant)
         %r = aten::elu_(%a_dequant, %alpha, %scale, %input_scale)
         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
         return (%r_quant) )";

  // ============= General Ops that inherit quantization paramters from input
  // tensor =============
  auto avg_pool1d = getInputTensorQParamOpFusionInfo(
      "aten::avg_pool1d",
      {"%kernel_size",
       "%stride",
       "%padding",
       "%ceil_mode",
       "%count_include_pad"});

  auto avg_pool2d = getInputTensorQParamOpFusionInfo(
      "aten::avg_pool2d",
      {"%kernel_size",
       "%stride",
       "%padding",
       "%ceil_mode",
       "%count_include_pad",
       "%divisor_override"});

  std::string common_general_value_op = R"(
          %r_scale : float = aten::q_scale(%a_quant)
          %r_zero_point : int = aten::q_zero_point(%a_quant)
          %r_dtype : int = prim::dtype(%a_quant)
          %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
          return (%r_quant) )";

  auto avg_pool3d = getInputTensorQParamOpFusionInfo(
      "aten::avg_pool3d",
      {"%kernel_size",
       "%stride",
       "%padding",
       "%ceil_mode",
       "%count_include_pad",
       "%divisor_override"});

  auto adaptive_avg_pool1d = getInputTensorQParamOpFusionInfo(
      "aten::adaptive_avg_pool1d", {"%output_size"});

  auto adaptive_avg_pool2d = getInputTensorQParamOpFusionInfo(
      "aten::adaptive_avg_pool2d", {"%output_size"});

  auto adaptive_avg_pool3d = getInputTensorQParamOpFusionInfo(
      "aten::adaptive_avg_pool3d", {"%output_size"});

  auto mean1 = getInputTensorQParamOpFusionInfo("aten::mean", {"%dim"});

  auto mean2 = getInputTensorQParamOpFusionInfo(
      "aten::mean", {"%dim", "%keepdim", "%out"});

  auto upsample_nearest1d_vec = getInputTensorQParamOpFusionInfo(
      "aten::upsample_nearest1d", {"%output_size", "%scale_factors"});

  auto upsample_nearest2d_vec = getInputTensorQParamOpFusionInfo(
      "aten::upsample_nearest2d", {"%output_size", "%scale_factors"});

  auto upsample_nearest3d_vec = getInputTensorQParamOpFusionInfo(
      "aten::upsample_nearest3d", {"%output_size", "%scale_factors"});

  auto upsample_linear1d_vec = getInputTensorQParamOpFusionInfo(
      "aten::upsample_linear1d",
      {"%output_size", "%align_corners", "%scale_factors"});

  auto upsample_bilinear2d_vec = getInputTensorQParamOpFusionInfo(
      "aten::upsample_bilinear2d",
      {"%output_size", "%align_corners", "%scale_factors"});

  auto upsample_trilinear3d_vec = getInputTensorQParamOpFusionInfo(
      "aten::upsample_trilinear3d",
      {"%output_size", "%align_corners", "%scale_factors"});

  auto upsample_nearest1d = getInputTensorQParamOpFusionInfo(
      "aten::upsample_nearest1d", {"%output_size", "%scales"});

  auto upsample_nearest2d = getInputTensorQParamOpFusionInfo(
      "aten::upsample_nearest2d", {"%output_size", "%scale_h", "%scale_w"});

  auto upsample_nearest3d = getInputTensorQParamOpFusionInfo(
      "aten::upsample_nearest3d",
      {"%output_size", "%scale_d", "%scale_h", "%scale_w"});

  auto upsample_linear1d = getInputTensorQParamOpFusionInfo(
      "aten::upsample_linear1d", {"%output_size", "%align_corners", "%scales"});

  auto upsample_bilinear2d = getInputTensorQParamOpFusionInfo(
      "aten::upsample_bilinear2d",
      {"%output_size", "%align_corners", "%scale_h", "%scale_w"});

  auto upsample_trilinear3d = getInputTensorQParamOpFusionInfo(
      "aten::upsample_trilinear3d",
      {"%output_size", "%align_corners", "%scale_d", "%scale_h", "%scale_w"});

  auto clamp = getClampOpFusionInfo("aten::clamp", {"%min", "%max"});

  auto hardtanh = getClampOpFusionInfo("aten::hardtanh", {"%min", "%max"});

  auto hardtanh_ = getClampOpFusionInfo("aten::hardtanh_", {"%min", "%max"});

  auto leaky_relu =
      getInputTensorQParamOpFusionInfo("aten::leaky_relu", {"%negative_slope"});

  auto leaky_relu_ = getInputTensorQParamOpFusionInfo(
      "aten::leaky_relu_", {"%negative_slope"});

  // Ops with fixed quantization parameters
  auto hardsigmoid = getFixedQParamOpFusionInfo("aten::hardsigmoid", {}, false);

  auto hardsigmoid_ =
      getFixedQParamOpFusionInfo("aten::hardsigmoid_", {}, false);

  auto sigmoid = getFixedQParamOpFusionInfo("aten::sigmoid", {}, false);

  auto sigmoid_ = getFixedQParamOpFusionInfo("aten::sigmoid_", {}, false);

  auto tanh = getFixedQParamOpFusionInfo("aten::tanh", {}, true);

  auto tanh_ = getFixedQParamOpFusionInfo("aten::tanh_", {}, true);

  auto hardswish = getObservedQParamOpFusionInfo(
      "aten::hardswish", "quantized::hardswish", {}, {});

  auto hardswish_ = getObservedQParamOpFusionInfo(
      "aten::hardswish_", "quantized::hardswish", {}, {});

  auto layer_norm = getObservedQParamOpFusionInfo(
      "aten::layer_norm",
      "quantized::layer_norm",
      {"%normalized_shape", "%weight", "%bias", "%eps", "%cudnn_enabled"},
      {"%normalized_shape", "%weight", "%bias", "%eps"});

  auto group_norm = getObservedQParamOpFusionInfo(
      "aten::group_norm",
      "quantized::group_norm",
      {"%num_groups", "%weight", "%bias", "%eps", "%cudnn_enabled"},
      {"%num_groups", "%weight", "%bias", "%eps"});

  auto instance_norm = getObservedQParamOpFusionInfo(
      "aten::instance_norm",
      "quantized::instance_norm",
      {"%weight",
       "%bias",
       "%running_mean",
       "%running_var",
       "%use_input_stats",
       "%momentum",
       "%eps",
       "%cudnn_enabled"},
      {"%weight", "%bias", "%eps"});

  return {
      {"quantized::conv1d", std::move(conv1d), std::move(quantized_conv1d)},
      {"quantized::conv1d_relu", std::move(conv1d_relu), quantized_conv1d_relu},
      {"quantized::conv1d_relu",
       std::move(conv1d_inplace_relu),
       std::move(quantized_conv1d_relu)},
      {"quantized::conv2d", std::move(conv2d), std::move(quantized_conv2d)},
      {"quantized::conv2d_relu", std::move(conv2d_relu), quantized_conv2d_relu},
      {"quantized::conv2d_relu",
       std::move(conv2d_inplace_relu),
       std::move(quantized_conv2d_relu)},
      {"quantized::conv3d", std::move(conv3d), std::move(quantized_conv3d)},
      {"quantized::conv3d_relu", std::move(conv3d_relu), quantized_conv3d_relu},
      {"quantized::conv3d_relu",
       std::move(conv3d_inplace_relu),
       std::move(quantized_conv3d_relu)},
      {"quantized::conv_transpose1d",
       std::move(conv_transpose1d),
       std::move(quantized_conv_transpose1d)},
      {"quantized::conv_transpose2d",
       std::move(conv_transpose2d),
       std::move(quantized_conv_transpose2d)},
      {"quantized::linear", std::move(linear), std::move(quantized_linear)},
      {"quantized::linear_relu", std::move(linear_relu), quantized_linear_relu},
      {"quantized::linear_relu",
       std::move(linear_inplace_relu),
       std::move(quantized_linear_relu)},
      {"quantized::add_relu",
       std::move(add_relu),
       quantized_add_relu,
       {aten_add_alpha_is_one}},
      {"quantized::add_relu",
       std::move(add_inplace_relu),
       quantized_add_relu,
       {aten_add_alpha_is_one}},
      {"quantized::add_relu",
       std::move(inplace_add_relu),
       quantized_add_relu,
       {aten_add_alpha_is_one}},
      {"quantized::add_relu",
       std::move(inplace_add_inplace_relu),
       std::move(quantized_add_relu),
       {aten_add_alpha_is_one}},
      std::move(add_scalar),
      std::move(add_scalar_out),
      // note that these must come after quantized::add_scalar and
      // quantized::add_scalar_out patterns
      {"quantized::add_scalar_relu",
       quantized_add_scalar_relu_pattern,
       quantized_add_scalar_relu_replacement},
      {"quantized::add_scalar_relu",
       quantized_add_scalar_inplace_relu_pattern,
       quantized_add_scalar_relu_replacement},
      {"quantized::add_scalar_relu_out",
       quantized_add_scalar_relu_out_pattern,
       quantized_add_scalar_relu_out_replacement},
      {"quantized::add_scalar_relu_out",
       quantized_add_scalar_inplace_relu_out_pattern,
       quantized_add_scalar_relu_out_replacement},
      {"quantized::add",
       std::move(add),
       quantized_add,
       {aten_add_alpha_is_one}},
      {"quantized::add",
       std::move(inplace_add),
       std::move(quantized_add),
       {aten_add_alpha_is_one}},
      {"quantized::cat", std::move(cat), std::move(quantized_cat)},
      {"quantized::batch_norm",
       std::move(batch_norm),
       std::move(quantized_batch_norm)},
      {"quantized::batch_norm_relu",
       std::move(batch_norm_relu),
       quantized_batch_norm_relu},
      {"quantized::batch_norm_relu",
       std::move(batch_norm_inplace_relu),
       std::move(quantized_batch_norm_relu)},
      std::move(mul_scalar),
      std::move(mul_scalar_out),
      // note that these must come after quantized::mul_scalar and
      // quantized::mul_scalar_out patterns
      {"quantized::mul_scalar_relu",
       quantized_mul_scalar_relu_pattern,
       quantized_mul_scalar_relu_replacement},
      {"quantized::mul_scalar_relu",
       quantized_mul_scalar_inplace_relu_pattern,
       quantized_mul_scalar_relu_replacement},
      {"quantized::mul_scalar_relu_out",
       quantized_mul_scalar_relu_out_pattern,
       quantized_mul_scalar_relu_out_replacement},
      {"quantized::mul_scalar_relu_out",
       quantized_mul_scalar_inplace_relu_out_pattern,
       quantized_mul_scalar_relu_out_replacement},
      {"quantized::mul_relu", std::move(mul_relu), quantized_mul_relu},
      {"quantized::mul_relu", std::move(mul_inplace_relu), quantized_mul_relu},
      {"quantized::mul_relu", std::move(inplace_mul_relu), quantized_mul_relu},
      {"quantized::mul_relu",
       std::move(inplace_mul_inplace_relu),
       std::move(quantized_mul_relu)},
      {"quantized::mul", std::move(mul), quantized_mul},
      {"quantized::mul", std::move(inplace_mul), std::move(quantized_mul)},
      std::move(hardswish),
      std::move(hardswish_),
      std::move(layer_norm),
      std::move(group_norm),
      std::move(instance_norm),
      {"quantized::elu", std::move(elu), quantized_elu},
      {"quantized::elu_", std::move(elu_), std::move(quantized_elu)},
      std::move(avg_pool1d),
      std::move(avg_pool2d),
      std::move(avg_pool3d),
      std::move(adaptive_avg_pool1d),
      std::move(adaptive_avg_pool2d),
      std::move(adaptive_avg_pool3d),
      std::move(mean1),
      std::move(mean2),
      std::move(upsample_nearest1d),
      std::move(upsample_nearest2d),
      std::move(upsample_nearest3d),
      std::move(upsample_linear1d),
      std::move(upsample_bilinear2d),
      std::move(upsample_trilinear3d),
      std::move(upsample_nearest1d_vec),
      std::move(upsample_nearest2d_vec),
      std::move(upsample_nearest3d_vec),
      std::move(upsample_linear1d_vec),
      std::move(upsample_bilinear2d_vec),
      std::move(upsample_trilinear3d_vec),
      std::move(clamp),
      std::move(hardtanh),
      std::move(hardtanh_),
      std::move(leaky_relu),
      std::move(leaky_relu_),
      // fixed qparam ops
      std::move(hardsigmoid),
      std::move(hardsigmoid_),
      std::move(sigmoid),
      std::move(sigmoid_),
      std::move(tanh),
      std::move(tanh_),
  };
}

inline std::vector<QuantFusionInfo>
dynamic_quantized_linear_pattern_and_replacements() {
  std::string linear_dynamic = R"(
graph(%packed_params, %a):
        %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::linear(%a, %w_dequant, %b)
        return (%r) )";

  // This pattern ignores reduce range
  // Set the reduce range to default to true, since qnnpack backend ignores this
  // argument.
  std::string quantized_linear_dynamic = R"(
graph(%packed_params, %a):
        %reduce_range : bool = prim::Constant[value=1]()
        %r = quantized::linear_dynamic(%a, %packed_params, %reduce_range)
        return (%r) )";

  return {
      {"quantized::linear_dynamic",
       std::move(linear_dynamic),
       std::move(quantized_linear_dynamic)},
  };
}

std::vector<QuantFusionInfo> dynamic_quant_fusion_pattern_and_replacements() {
  std::string linear_dynamic = R"(
graph(%packed_params, %a, %reduce_range, %a_dtype):
        %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
        %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
        %a_dequant = aten::dequantize(%a_quant)
        %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::linear(%a_dequant, %w_dequant, %b)
        return (%r) )";

  std::string quantized_linear_dynamic = R"(
graph(%packed_params, %a, %reduce_range, %a_dtype):
        %r = quantized::linear_dynamic(%a, %packed_params, %reduce_range)
        return (%r) )";

  std::string linear_dynamic_fp16 = R"(
graph(%packed_params, %a):
        %w_unpacked : Tensor, %b : Tensor? = quantized::linear_unpack_fp16(%packed_params)
        %r = aten::linear(%a, %w_unpacked, %b)
        return (%r) )";

  std::string quantized_linear_dynamic_fp16 = R"(
graph(%packed_params, %a):
        %r = quantized::linear_dynamic_fp16(%a, %packed_params)
        return (%r) )";

  return {
      {"quantized::linear_dynamic",
       std::move(linear_dynamic),
       std::move(quantized_linear_dynamic)},
      {"quantized::linear_dynamic_fp16",
       std::move(linear_dynamic_fp16),
       std::move(quantized_linear_dynamic_fp16)},
  };
}

std::vector<QuantFusionInfo> linear_prepack_unpack_patterns() {
  std::string linear_with_quant = R"(
graph(%a_dequant, %w_quant, %b):
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::linear(%a_dequant, %w_dequant, %b)
        return (%r) )";

  std::string linear_with_quant_prepack = R"(
graph(%a_dequant, %w_quant, %b):
        %packed_params = quantized::linear_prepack(%w_quant, %b)
        %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant_unpacked)
        %r = aten::linear(%a_dequant, %w_dequant, %b_unpacked)
        return (%r) )";
  std::string linear_fp16_with_cast = R"(
graph(%w, %a_dq, %b):
        %fp16_tensor = aten::_saturate_weight_to_fp16(%w)
        %r = aten::linear(%a_dq, %fp16_tensor, %b)
        return (%r) )";
  std::string linear_fp16_with_prepack = R"(
graph(%w, %a_dq, %b):
        %packed_params = quantized::linear_prepack_fp16(%w, %b)
        %w_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack_fp16(%packed_params)
        %r = aten::linear(%a_dq, %w_unpacked, %b_unpacked)
        return (%r) )";

  return {
      {"linear_prepack_unpack",
       std::move(linear_with_quant),
       std::move(linear_with_quant_prepack)},
      {"linear_fp16_prepack_unpack",
       std::move(linear_fp16_with_cast),
       std::move(linear_fp16_with_prepack)},
  };
}

std::vector<QuantFusionInfo> conv_prepack_unpack_patterns() {
  std::string conv1d_with_quant = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        return (%r) )";

  std::string conv1d_with_quant_prepack = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
        %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv1d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
        %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv1d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant_unpacked)
        %r = aten::conv1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
        return (%r) )";

  std::string conv2d_with_quant = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        return (%r) )";

  std::string conv2d_with_quant_prepack = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
        %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
        %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv2d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant_unpacked)
        %r = aten::conv2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
        return (%r) )";

  std::string conv3d_with_quant = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
        return (%r) )";

  std::string conv3d_with_quant_prepack = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
        %packed_params : __torch__.torch.classes.quantized.Conv3dPackedParamsBase = quantized::conv3d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
        %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv3d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant_unpacked)
        %r = aten::conv3d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
        return (%r) )";

  std::string conv_transpose1d_with_quant = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
        return (%r) )";

  std::string conv_transpose1d_with_quant_prepack = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
        %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose1d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups)
        %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose1d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant_unpacked)
        %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation)
        return (%r) )";

  std::string conv_transpose2d_with_quant = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
        return (%r) )";

  std::string conv_transpose2d_with_quant_prepack = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
        %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose2d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups)
        %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose2d_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant_unpacked)
        %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation)
        return (%r) )";

  return {
      {"conv1d_prepack_unpack",
       std::move(conv1d_with_quant),
       std::move(conv1d_with_quant_prepack)},
      {"conv2d_prepack_unpack",
       std::move(conv2d_with_quant),
       std::move(conv2d_with_quant_prepack)},
      {"conv3d_prepack_unpack",
       std::move(conv3d_with_quant),
       std::move(conv3d_with_quant_prepack)},
      {"conv_transpose1d_prepack_unpack",
       std::move(conv_transpose1d_with_quant),
       std::move(conv_transpose1d_with_quant_prepack)},
      {"conv_transpose2d_prepack_unpack",
       std::move(conv_transpose2d_with_quant),
       std::move(conv_transpose2d_with_quant_prepack)}};
}

} // namespace jit
} // namespace torch