Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ ao / ns / fx / mappings.py

import operator

import torch
import torch.nn as nn
import torch.nn.functional as F
toq = torch.ops.quantized

import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.ao.nn.qat.dynamic as nnqatd
from torch.ao.quantization.backend_config import get_native_backend_config
import torch.ao.quantization.fx._lower_to_native_backend as \
    _lower_to_native_backend
import torch.ao.quantization.quantization_mappings as quantization_mappings

from .ns_types import NSNodeTargetType

from typing import Callable, Dict, List, Optional, Set, Tuple


def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
    # note: this set is modified below by items from backend_config
    sets_of_related_ops: List[Set[NSNodeTargetType]] = [
        # conv modules
        {
            nn.Conv1d,
        },
        {
            nn.Conv2d,
        },
        {
            nn.Conv3d,
        },
        # conv functionals
        {
            F.conv1d,
        },
        {
            F.conv2d,
        },
        {
            F.conv3d,
        },
        # linear modules
        {
            nn.Linear,
        },
        # linear functionals
        {
            F.linear,
        },
        # average pool
        {
            nn.AvgPool1d,
            torch.avg_pool1d,
        },
        {
            nn.AvgPool2d,
            torch._C._nn.avg_pool2d,
        },
        {
            nn.AvgPool3d,
            torch._C._nn.avg_pool3d,
        },
        # adaptive average pool
        {
            nn.AdaptiveAvgPool1d,
            F.adaptive_avg_pool1d,
        },
        {
            nn.AdaptiveAvgPool2d,
            F.adaptive_avg_pool2d,
        },
        {
            nn.AdaptiveAvgPool3d,
            F.adaptive_avg_pool3d,
        },
        # LSTM
        {
            nn.LSTM,
        },
        # add
        {
            torch.add,
            operator.add,  # x + y
        },
        # cat
        {
            torch.cat,
        },
        # mul
        {
            torch.mul,
            operator.mul,
        },
        # relu
        {
            F.relu,
            nn.ReLU,
            'relu',
            'relu_',
            torch.relu,
        },
        # maxpool
        {
            nn.MaxPool1d,
            F.max_pool1d,
        },
        {
            nn.MaxPool2d,
            F.max_pool2d,
        },
        {
            nn.MaxPool3d,
            F.max_pool3d,
        },
        # sigmoid
        {
            torch.sigmoid,
            'sigmoid',
            'sigmoid_',
            nn.Sigmoid,
            F.sigmoid,
        },
        # BatchNorm
        {
            nn.BatchNorm2d,
        },
        {
            nn.BatchNorm3d,
        },
        # ConvTranspose
        {
            nn.ConvTranspose1d,
        },
        {
            nn.ConvTranspose2d,
        },
        {
            nn.ConvTranspose3d,
        },
        # ELU
        {
            nn.ELU,
        },
        # Embedding
        {
            nn.Embedding,
        },
        # EmbeddingBag
        {
            nn.EmbeddingBag,
        },
        # GroupNorm
        {
            nn.GroupNorm,
        },
        # Hardswish
        {
            nn.Hardswish,
        },
        # InstanceNorm
        {
            nn.InstanceNorm1d,
        },
        {
            nn.InstanceNorm2d,
        },
        {
            nn.InstanceNorm3d,
        },
        # LayerNorm
        {
            nn.LayerNorm,
        },
        # LeakyReLU
        {
            nn.LeakyReLU,
        },
        # ReLU6
        {
            nn.ReLU6,
            F.relu6,
        },
        # F.elu
        {
            F.elu,
        },
        # F.hardswish
        {
            F.hardswish,
        },
        # F.group_norm
        {
            F.group_norm,
        },
        # F.instance_norm
        {
            F.instance_norm,
        },
        # F.layer_norm
        {
            F.layer_norm,
        },
        # F.leaky_relu
        {
            F.leaky_relu,
        },
        # F.silu
        {
            nn.SiLU,
            F.silu,
        },
        # F.mish
        {
            nn.Mish,
            F.mish,
        },
        # F.tanh
        {
            nn.Tanh,
            F.tanh,
            torch.tanh,
            'tanh_',
            'tanh',
        },
        # F.hardsigmoid
        {
            'hardsigmoid_',
            'hardsigmoid',
            F.hardsigmoid,
            nn.Hardsigmoid,
        },
        # F.hardtanh
        {
            nn.Hardtanh,
            F.hardtanh,
            F.hardtanh_,
        },
        # floordiv
        {
            operator.floordiv,
        },
        # unsqueeze
        {
            torch.unsqueeze,
        },
        # stack
        {
            torch.stack,
        },
        # squeeze
        {
            torch.squeeze,
        },
        # sort
        {
            torch.sort,
        },
        # repeat_interleave
        {
            torch.repeat_interleave,
        },
        # min
        {
            torch.min,
        },
        # mean
        {
            torch.mean,
        },
        # max
        {
            torch.max,
        },
        # transpose
        {
            torch.transpose,
        },
        # flatten
        {
            torch.flatten,
        },
        # clamp
        {
            torch.clamp,
        },
        # chunk
        {
            torch.chunk,
        },
        # interpolate
        {
            torch.nn.functional.interpolate,
        },
        # dropout
        {
            nn.Dropout,
        },
        # F.dropout
        {
            F.dropout,
        },
        # matmul
        {
            torch.matmul,
        },
        # Softmax
        {
            nn.Softmax,
        },
        # PReLU
        {
            nn.PReLU,
            nnq.PReLU,
        },
        # F.prelu
        {
            F.prelu,
            toq.prelu,
        },
    ]

    # for each floating point op, add versions of the op added by
    # backend_config
    backend_config = get_native_backend_config()

    new_connections: List[Tuple[Callable, Callable]] = [
        # technical debt edge case
        (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
    ]

    for pattern, config in backend_config._pattern_complex_format_to_config.items():

        # pattern format: (c, (b, a))
        first_element = pattern
        # look from the end, because pattern is in reverse order
        while isinstance(first_element, (list, tuple)):
            first_element = first_element[-1]
Loading ...