Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / _tensor / ops / math_ops.py

# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, Optional, Sequence

import torch

from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import pointwise_rule, reduction_rule
from torch.distributed._tensor.ops.utils import (
    as_list,
    normalize_dims,
    register_prop_rule,
)
from torch.distributed._tensor.placement_types import DTensorSpec


aten = torch.ops.aten


def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[Sequence[int]]:
    if dims_arg is None:
        return None
    dims = cast(Sequence[int], as_list(dims_arg))
    dims = normalize_dims(dims, ndim)
    empty_dims = [[0], [-1], []]
    if ndim == 0 and dims_arg in empty_dims:
        return None
    return dims


@register_prop_rule(aten.all.default)
def default_reduction_rule(op_schema: OpSchema) -> OutputSharding:
    return reduction_rule(op_schema, reduction_linear=True)


@register_prop_rule(
    [
        aten.sum.default,
        aten.sum.dim_IntList,
    ]
)
def sum_rule(op_schema: OpSchema) -> OutputSharding:
    args_schema = op_schema.args_schema
    input_spec = cast(DTensorSpec, args_schema[0])
    dims = None
    if len(args_schema) > 1:
        dims = _infer_reduction_dims(args_schema[1], input_spec.ndim)

    keep_dim = len(args_schema) > 2 and bool(args_schema[2])
    return reduction_rule(
        op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=True
    )


@register_prop_rule(aten._softmax.default)
def softmax_rule(op_schema: OpSchema) -> OutputSharding:
    input_spec, softmax_dim, _ = op_schema.args_schema
    input_spec = cast(DTensorSpec, input_spec)
    softmax_dim = cast(int, softmax_dim)
    dim_map = input_spec.dim_map
    if softmax_dim < len(dim_map) and dim_map[softmax_dim] >= 0:
        raise RuntimeError("Cannot run softmax on sharding dimension!")
    return OutputSharding(input_spec)


@register_prop_rule(aten._softmax_backward_data.default)
def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding:
    grad_out_spec, out_spec, softmax_dim, _ = op_schema.args_schema
    grad_out_spec = cast(DTensorSpec, grad_out_spec)
    out_spec = cast(DTensorSpec, out_spec)
    softmax_dim = cast(int, softmax_dim)
    grad_out_dim_map = grad_out_spec.dim_map
    out_dim_map = out_spec.dim_map
    if softmax_dim < len(grad_out_dim_map) and (
        grad_out_dim_map[softmax_dim] >= 0 or out_dim_map[softmax_dim] >= 0
    ):
        raise RuntimeError("Cannot run _softmax_backward_data on sharding dimension!")
    return pointwise_rule(op_schema)


@register_prop_rule([aten.mean.default, aten.mean.dim, aten.mean.out])
def mean_rule(op_schema: OpSchema) -> OutputSharding:
    args_schema = op_schema.args_schema
    input_spec = cast(DTensorSpec, args_schema[0])
    dims = None
    # if length of args > 1, we check args to find dims
    if len(args_schema) > 1:
        dims = _infer_reduction_dims(args_schema[1], input_spec.ndim)

    keep_dim = len(args_schema) > 2 and bool(args_schema[2])
    return reduction_rule(
        op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=False
    )


@register_prop_rule(
    [
        aten.var.default,
        aten.var.dim,
        aten.var.out,
    ]
)
def var_rule(op_schema: OpSchema) -> OutputSharding:
    args_schema = op_schema.args_schema
    input_spec = cast(DTensorSpec, args_schema[0])
    dims = None
    # if length of args > 1, we check args to find dims, note that
    # var.default have unbias arg as the first argument, so we want
    # to check if it's not bool
    if len(args_schema) > 1 and not isinstance(args_schema[1], bool):
        dims = _infer_reduction_dims(args_schema[1], input_spec.ndim)

    keep_dim = len(args_schema) > 3 and bool(args_schema[3])
    return reduction_rule(
        op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=False
    )


@register_prop_rule([aten.var.correction, aten.var.correction_out])
def var_correction_rule(op_schema: OpSchema) -> OutputSharding:
    args_schema = op_schema.args_schema
    input_spec = cast(DTensorSpec, args_schema[0])
    dims = None
    if len(args_schema) > 1:
        dims = _infer_reduction_dims(args_schema[1], input_spec.ndim)

    # keep_dim is a kwarg instead of arg for var.correction
    keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False))
    return reduction_rule(
        op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=False
    )