import torch
from .backend_config import (
BackendConfig,
BackendPatternConfig,
DTypeConfig,
ObservationType
)
from ._common_operator_config_utils import (
_get_binary_op_configs,
_get_linear_configs,
_get_conv_configs,
_get_share_qparams_op_configs,
_get_tensor_info_op_configs,
)
__all__ = [
"get_tensorrt_backend_config",
"get_tensorrt_backend_config_dict",
]
def get_tensorrt_backend_config() -> BackendConfig:
"""
Return the `BackendConfig` for the TensorRT backend.
NOTE: Current api will change in the future, it's just to unblock experimentation for
new backends, please don't use it right now.
TODO: add a README when it's more stable
"""
# dtype configs
weighted_op_qint8_dtype_config = DTypeConfig(
input_dtype=torch.qint8,
output_dtype=torch.qint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
non_weighted_op_qint8_dtype_config = DTypeConfig(
input_dtype=torch.qint8,
output_dtype=torch.qint8,
)
addmm_config = BackendPatternConfig(torch.addmm) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(weighted_op_qint8_dtype_config) \
._set_input_type_to_index({
"bias": 0,
"input": 1,
"weight": 2,
})
cat_config = BackendPatternConfig(torch.cat) \
.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \
.add_dtype_config(non_weighted_op_qint8_dtype_config)
conv_dtype_configs = [
weighted_op_qint8_dtype_config,
]
linear_dtype_configs = [
weighted_op_qint8_dtype_config,
]
binary_op_dtype_configs = [
weighted_op_qint8_dtype_config,
]
share_qparams_op_dtype_configs = [
non_weighted_op_qint8_dtype_config,
]
tensor_info_op_dtype_configs = [
non_weighted_op_qint8_dtype_config,
]
# there might be things not supported in fx2trt, but it will error out
# during fx2trt conversion and can support them after that
return BackendConfig("tensorrt") \
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
.set_backend_pattern_config(addmm_config) \
.set_backend_pattern_config(cat_config) \
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs))
def get_tensorrt_backend_config_dict():
"""
Return the `BackendConfig` for the TensorRT backend in dictionary form.
"""
return get_tensorrt_backend_config().to_dict()