!11327 Dynamic shape support for Conv2d Op

From: @tom__chen
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-17 22:46:18 +08:00 committed by Gitee
commit 8008843562
6 changed files with 399 additions and 77 deletions

View File

@ -31,31 +31,7 @@ namespace kernel {
template <typename T>
class Conv2dGpuFwdKernel : public GpuKernel {
public:
Conv2dGpuFwdKernel()
: cudnn_handle_(nullptr),
input_desc_(nullptr),
output_desc_(nullptr),
filter_desc_(nullptr),
conv_desc_(nullptr),
padded_desc_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
compute_format_(CUDNN_TENSOR_NCHW),
old_height_(0),
old_width_(0),
pad_height_(0),
pad_width_(0),
pad_top_(0),
pad_left_(0),
n_(0),
c_(0),
group_(1),
is_null_input_(false),
input_size_(0),
filter_size_(0),
output_size_(0),
padded_size_(0),
workspace_size_(0),
use_pad_(true) {}
Conv2dGpuFwdKernel() { ResetResource(); }
~Conv2dGpuFwdKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -194,6 +170,38 @@ class Conv2dGpuFwdKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
input_desc_ = nullptr;
output_desc_ = nullptr;
filter_desc_ = nullptr;
conv_desc_ = nullptr;
padded_desc_ = nullptr;
cudnn_data_type_ = CUDNN_DATA_FLOAT;
compute_format_ = CUDNN_TENSOR_NCHW;
old_height_ = 0;
old_width_ = 0;
pad_height_ = 0;
pad_width_ = 0;
pad_top_ = 0;
pad_left_ = 0;
n_ = 0;
c_ = 0;
stride_.clear();
dilation_.clear();
group_ = 1;
is_null_input_ = false;
input_size_ = 0;
filter_size_ = 0;
output_size_ = 0;
padded_size_ = 0;
workspace_size_ = 0;
use_pad_ = true;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");

View File

@ -57,6 +57,8 @@ AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const Primitiv
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <cmath>
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
@ -267,6 +268,194 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
return std::make_shared<AbstractTuple>(rets);
}
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());
ShapeVector x_shape = input_x->shape()->shape();
ShapeVector x_min_shape = input_x->shape()->min_shape();
ShapeVector x_max_shape = input_x->shape()->max_shape();
(void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(input_w);
MS_EXCEPTION_IF_NULL(input_w->shape());
ShapeVector w_shape = input_w->shape()->shape();
ShapeVector w_min_shape = input_w->shape()->min_shape();
ShapeVector w_max_shape = input_w->shape()->max_shape();
(void)CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape);
std::set<std::string> available_data_format{"NCHW", "NHWC"};
int64_t n_axis = 0;
int64_t c_axis = 1;
int64_t h_axis = 2;
int64_t w_axis = 3;
auto data_format_ptr = primitive->GetAttr("format");
std::string data_format = "NCHW";
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) {
data_format = data_format_ptr->cast<StringImmPtr>()->value();
if (available_data_format.find(data_format) == available_data_format.end()) {
MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ". use NCHW or NHWC";
}
if (data_format == "NHWC") {
c_axis = 3;
h_axis = 1;
w_axis = 2;
}
}
int64_t group = primitive->GetAttr("group")->cast<Int64ImmPtr>()->value();
if (group <= 0) {
MS_LOG(EXCEPTION) << "Invalid group value: " << group << ", should be greater then 0";
}
if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) {
MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis]
<< " (channels) must be divisible by group = " << group;
}
int64_t out_channel = primitive->GetAttr("out_channel")->cast<Int64ImmPtr>()->value();
if (out_channel <= 0) {
MS_LOG(EXCEPTION) << "Invalid out_channel value: " << out_channel << ", should be greater then 0";
}
if ((w_shape[0] != Shape::SHP_ANY) && (w_shape[0] != out_channel)) {
MS_LOG(EXCEPTION) << "w_shape[0] = " << w_shape[0] << " must equal to = " << out_channel;
}
int64_t kernel_h = 0;
int64_t kernel_w = 0;
ValuePtr kernel_size_attr = primitive->GetAttr("kernel_size");
MS_EXCEPTION_IF_NULL(kernel_size_attr);
if (kernel_size_attr->isa<ValueTuple>()) {
std::vector<ValuePtr> kernel_size_vec = kernel_size_attr->cast<ValueTuplePtr>()->value();
kernel_h = GetValue<int64_t>(kernel_size_vec[0]);
kernel_w = GetValue<int64_t>(kernel_size_vec[1]);
} else {
int64_t kernel_size = kernel_size_attr->cast<Int64ImmPtr>()->value();
kernel_h = kernel_size;
kernel_w = kernel_size;
}
if ((w_shape[2] != Shape::SHP_ANY) && (w_shape[2] != kernel_h)) {
MS_LOG(EXCEPTION) << "weight height, w_shape[2] = " << w_shape[2] << ", must equal to = " << kernel_h;
}
if ((w_shape[3] != Shape::SHP_ANY) && (w_shape[3] != kernel_w)) {
MS_LOG(EXCEPTION) << "weight width, w_shape[3] = " << w_shape[3] << ", must equal to = " << kernel_w;
}
int64_t stride_h = 0;
int64_t stride_w = 0;
ValuePtr stride_attr = primitive->GetAttr("stride");
MS_EXCEPTION_IF_NULL(stride_attr);
if (stride_attr->isa<ValueTuple>()) {
std::vector<ValuePtr> stride_vec = stride_attr->cast<ValueTuplePtr>()->value();
stride_h = GetValue<int64_t>(stride_vec[2]);
stride_w = GetValue<int64_t>(stride_vec[3]);
} else {
int64_t stride = stride_attr->cast<Int64ImmPtr>()->value();
stride_h = stride;
stride_w = stride;
}
int64_t dilation_h = 0;
int64_t dilation_w = 0;
ValuePtr dilation_attr = primitive->GetAttr("dilation");
MS_EXCEPTION_IF_NULL(dilation_attr);
if (dilation_attr->isa<ValueTuple>()) {
std::vector<ValuePtr> dilation_vec = dilation_attr->cast<ValueTuplePtr>()->value();
dilation_h = GetValue<int64_t>(dilation_vec[2]);
dilation_w = GetValue<int64_t>(dilation_vec[3]);
} else {
int64_t dilation = dilation_attr->cast<Int64ImmPtr>()->value();
dilation_h = dilation;
dilation_w = dilation;
}
std::vector<int64_t> padding;
ValuePtr padding_attr = primitive->GetAttr("pad");
MS_EXCEPTION_IF_NULL(padding_attr);
if (padding_attr->isa<ValueTuple>()) {
std::vector<ValuePtr> padding_vec = padding_attr->cast<ValueTuplePtr>()->value();
(void)std::transform(std::begin(padding_vec), std::end(padding_vec), std::back_inserter(padding),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
} else {
int64_t padding_val = padding_attr->cast<Int64ImmPtr>()->value();
padding = {padding_val, padding_val, padding_val, padding_val};
}
std::set<std::string> available_pad_mode{"pad", "same", "valid"};
ValuePtr pad_mode_attr = primitive->GetAttr("pad_mode");
MS_EXCEPTION_IF_NULL(pad_mode_attr);
auto pad_mode = pad_mode_attr->cast<StringImmPtr>()->value();
if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) {
MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid";
}
std::function<void(int64_t, int64_t, std::vector<int64_t> &, std::vector<int64_t> &)> pad_function =
[kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, pad_mode, padding](
int64_t x_h, int64_t x_w, std::vector<int64_t> &output_hw, std::vector<int64_t> &pad_list) {
if (pad_mode == "valid") {
output_hw.push_back(std::ceil(((x_h * 1.0) - dilation_h * (kernel_h - 1)) / stride_h));
output_hw.push_back(std::ceil(((x_w * 1.0) - dilation_w * (kernel_w - 1)) / stride_w));
pad_list = {0, 0, 0, 0};
} else if (pad_mode == "same") {
output_hw.push_back(std::ceil((x_h * 1.0) / stride_h));
output_hw.push_back(std::ceil((x_w * 1.0) / stride_w));
int64_t pad_needed_h = (output_hw[0] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_h;
pad_needed_h = std::max((int64_t)0, pad_needed_h);
pad_list.push_back(std::floor(pad_needed_h / 2));
pad_list.push_back(pad_needed_h - pad_list[0]);
int64_t pad_needed_w = (output_hw[1] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_w;
pad_needed_w = std::max((int64_t)0, pad_needed_w);
pad_list.push_back(std::floor(pad_needed_w / 2));
pad_list.push_back(pad_needed_w - pad_list[2]);
} else if (pad_mode == "pad") {
pad_list = padding;
output_hw.push_back(std::floor(
1 + ((x_h * 1.0) + pad_list[0] + pad_list[1] - kernel_h - (kernel_h - 1) * (dilation_h - 1)) / stride_h));
output_hw.push_back(std::floor(
1 + ((x_w * 1.0) + pad_list[2] + pad_list[3] - kernel_w - (kernel_w - 1) * (dilation_w - 1)) / stride_w));
}
};
std::vector<int64_t> output_hw;
std::vector<int64_t> pad_list;
std::vector<int64_t> output_hw_min;
std::vector<int64_t> pad_list_min;
std::vector<int64_t> output_hw_max;
std::vector<int64_t> pad_list_max;
pad_function(x_shape[h_axis], x_shape[w_axis], output_hw, pad_list);
if (x_shape[h_axis] == Shape::SHP_ANY) {
output_hw[0] = Shape::SHP_ANY;
}
if (x_shape[w_axis] == Shape::SHP_ANY) {
output_hw[1] = Shape::SHP_ANY;
}
pad_function(x_min_shape[h_axis], x_min_shape[w_axis], output_hw_min, pad_list_min);
pad_function(x_max_shape[h_axis], x_max_shape[w_axis], output_hw_max, pad_list_max);
std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
MakeValue(pad_list[3])};
primitive->set_attr("pad_list", MakeValue(pad_list_val));
ShapeVector output_shape;
ShapeVector output_shape_min;
ShapeVector output_shape_max;
if (data_format == "NHWC") {
output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel};
output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel};
output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel};
} else {
output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]};
output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]};
output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]};
}
ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max);
return std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr);
}
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors(doutput, input, filters).

View File

@ -111,6 +111,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2D, {InferImplConv2D, true}},
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},

View File

@ -1207,7 +1207,7 @@ class BatchNorm(PrimitiveWithInfer):
return (input_x, scale, bias, input_x, input_x)
class Conv2D(PrimitiveWithInfer):
class Conv2D(PrimitiveWithCheck):
r"""
2D convolution layer.
@ -1314,65 +1314,16 @@ class Conv2D(PrimitiveWithInfer):
self.add_prim_attr('groups', self.group)
self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape, b_shape=None):
def check_shape(self, x_shape, w_shape, b_shape=None):
x_shape_norm = x_shape if self.format == "NCHW" else (x_shape[0], x_shape[3], x_shape[1], x_shape[2])
w_shape_norm = w_shape if self.format == "NCHW" else (w_shape[0], w_shape[3], w_shape[1], w_shape[2])
validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name)
validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \
Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name)
kernel_size_h = w_shape_norm[2]
kernel_size_w = w_shape_norm[3]
stride_h = self.stride[2]
stride_w = self.stride[3]
dilation_h = self.dilation[2]
dilation_w = self.dilation[3]
if self.pad_mode == "valid":
h_out = math.ceil((x_shape_norm[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
w_out = math.ceil((x_shape_norm[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
elif self.pad_mode == "same":
h_out = math.ceil(x_shape_norm[2] / stride_h)
w_out = math.ceil(x_shape_norm[3] / stride_w)
pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape_norm[2])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape_norm[3])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
elif self.pad_mode == 'pad':
pad_top, pad_bottom, pad_left, pad_right = self.padding
h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \
* (dilation_h - 1)) / stride_h
w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \
* (dilation_w - 1)) / stride_w
h_out = math.floor(h_out)
w_out = math.floor(w_out)
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
out_channel = self.out_channel
out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else \
[x_shape_norm[0], h_out, w_out, out_channel]
_check_shape('output', out_shape, self.name)
return out_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
def check_dtype(self, x_dtype, w_dtype, b_dtype=None):
args = {'x': x_dtype, 'w': w_dtype}
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x_dtype
class DepthwiseConv2dNative(PrimitiveWithInfer):

View File

@ -20,6 +20,9 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
class NetConv2d(nn.Cell):
@ -61,3 +64,171 @@ def test_conv2d():
conv2d = NetConv2d()
output = conv2d(x, w)
assert (output.asnumpy() == expect).all()
class NetConv(nn.Cell):
def __init__(self, weight, x):
super(NetConv, self).__init__()
self.conv = nn.Conv2d(in_channels=3,
out_channels=3,
kernel_size=(5, 3),
stride=2,
pad_mode='same',
padding=(0, 0, 0, 0),
dilation=(1, 1),
group=1,
has_bias=False,
weight_init=Tensor(weight)
)
self.x = Parameter(initializer(Tensor(x), [1, 3, 4, 2]), name="x")
def construct(self):
return self.conv(self.x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_conv():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
weight = np.array([[[[0.38968208, 0.14398979, 0.7962463],
[-2.1836321, -0.63823014, -0.50588065],
[0.6660469, 0.64673275, -0.13160042],
[1.3683757, 1.4005762, -0.37235805],
[-0.22638111, 0.45427424, -0.10293389]],
[[1.4985064, -0.29318333, -0.92694616],
[1.539068, 0.8937254, -1.2598171],
[0.9658142, -0.63945454, -0.23185322],
[1.363089, -0.41694695, -2.2750475],
[-0.4865508, -1.6938025, 0.609849]],
[[1.1844803, 0.99874926, -1.9475793],
[0.4987858, 0.5307887, -0.04226681],
[0.4529779, -1.1960793, 0.9456575],
[3.133675, 0.2309789, -0.29201075],
[-0.59632736, -0.0789804, -0.69486314]]],
[[[-0.5606142, 0.6420862, 0.2478745],
[0.02717604, 1.5483379, -0.9373383],
[-1.1017276, -0.259478, 1.0311872],
[1.8387799, 0.16468556, 0.33392152],
[-1.8781787, 1.0158662, 1.6527579]],
[[0.45696944, -0.5652523, -1.5618048],
[-0.30304828, 0.1331878, -0.36955845],
[0.91655576, 0.66612357, 0.3068175],
[-0.45732066, 0.8923335, 1.0542952],
[-0.73519516, 1.0518405, -1.0273266]],
[[-0.79712886, -0.26814285, 0.12779616],
[1.0367643, -1.6180774, 0.42999932],
[-0.81818223, -0.81502074, 0.882194],
[0.53640485, 0.4178927, 1.6037121],
[0.9256354, -1.1006796, 0.16614541]]],
[[[-1.5216796, -1.2473261, 0.6549515],
[0.63627815, 0.7221449, 0.02977821],
[-0.61331123, -0.49451825, 0.33852202],
[1.4510741, -1.3818305, -0.791747],
[0.6989747, 0.49558765, 1.0813237]],
[[-0.03969796, 0.71586496, 0.8326594],
[-0.15443641, 1.0389746, -0.59301984],
[0.7197836, 0.03257621, 1.8398637],
[0.6111736, -0.16166899, -2.4869773],
[1.3066711, -1.8003578, 0.17412892]],
[[-0.31470737, -0.5938182, -1.1311078],
[-0.99081016, 0.4005125, 0.44154453],
[1.0876914, -2.5958562, -0.5914863],
[1.3759689, -0.7741513, 0.19928917],
[1.6792973, 2.2744863, -0.04308867]]]]).astype(np.float32)
x = np.array([[[[-1.4311737, 1.015344],
[0.04431088, -2.2886624],
[1.4832113, 1.240908],
[0.67040104, 0.15266363]],
[[0.44226435, 1.1461105],
[1.194218, 1.5547837],
[0.23152256, 1.5911953],
[0.11206784, 0.17978816]],
[[-0.57803905, 0.8039611],
[0.0823025, -0.6134477],
[-1.4171146, 1.6269946],
[0.48878875, 0.9117505]]]]).astype(np.float32)
conv2d = NetConv(weight, x)
output = conv2d()
expected = np.array([[[[2.3498724],
[-1.9199573]],
[[5.376562],
[-5.425745]],
[[5.9105043],
[7.469034]]]]).astype(np.float32)
loss = np.abs(expected - output.asnumpy())
error = 1e-4 * np.ones(loss.shape)
assert (loss < error).all()
class NetConv2dDynamic(nn.Cell):
def __init__(self, axis=0, out_nums=1):
super(NetConv2dDynamic, self).__init__()
self.dynshape = inner.GpuConvertToDynamicShape()
out_channel = 2
kernel_size = 1
self.conv = P.Conv2D(out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
stride=1,
dilation=1,
group=1)
def construct(self, x, w):
x_dyn = self.dynshape(x)
w_dyn = self.dynshape(w)
x_conv = self.conv(x_dyn, w_dyn)
return x_conv
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_conv2d_dynamic():
x1 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
w1 = Tensor(np.arange(2 * 3 * 1 * 1).reshape(2, 3, 1, 1).astype(np.float32))
expect1 = np.array([[[[45, 48, 51],
[54, 57, 60],
[63, 66, 69]],
[[126, 138, 150],
[162, 174, 186],
[198, 210, 222]]]]).astype(np.float32)
x2 = Tensor(np.arange(5 * 1 * 2 * 2).reshape(5, 1, 2, 2).astype(np.float32))
w2 = Tensor(np.arange(2 * 1 * 1 * 1).reshape(2, 1, 1, 1).astype(np.float32))
expect2 = np.array([[[[0., 0.],
[0., 0.]],
[[0., 1.],
[2., 3.]]],
[[[0., 0.],
[0., 0.]],
[[4., 5.],
[6., 7.]]],
[[[0., 0.],
[0., 0.]],
[[8., 9.],
[10., 11.]]],
[[[0., 0.],
[0., 0.]],
[[12., 13.],
[14., 15.]]],
[[[0., 0.],
[0., 0.]],
[[16., 17.],
[18., 19.]]]]).astype(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
conv2d = NetConv2dDynamic()
output1 = conv2d(x1, w1)
assert (output1.asnumpy() == expect1).all()
output2 = conv2d(x2, w2)
assert (output2.asnumpy() == expect2).all()