Repository URL to install this package:
Version:
2.1.2+cpu ▾
|
# Copyright (c) Meta Platforms, Inc. and affiliates
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import (
ColwiseParallel,
make_input_replicate_1d,
make_input_reshard_replicate,
make_input_shard_1d,
make_input_shard_1d_last_dim,
make_sharded_output_tensor,
make_output_replicate_1d,
make_output_reshard_tensor,
make_output_shard_1d,
make_output_tensor,
PairwiseParallel,
ParallelStyle,
RowwiseParallel,
SequenceParallel,
)
__all__ = [
"ColwiseParallel",
"PairwiseParallel",
"ParallelStyle",
"RowwiseParallel",
"SequenceParallel",
"make_input_replicate_1d",
"make_input_reshard_replicate",
"make_input_shard_1d",
"make_input_shard_1d_last_dim",
"make_sharded_output_tensor",
"make_output_replicate_1d",
"make_output_reshard_tensor",
"make_output_tensor",
"make_output_shard_1d",
"parallelize_module",
]