Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / algorithms / _quantization / quantization.py

import functools
import torch
import torch.distributed as dist


from enum import Enum


TORCH_HALF_MIN = torch.finfo(torch.float16).min
TORCH_HALF_MAX = torch.finfo(torch.float16).max

class DQuantType(Enum):
    """
    Different quantization methods for auto_quantize API are identified here.
    auto_quantize API currently supports fp16 and bfp16 methods.
    """
    FP16 = "fp16",
    BFP16 = "bfp16"

    def __str__(self) -> str:
        return self.value


def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
    return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()

def _quantize_tensor(tensor, qtype):
    if not isinstance(tensor, torch.Tensor):
        raise RuntimeError(
            f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
        )
    if (qtype == DQuantType.FP16):
        return _fp32_to_fp16_with_clamp(tensor)
    elif (qtype == DQuantType.BFP16):
        return torch.ops.quantization._FloatToBfloat16Quantized(tensor)
    else:
        raise RuntimeError(
            f'Quantization type {qtype} is not supported'
        )

def _quantize_tensor_list(tensor_list, qtype):
    if not isinstance(tensor_list, list) or not all(
        isinstance(p, torch.Tensor) for p in tensor_list
    ):
        raise RuntimeError(
            f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
        )
    quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
    return quantized_tensor_list

def _dequantize_tensor(tensor, qtype, quant_loss=None):
    if not isinstance(tensor, torch.Tensor):
        raise RuntimeError(
            f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
        )
    if (qtype == DQuantType.FP16):
        if tensor.dtype != torch.float16:
            raise RuntimeError(
                f"tensor dtype is {tensor.dtype} while expected to be FP16."
            )
        elif tensor.dtype == torch.float16 and quant_loss is None:
            return tensor.float()
        else:
            return tensor.float() / quant_loss
    elif (qtype == DQuantType.BFP16):
        if tensor.dtype != torch.float16:
            raise RuntimeError(
                f"tensor dtype is {tensor.dtype} while expected to be FP16."
            )
        else:
            return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor)
    else:
        raise RuntimeError(
            f'Quantization type {qtype} is not supported'
        )


def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
    if not isinstance(tensor_list, list) or not all(
        isinstance(p, torch.Tensor) for p in tensor_list
    ):
        raise RuntimeError(
            f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
        )
    dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
    return dequantized_tensor_list


def auto_quantize(func, qtype, quant_loss=None):
    """
    This is a prototype API that automatically quantize the input tensors, choose the precision types, and
    pass other necessary arguments and then dequantizes the output.
    Currently it only supports:
        . FP16 and BFP16 quantization method supported for gloo and nccl backends
        . all_gather, all_to_all collective ops
    Note: BFP16 only supports 2D tensors.
    Args:
        func (Callable): A function representing collective operations.
        qtype (QuantType): Quantization method
        quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
    Returns:
        (Callable): the same collective as func but enables automatic quantization/dequantization.
    """
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        group = kwargs.get('group', None)
        async_op = kwargs.get('async_op', False)
        if (async_op is True):
            raise RuntimeError(
                'The async_op=True mode is not supported yet.'
            )
        if (func == dist.all_gather):
            tensors = args[0]
            input_tensors = _quantize_tensor(args[1], qtype)
            out_tensors = _quantize_tensor_list(tensors, qtype)
            dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
            for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
                tensors[i] = t

        elif (func == dist.all_to_all):
            tensors = args[0]
            input_tensors = _quantize_tensor_list(args[1], qtype)
            out_tensors = _quantize_tensor_list(tensors, qtype)
            dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
            for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
                tensors[i] = t

        elif (func == dist.all_to_all_single):
            tensors = args[0]
            out_splits = kwargs.get('out_splits', None)
            in_splits = kwargs.get('in_splits', None)
            # Quantizing the input/output tensor
            input_tensors = _quantize_tensor(args[1], qtype)
            out_tensors = _quantize_tensor(tensors, qtype)
            dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group)
            for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)):
                tensors[i] = t
        else:
            raise RuntimeError(
                f"The collective op {func} is not supported yet"
            )

    return wrapper