Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ testing / _internal / distributed / rpc / jit / rpc_test.py

import time
import io
from typing import Dict, List, Tuple, Any

import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.autograd.profiler import record_function
from torch.distributed.rpc import RRef
from torch.distributed.rpc.internal import RPCExecMode, _build_rpc_profiling_key
from torch.futures import Future
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.dist_utils import (
    dist_init,
    get_function_event,
    initialize_pg,
    worker_name,
)
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
    RpcAgentTestFixture,
)

def rref_isinstance(rref, cls_to_check):
    return isinstance(rref.local_value(), cls_to_check)

def sleep(t):
    time.sleep(t)


def rpc_return_rref(dst):
    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))


@torch.jit.script
def rref_local_value(rref: RRef[Tensor]) -> Tensor:
    return rref.local_value()


@torch.jit.script
def list_create() -> List[int]:
    global_list = [1, 2, 3]
    return global_list


@torch.jit.script
def rref_list_mutate(rref: RRef[List[int]]) -> None:
    rref.local_value().append(4)
    rref.to_here().append(5)
    rref.to_here(5.0).append(6)


def return_value(value: int) -> int:
    return value


class RRefAPITest:
    @dist_init
    def test_rref_is_owner(self):
        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
        rref_var = rpc_return_rref(dst_worker_name)

        @torch.jit.script
        def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool:
            return rref_var.is_owner()

        res = rref_tensor_is_owner(rref_var)
        self.assertEqual(res, False)

    @dist_init
    def test_rref_local_value(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
        rref = rpc_return_rref(dst_worker_name)

        with self.assertRaisesRegex(
            RuntimeError, r"Can't call RRef.local_value\(\) on a non-owner RRef"
        ):
            rref_local_value(rref)

        ret = ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,))
        self.assertEqual(ret, torch.add(torch.ones(2, 2), 1))

    @dist_init
    def test_local_rref_local_value(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name(self.rank)
        rref = rpc.remote(dst_worker_name, return_value, (5,), {})

        ret = rref_local_value(rref)
        self.assertEqual(ret, 5)

    def _create_rref(self):
        owner_rank = (self.rank + 2) % self.world_size
        return rpc.remote(
            worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
        )

    @dist_init
    def test_user_rrefs_confirmed(self):
        dst_rank = (self.rank + 1) % self.world_size
        rref = self._create_rref()
        ret = rpc.rpc_sync(
            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
        )
        self.assertEqual(ret, True)

    @dist_init
    def test_user_rrefs_confirmed_remote(self):
        dst_rank = (self.rank + 1) % self.world_size
        rref = self._create_rref()
        ret_rref = rpc.remote(
            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
        )
        self.assertEqual(ret_rref.to_here(), True)

    @dist_init
    def test_rref_list_mutate(self):
        dst = worker_name((self.rank + 1) % self.world_size)
        list_rref = rpc.remote(dst, list_create)

        rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,))
        self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6])


@torch.jit.script
def no_arg():
    return 0


@torch.jit.script
def one_arg(value):
    return value + 1

@torch.jit.script
def script_add_ones(x):
    return torch.add(x, torch.ones(1))

@torch.jit.script
def script_add_ones_with_record_function(x, block: str):
    with record_function(block):
        return torch.add(x, torch.ones(1))


@torch.jit.script
def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
    t: Tensor = torch.ones(1)
    with record_function(block) as rf:
        fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
        fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
        res = fut1.wait() + fut2.wait()
    return res



@torch.jit.script
def script_fork_wait_udf(tensor):
    fut = torch.jit._fork(script_add_ones, tensor)
    x = torch.jit._wait(fut)
    return x


@torch.jit.script
def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
    return rref_var.to_here()


@torch.jit.script
def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]:
    return rref_var


@torch.jit.script
def script_raise_func(value):
    if value.numel() == 2:
        raise ValueError("Expected error")
    return value + 1


@torch.jit.script
def script_fork_wait_throw(invalue):
    fut = torch.jit._fork(script_raise_func, invalue)
    value = torch.jit._wait(fut)
    return value


@torch.jit.script
def call_rpc_with_profiling(handle: Tensor, dst_worker_name: str) -> Tensor:
    # Call rpc_async from within ScriptFunction and ensure that we can attach
    # profiling callbacks. Note that handle here is a Tensor representation of
    # RecordFunction.
    fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),))
    torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
    ret = fut.wait()
    return ret

@torch.jit.script
def call_rpc_torchscript_with_record_function(dst_worker_name: str, block: str) -> Tensor:
    fut = rpc.rpc_async(dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block))
    return fut.wait()


@torch.jit.script
def call_fork_with_profiling(handle: Tensor) -> Tensor:
    # Call fork from within ScriptFunction and ensure that we can attach profiling
    # callbacks to the resulting future. Note that handle here is a Tensor
    # representation of RecordFunction.
    fut = torch.jit._fork(one_arg, torch.tensor(1))
    torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
    ret = fut.wait()
    return ret


class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
    def __init__(self, dst_worker):
        super().__init__()
        self.rrefs = []
        for _ in range(4):
            self.rrefs.append(rpc_return_rref(dst_worker))

    @torch.jit.script_method
    def forward(self) -> Tensor:
        res_tensor = torch.ones(2, 2)
        for rref in self.rrefs:
            res_tensor += rref.to_here()

        return res_tensor


@torch.jit.ignore
def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]:
    return rref_var


@torch.jit.script
def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor:
    return rref_python_annotation(rref_var).to_here()


class RRefTypingTest:
    @dist_init
    def test_rref_as_arg_and_return(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        local_ret = one_arg(torch.ones(2, 2))

        # create rref on current rank
        rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),))

        # pass rref to another user in rpc call
        ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,))
        self.assertEqual(ret, local_ret)

        # return rref in rpc call
        rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,))
        self.assertEqual(rref1.to_here(), local_ret)

        # pass rref to another user in remote call
        rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,))
        self.assertEqual(rref2.to_here(), local_ret)

        # return rref in remote call
        rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,))
        self.assertEqual(rref3.to_here().to_here(), local_ret)

    @dist_init
    def test_my_script_module_with_rrefs(self):
        n = self.rank + 1
        dst_rank = n % self.world_size

        module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank))
        res = module_with_rrefs()
        self.assertEqual(res, torch.ones(2, 2) * 9)

    @dist_init
    def test_rref_python_annotation(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        rref_var = rpc_return_rref(worker_name(dst_rank))

        res = rref_script_annotation(rref_var)
        self.assertEqual(res, torch.ones(2, 2) + 1)


class FutureTypingTest:
    @dist_init
    def test_future_passed_between_python_and_jit(self):
        dst_rank = (self.rank + 1) % self.world_size
        inputs = (torch.tensor([1, 1]), torch.tensor([2, 2]))
        ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs)
        expected_res = torch.tensor([10, 10])

        @torch.jit.script
        def future_wait_in_script(fut: Future[Tensor]) -> Tensor:
            return fut.wait()

        self.assertEqual(future_wait_in_script(ret_fut), expected_res)

        @torch.jit.script
        def future_return_to_python(
            dst_rank: int, inputs: Tuple[Tensor, Tensor]
        ) -> Future[Tensor]:
            return rpc.rpc_async(
                "worker{}".format(dst_rank), two_args_two_kwargs, inputs
            )

        fut_res = future_return_to_python(dst_rank, inputs)
        self.assertEqual(fut_res.wait(), expected_res)

    @dist_init
    def test_future_python_annotation(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
        input_0 = torch.ones(2, 2)
        input_1 = 1
        expected_res = torch.add(input_0, input_1)

        @torch.jit.ignore
        def python_return_future() -> Future[Tensor]:
            fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {})
            return fut

        @torch.jit.script
        def script_use_future() -> Tensor:
            fut = python_return_future()
            return fut.wait()

        res = script_use_future()
        self.assertEqual(res, expected_res)


@torch.jit.script
class MyScriptClass:
    def __init__(self, a: int):
        self.a = a

    def get_value(self) -> int:
        return self.a
Loading ...