Repository URL to install this package:
|
Version:
0.16.2 ▾
|
torchvision
/
_meta_registrations.py
|
|---|
import functools
import torch
import torch.library
# Ensure that torch.ops.torchvision is visible
import torchvision.extension # noqa: F401
@functools.lru_cache(None)
def get_meta_lib():
return torch.library.Library("torchvision", "IMPL", "Meta")
def register_meta(op_name, overload_name="default"):
def wrapper(fn):
if torchvision.extension._has_ops():
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
return fn
return wrapper
@register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
_, channels, height, width = input.size()
return input.new_empty((num_rois, channels, pooled_height, pooled_width))
@register_meta("_roi_align_backward")
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))