Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / ATen / native / quantized / affine_quantizer.h

#pragma once

#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/quantized/affine_quantizer_base.h>

namespace at {
namespace native {

Tensor quantize_tensor_per_tensor_affine(
    Tensor rtensor,
    Tensor qtensor,
    double scale,
    int64_t zero_point);
Tensor quantize_tensor_per_channel_affine(
    Tensor qtensor,
    Tensor rtensor,
    Tensor scales,
    Tensor zero_points,
    int64_t axis);

Tensor quantize_tensor_per_channel_float_qparams(
    Tensor qtensor,
    Tensor rtensor,
    Tensor scales,
    Tensor zero_points,
    int64_t axis);

Tensor dequantize_tensor_per_tensor_affine(
    Tensor qtensor,
    Tensor rtensor,
    double scale,
    int64_t zero_point);
Tensor dequantize_tensor_per_channel_affine(
    Tensor qtensor,
    Tensor rtensor,
    Tensor scales,
    Tensor zero_points,
    int64_t axis);
Tensor dequantize_tensor_per_channel_float_qparams(
    Tensor qtensor,
    Tensor rtensor,
    Tensor scales,
    Tensor zero_points,
    int64_t axis);

using quantize_tensor_per_tensor_affine_fn =
    void (*)(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);

using quantize_tensor_per_channel_affine_fn = void (*)(
    Tensor qtensor,
    Tensor rtensor,
    Tensor scales,
    Tensor zero_points,
    int64_t axis);

using quantize_tensor_per_channel_float_qparams_fn = void (*)(
    Tensor qtensor,
    Tensor rtensor,
    Tensor scales,
    Tensor zero_points,
    int64_t axis);

using dequantize_tensor_per_tensor_affine_fn =
    void (*)(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point);

using dequantize_tensor_per_channel_affine_fn = void (*)(
    Tensor qtensor,
    Tensor rtensor,
    Tensor scales,
    Tensor zero_points,
    int64_t axis);

using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
    Tensor qtensor,
    Tensor rtensor,
    Tensor scales,
    Tensor zero_points,
    int64_t axis);

using quantize_tensor_per_tensor_affine_sub_byte_fn =
    void (*)(Tensor rtensor, Tensor qtensor, float scale, float zero_point);

using dequantize_tensor_per_tensor_affine_sub_byte_fn =
    void (*)(Tensor qtensor, Tensor rtensor, float scale, float zero_point);

DECLARE_DISPATCH(
    quantize_tensor_per_tensor_affine_fn,
    quantize_tensor_per_tensor_affine_stub);
DECLARE_DISPATCH(
    quantize_tensor_per_channel_affine_fn,
    quantize_tensor_per_channel_affine_stub);
DECLARE_DISPATCH(
    quantize_tensor_per_channel_float_qparams_fn,
    quantize_tensor_per_channel_float_qparams_stub);

DECLARE_DISPATCH(
    dequantize_tensor_per_tensor_affine_fn,
    dequantize_tensor_per_tensor_affine_stub);
DECLARE_DISPATCH(
    dequantize_tensor_per_channel_affine_fn,
    dequantize_tensor_per_channel_affine_stub);
DECLARE_DISPATCH(
    dequantize_tensor_per_channel_float_qparams_fn,
    dequantize_tensor_per_channel_float_qparams_stub);

DECLARE_DISPATCH(
    quantize_tensor_per_tensor_affine_sub_byte_fn,
    quantize_tensor_per_tensor_affine_sub_byte_stub);

DECLARE_DISPATCH(
    dequantize_tensor_per_tensor_affine_sub_byte_fn,
    dequantize_tensor_per_tensor_affine_sub_byte_stub);

template <typename T>
TORCH_API Tensor quantize_tensor(
    Tensor rtensor,
    Tensor qtensor,
    double scale,
    int64_t zero_point);
template <typename T>
TORCH_API Tensor dequantize_tensor(
    Tensor qtensor,
    Tensor rtensor,
    double scale,
    int64_t zero_point);

} // namespace native
} // namespace at