import collections
import functools
import unittest
import caffe2.python._import_c_extension as C
import caffe2.python.hip_test_util as hiputl
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
import hypothesis.strategies as st
import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import brew, core, utils, workspace
from caffe2.python.model_helper import ModelHelper
from hypothesis import assume, given, settings
def _cudnn_supports(dilation=False, nhwc=False, backward=False):
"""Return True if cuDNN supports this configuration."""
v = workspace.GetCuDNNVersion()
if backward:
if nhwc:
# nhwc isn't supported in backward ops.
return False
else:
# Forward mode.
if dilation and v < 6000:
# Dilation not supported until v6
return False
if dilation and nhwc:
# Dilation and NHWC not supported together
return False
return True
def _cudnn_convolution_algo_count(direction):
try:
if direction == "fwd":
return st.integers(0, C.cudnn_convolution_fwd_algo_count - 1)
elif direction == "dgrad":
return st.integers(0, C.cudnn_convolution_bwd_data_algo_count - 1)
elif direction == "wgrad":
return st.integers(0, C.cudnn_convolution_bwd_filter_algo_count - 1)
else:
assert False
except Exception:
return st.sampled_from([-1])
class TestConvolution(serial.SerializedTestCase):
# CUDNN does NOT support different padding values and we skip it
@given(
op_type=st.sampled_from(["Conv", "Conv2D"]),
stride_h=st.integers(1, 3),
stride_w=st.integers(1, 3),
pad_t=st.integers(0, 3),
pad_l=st.integers(0, 3),
pad_b=st.integers(0, 3),
pad_r=st.integers(0, 3),
kernel=st.integers(3, 5),
size=st.integers(1, 8),
input_channels=st.integers(1, 3),
output_channels=st.integers(1, 3),
batch_size=st.integers(0, 3),
group=st.integers(1, 2),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "EIGEN"]),
shared_buffer=st.booleans(),
use_bias=st.booleans(),
**hu.gcs
)
@settings(deadline=None, max_examples=50)
def test_convolution_separate_stride_pad_gradients(
self,
op_type,
stride_h,
stride_w,
pad_t,
pad_l,
pad_b,
pad_r,
kernel,
size,
input_channels,
output_channels,
batch_size,
group,
order,
engine,
shared_buffer,
use_bias,
gc,
dc,
):
# TODO: Group conv in NHWC not implemented for GPU yet.
assume(group == 1 or order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
if group != 1 and order == "NHWC":
dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
# Group conv not implemented with EIGEN engine.
assume(group == 1 or engine != "EIGEN")
input_channels *= group
output_channels *= group
op = core.CreateOperator(
op_type,
["X", "w", "b"] if use_bias else ["X", "w"],
["Y"],
stride_h=stride_h,
stride_w=stride_w,
pad_t=pad_t,
pad_l=pad_l,
pad_b=pad_b,
pad_r=pad_r,
kernel=kernel,
group=group,
order=order,
engine=engine,
shared_buffer=int(shared_buffer),
)
X = (
np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
- 0.5
)
w = (
np.random.rand(
output_channels, kernel, kernel, int(input_channels / group)
).astype(np.float32)
- 0.5
)
b = np.random.rand(output_channels).astype(np.float32) - 0.5
if order == "NCHW":
X = utils.NHWC2NCHW(X)
w = utils.NHWC2NCHW(w)
inputs = [X, w, b] if use_bias else [X, w]
# Error handling path.
if size + pad_r + pad_l < kernel or size + pad_t + pad_b < kernel:
with self.assertRaises(RuntimeError):
self.assertDeviceChecks(dc, op, inputs, [0])
return
self.assertDeviceChecks(dc, op, inputs, [0])
for i in range(len(inputs)):
self.assertGradientChecks(gc, op, inputs, i, [0])
# CUDNN does NOT support different padding values and we skip it
@given(
op_type=st.sampled_from(["Conv", "Conv2D"]),
stride_h=st.integers(1, 3),
stride_w=st.integers(1, 3),
pad_t=st.integers(0, 3),
pad_l=st.integers(0, 3),
pad_b=st.integers(0, 3),
pad_r=st.integers(0, 3),
kernel=st.integers(1, 5),
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(0, 3),
engine=st.sampled_from(["", "EIGEN"]),
use_bias=st.booleans(),
**hu.gcs
)
@settings(deadline=1000)
def test_convolution_separate_stride_pad_layout(
self,
op_type,
stride_h,
stride_w,
pad_t,
pad_l,
pad_b,
pad_r,
kernel,
size,
input_channels,
output_channels,
batch_size,
engine,
use_bias,
gc,
dc,
):
X = (
np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
- 0.5
)
w = (
np.random.rand(output_channels, kernel, kernel, input_channels).astype(
np.float32
)
- 0.5
)
b = np.random.rand(output_channels).astype(np.float32) - 0.5
outputs = {}
for order in ["NCHW", "NHWC"]:
op = core.CreateOperator(
op_type,
["X", "w", "b"] if use_bias else ["X", "w"],
["Y"],
stride_h=stride_h,
stride_w=stride_w,
kernel=kernel,
pad_t=pad_t,
pad_l=pad_l,
pad_b=pad_b,
pad_r=pad_r,
order=order,
engine=engine,
device_option=gc,
)
if order == "NCHW":
X_f = utils.NHWC2NCHW(X)
w_f = utils.NHWC2NCHW(w)
else:
X_f = X
w_f = w
self.ws.create_blob("X").feed(X_f, device_option=gc)
self.ws.create_blob("w").feed(w_f, device_option=gc)
self.ws.create_blob("b").feed(b, device_option=gc)
self.ws.run(op)
outputs[order] = self.ws.blobs["Y"].fetch()
np.testing.assert_allclose(
outputs["NCHW"], utils.NHWC2NCHW(outputs["NHWC"]), atol=1e-4, rtol=1e-4
)
@given(
op_type=st.sampled_from(["Conv", "Conv2D"]),
stride=st.integers(1, 3),
pad=st.integers(0, 3),
kernel=st.integers(1, 5),
dilation=st.integers(1, 3),
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(0, 3),
group=st.integers(1, 2),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "CUDNN", "MKLDNN"]),
use_bias=st.booleans(),
force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
**hu.gcs
)
@settings(max_examples=20, deadline=None)
def test_convolution_gradients(
self,
op_type,
stride,
pad,
kernel,
dilation,
size,
input_channels,
output_channels,
batch_size,
group,
order,
engine,
use_bias,
force_algo_fwd,
force_algo_dgrad,
force_algo_wgrad,
gc,
dc,
):
# TODO: Group conv in NHWC not implemented for GPU yet.
assume(
group == 1
or (order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
and engine != "MKLDNN"
)
if group != 1 and order == "NHWC":
dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
input_channels *= group
output_channels *= group
dkernel = dilation * (kernel - 1) + 1
if engine == "CUDNN":
if hiputl.run_in_hip(gc, dc):
assume((order == "NCHW") and not (dilation > 1 and group > 1))
else:
assume(
_cudnn_supports(
dilation=(dilation > 1), nhwc=(order == "NHWC"), backward=True
)
)
assume(engine != "MKLDNN" or use_bias is True)
op = core.CreateOperator(
op_type,
["X", "w", "b"] if use_bias else ["X", "w"],
["Y"],
stride=stride,
kernel=kernel,
dilation=dilation,
pad=pad,
group=group,
order=order,
engine=engine,
force_algo_fwd=force_algo_fwd,
force_algo_dgrad=force_algo_dgrad,
force_algo_wgrad=force_algo_wgrad,
)
X = (
np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
- 0.5
)
w = (
np.random.rand(
output_channels, kernel, kernel, int(input_channels / group)
).astype(np.float32)
- 0.5
)
b = np.random.rand(output_channels).astype(np.float32) - 0.5
if order == "NCHW":
X = utils.NHWC2NCHW(X)
w = utils.NHWC2NCHW(w)
inputs = [X, w, b] if use_bias else [X, w]
# Error handling path.
if size + pad + pad < dkernel or size + pad + pad < dkernel:
with self.assertRaises(RuntimeError):
self.assertDeviceChecks(dc, op, inputs, [0])
return
try:
self.assertDeviceChecks(dc, op, inputs, [0])
except RuntimeError as e:
es = str(e)
# CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM should always have
# implementation
if (
"status == CUDNN_STATUS_SUCCESS" not in es
or "CUDNN_STATUS_NOT_SUPPORTED" not in es
or force_algo_fwd == 0
):
raise e
Loading ...