diff --git a/mindspore/core/abstract/param_validator.cc b/mindspore/core/abstract/param_validator.cc index 69f30b6ea9b..0a88482b73e 100644 --- a/mindspore/core/abstract/param_validator.cc +++ b/mindspore/core/abstract/param_validator.cc @@ -16,6 +16,8 @@ #include "abstract/param_validator.h" +#include +#include #include #include #include @@ -143,5 +145,68 @@ void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBas MS_EXCEPTION_IF_NULL(args_spec_list[i]); } } + +void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape) { + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] < 0) { + MS_LOG(EXCEPTION) << op << " shape element [" << i << "] must be positive integer, but got " << shape[i]; + } + } +} + +void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) { + for (size_t i = 0; i < shape.size(); ++i) { + if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) { + MS_LOG(EXCEPTION) << op << " shape element [" << i << "] must be positive integer or SHP_ANY, but got " + << shape[i]; + } + } +} + +int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) { + int64_t attr_val = attr->cast()->value(); + if (attr_val <= 0) { + MS_LOG(EXCEPTION) << "Invalid " << attr_name << " value: " << attr_val << ", should be greater then 0"; + } + return attr_val; +} + +std::vector CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx, + const size_t num_element) { + std::vector result; + MS_EXCEPTION_IF_NULL(attr); + if (attr->isa()) { + std::vector attr_vec = attr->cast()->value(); + auto it_start = attr_vec.begin() + start_idx; + (void)std::transform(it_start, it_start + num_element, std::back_inserter(result), + [](const ValuePtr &e) -> int64_t { return GetValue(e); }); + } else { + int64_t attr_val = attr->cast()->value(); + result.insert(result.begin(), num_element, attr_val); + } + return result; +} + +std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name, + const std::set &val_set) { + MS_EXCEPTION_IF_NULL(attr); + std::string attr_val = attr->cast()->value(); + if (val_set.find(attr_val) == val_set.end()) { + std::ostringstream buffer; + bool f_begin = true; + buffer << "{"; + for (auto &x : val_set) { + if (!f_begin) { + buffer << ", "; + } else { + f_begin = false; + } + buffer << x; + } + buffer << "}"; + MS_LOG(EXCEPTION) << op << "Unsupported " << attr_name << ": " << attr_val << ". use " << buffer.str(); + } + return attr_val; +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/param_validator.h b/mindspore/core/abstract/param_validator.h index 49b7b93aa91..838df0d45ac 100644 --- a/mindspore/core/abstract/param_validator.h +++ b/mindspore/core/abstract/param_validator.h @@ -18,6 +18,7 @@ #define MINDSPORE_CORE_ABSTRACT_PARAM_VALIDATOR_H_ #include +#include #include #include #include @@ -48,6 +49,18 @@ int64_t CheckAxis(const std::string &op, const ValuePtr &axis, int64_t min, int6 void CheckArgsSize(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t size_expect); +void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape); + +void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape); + +int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name); + +std::vector CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx, + const size_t num_element); + +std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name, + const std::set &val_set); + template struct ReportNameTraits {}; diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index e2cd0be0b58..1aab6623760 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -268,200 +268,109 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit return std::make_shared(rets); } +void Conv2DPadFunction(std::vector *output_hw, std::vector *pad_list, const int64_t x_h, + const int64_t x_w, const std::vector &kernel, const std::vector &stride, + const std::vector &dilation, const std::string &pad_mode, + const std::vector &padding) { + if (pad_mode == "valid") { + output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])); + output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])); + pad_list->insert(pad_list->begin(), 4, 0); + } else if (pad_mode == "same") { + output_hw->push_back(std::ceil((x_h * 1.0) / stride[0])); + output_hw->push_back(std::ceil((x_w * 1.0) / stride[1])); + int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; + pad_needed_h = std::max((int64_t)0, pad_needed_h); + pad_list->push_back(std::floor(pad_needed_h / 2)); + pad_list->push_back(pad_needed_h - pad_list->at(0)); + int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w; + pad_needed_w = std::max((int64_t)0, pad_needed_w); + pad_list->push_back(std::floor(pad_needed_w / 2)); + pad_list->push_back(pad_needed_w - pad_list->at(2)); + } else if (pad_mode == "pad") { + pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); + output_hw->push_back(std::floor( + 1 + + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) / stride[0])); + output_hw->push_back(std::floor( + 1 + + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) / stride[1])); + } +} + AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 2); - AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(input_x); MS_EXCEPTION_IF_NULL(input_x->shape()); ShapeVector x_shape = input_x->shape()->shape(); ShapeVector x_min_shape = input_x->shape()->min_shape(); ShapeVector x_max_shape = input_x->shape()->max_shape(); - (void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); - for (size_t i = 0; i < x_shape.size(); ++i) { - if ((x_shape[i] < 0) && (x_shape[i] != Shape::SHP_ANY)) { - MS_LOG(EXCEPTION) << "Shape element x_shape[" << i << "] must be positive integer, but got " << x_shape[i]; - } - if (x_min_shape[i] < 0) { - MS_LOG(EXCEPTION) << "Min Shape element x_min_shape[" << i << "] must be positive integer, but got " - << x_min_shape[i]; - } - if (x_max_shape[i] < 0) { - MS_LOG(EXCEPTION) << "Max Shape element x_max_shape[" << i << "] must be positive integer, but got " - << x_max_shape[i]; - } - } - + CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); + CheckShapeAnyAndPositive(op_name + " x_shape", x_shape); + CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape); + CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape); AbstractTensorPtr input_w = CheckArg(op_name, args_spec_list, 1); MS_EXCEPTION_IF_NULL(input_w); MS_EXCEPTION_IF_NULL(input_w->shape()); ShapeVector w_shape = input_w->shape()->shape(); ShapeVector w_min_shape = input_w->shape()->min_shape(); ShapeVector w_max_shape = input_w->shape()->max_shape(); - (void)CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape); - for (size_t i = 0; i < w_shape.size(); ++i) { - if ((w_shape[i] < 0) && (w_shape[i] != Shape::SHP_ANY)) { - MS_LOG(EXCEPTION) << "Shape element w_shape[" << i << "] must be positive integer, but got " << w_shape[i]; - } - if (w_min_shape[i] < 0) { - MS_LOG(EXCEPTION) << "Min Shape element w_min_shape[" << i << "] must be positive integer, but got " - << w_min_shape[i]; - } - if (w_max_shape[i] < 0) { - MS_LOG(EXCEPTION) << "Max Shape element w_max_shape[" << i << "] must be positive integer, but got " - << w_max_shape[i]; - } - } - - std::set available_data_format{"NCHW", "NHWC"}; + CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape); + CheckShapeAnyAndPositive(op_name + " w_shape", w_shape); + CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape); + CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); + std::string data_format = CheckAttrStringSet(op_name, primitive->GetAttr("format"), "format", {"NCHW", "NHWC"}); int64_t n_axis = 0; int64_t c_axis = 1; int64_t h_axis = 2; int64_t w_axis = 3; - auto data_format_ptr = primitive->GetAttr("format"); - std::string data_format = "NCHW"; - if ((data_format_ptr != nullptr) && data_format_ptr->isa()) { - data_format = data_format_ptr->cast()->value(); - if (available_data_format.find(data_format) == available_data_format.end()) { - MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ". use NCHW or NHWC"; - } - if (data_format == "NHWC") { - c_axis = 3; - h_axis = 1; - w_axis = 2; - } - } - - int64_t group = primitive->GetAttr("group")->cast()->value(); - if (group <= 0) { - MS_LOG(EXCEPTION) << "Invalid group value: " << group << ", should be greater then 0"; + if (data_format == "NHWC") { + c_axis = 3; + h_axis = 1; + w_axis = 2; } + int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group"); if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) { MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis] << " (channels) must be divisible by group = " << group; } - - int64_t out_channel = primitive->GetAttr("out_channel")->cast()->value(); - if (out_channel <= 0) { - MS_LOG(EXCEPTION) << "Invalid out_channel value: " << out_channel << ", should be greater then 0"; + int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel"); + if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) { + MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must equal to = " << out_channel; } - if ((w_shape[0] != Shape::SHP_ANY) && (w_shape[0] != out_channel)) { - MS_LOG(EXCEPTION) << "w_shape[0] = " << w_shape[0] << " must equal to = " << out_channel; + std::vector kernel_size = CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, 2); + if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) { + MS_LOG(EXCEPTION) << "weight height = " << w_shape[h_axis] << ", must equal to = " << kernel_size[0]; } - - int64_t kernel_h = 0; - int64_t kernel_w = 0; - ValuePtr kernel_size_attr = primitive->GetAttr("kernel_size"); - MS_EXCEPTION_IF_NULL(kernel_size_attr); - if (kernel_size_attr->isa()) { - std::vector kernel_size_vec = kernel_size_attr->cast()->value(); - kernel_h = GetValue(kernel_size_vec[0]); - kernel_w = GetValue(kernel_size_vec[1]); - } else { - int64_t kernel_size = kernel_size_attr->cast()->value(); - kernel_h = kernel_size; - kernel_w = kernel_size; + if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) { + MS_LOG(EXCEPTION) << "weight width = " << w_shape[w_axis] << ", must equal to = " << kernel_size[1]; } - if ((w_shape[2] != Shape::SHP_ANY) && (w_shape[2] != kernel_h)) { - MS_LOG(EXCEPTION) << "weight height, w_shape[2] = " << w_shape[2] << ", must equal to = " << kernel_h; - } - if ((w_shape[3] != Shape::SHP_ANY) && (w_shape[3] != kernel_w)) { - MS_LOG(EXCEPTION) << "weight width, w_shape[3] = " << w_shape[3] << ", must equal to = " << kernel_w; - } - - int64_t stride_h = 0; - int64_t stride_w = 0; - ValuePtr stride_attr = primitive->GetAttr("stride"); - MS_EXCEPTION_IF_NULL(stride_attr); - if (stride_attr->isa()) { - std::vector stride_vec = stride_attr->cast()->value(); - stride_h = GetValue(stride_vec[2]); - stride_w = GetValue(stride_vec[3]); - } else { - int64_t stride = stride_attr->cast()->value(); - stride_h = stride; - stride_w = stride; - } - - int64_t dilation_h = 0; - int64_t dilation_w = 0; - ValuePtr dilation_attr = primitive->GetAttr("dilation"); - MS_EXCEPTION_IF_NULL(dilation_attr); - if (dilation_attr->isa()) { - std::vector dilation_vec = dilation_attr->cast()->value(); - dilation_h = GetValue(dilation_vec[2]); - dilation_w = GetValue(dilation_vec[3]); - } else { - int64_t dilation = dilation_attr->cast()->value(); - dilation_h = dilation; - dilation_w = dilation; - } - - std::vector padding; - ValuePtr padding_attr = primitive->GetAttr("pad"); - MS_EXCEPTION_IF_NULL(padding_attr); - if (padding_attr->isa()) { - std::vector padding_vec = padding_attr->cast()->value(); - (void)std::transform(std::begin(padding_vec), std::end(padding_vec), std::back_inserter(padding), - [](const ValuePtr &e) -> int64_t { return GetValue(e); }); - } else { - int64_t padding_val = padding_attr->cast()->value(); - padding = {padding_val, padding_val, padding_val, padding_val}; - } - - std::set available_pad_mode{"pad", "same", "valid"}; - ValuePtr pad_mode_attr = primitive->GetAttr("pad_mode"); - MS_EXCEPTION_IF_NULL(pad_mode_attr); - auto pad_mode = pad_mode_attr->cast()->value(); - if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { - MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; - } - - std::function &, std::vector &)> pad_function = - [kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, pad_mode, padding]( - int64_t x_h, int64_t x_w, std::vector &output_hw, std::vector &pad_list) { - if (pad_mode == "valid") { - output_hw.push_back(std::ceil(((x_h * 1.0) - dilation_h * (kernel_h - 1)) / stride_h)); - output_hw.push_back(std::ceil(((x_w * 1.0) - dilation_w * (kernel_w - 1)) / stride_w)); - pad_list = {0, 0, 0, 0}; - } else if (pad_mode == "same") { - output_hw.push_back(std::ceil((x_h * 1.0) / stride_h)); - output_hw.push_back(std::ceil((x_w * 1.0) / stride_w)); - int64_t pad_needed_h = (output_hw[0] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_h; - pad_needed_h = std::max((int64_t)0, pad_needed_h); - pad_list.push_back(std::floor(pad_needed_h / 2)); - pad_list.push_back(pad_needed_h - pad_list[0]); - int64_t pad_needed_w = (output_hw[1] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_w; - pad_needed_w = std::max((int64_t)0, pad_needed_w); - pad_list.push_back(std::floor(pad_needed_w / 2)); - pad_list.push_back(pad_needed_w - pad_list[2]); - } else if (pad_mode == "pad") { - pad_list = padding; - output_hw.push_back(std::floor( - 1 + ((x_h * 1.0) + pad_list[0] + pad_list[1] - kernel_h - (kernel_h - 1) * (dilation_h - 1)) / stride_h)); - output_hw.push_back(std::floor( - 1 + ((x_w * 1.0) + pad_list[2] + pad_list[3] - kernel_w - (kernel_w - 1) * (dilation_w - 1)) / stride_w)); - } - }; - + std::vector stride = CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), 2, 2); + std::vector dilation = CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), 2, 2); + std::vector padding = CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), 0, 4); + std::string pad_mode = + CheckAttrStringSet(op_name, primitive->GetAttr("pad_mode"), "pad_mode", {"pad", "same", "valid"}); std::vector output_hw; std::vector pad_list; std::vector output_hw_min; std::vector pad_list_min; std::vector output_hw_max; std::vector pad_list_max; - pad_function(x_shape[h_axis], x_shape[w_axis], output_hw, pad_list); + Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode, + padding); if (x_shape[h_axis] == Shape::SHP_ANY) { output_hw[0] = Shape::SHP_ANY; } if (x_shape[w_axis] == Shape::SHP_ANY) { output_hw[1] = Shape::SHP_ANY; } - pad_function(x_min_shape[h_axis], x_min_shape[w_axis], output_hw_min, pad_list_min); - pad_function(x_max_shape[h_axis], x_max_shape[w_axis], output_hw_max, pad_list_max); - + Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride, + dilation, pad_mode, padding); + Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride, + dilation, pad_mode, padding); std::vector pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]), MakeValue(pad_list[3])}; primitive->set_attr("pad_list", MakeValue(pad_list_val)); @@ -477,28 +386,15 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]}; output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]}; } - for (size_t i = 0; i < output_shape.size(); ++i) { - if ((output_shape[i] < 0) && (output_shape[i] != Shape::SHP_ANY)) { - MS_LOG(EXCEPTION) << "Shape element output_shape[" << i << "] must be positive integer, but got " - << output_shape[i]; - } - if (output_shape_min[i] < 0) { - MS_LOG(EXCEPTION) << "Min Shape element output_shape_min[" << i << "] must be positive integer, but got " - << output_shape_min[i]; - } - if (output_shape_max[i] < 0) { - MS_LOG(EXCEPTION) << "Max Shape element output_shape_max[" << i << "] must be positive integer, but got " - << output_shape_max[i]; - } + CheckShapeAnyAndPositive(op_name + " output_shape", output_shape); + CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min); + CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max); + TypePtr x_type = input_x->element()->GetTypeTrack(); + if (x_type->type_id() == TypeId::kNumberTypeInt8) { + x_type = kInt32; } - ShapePtr output_shape_ptr = std::make_shared(output_shape, output_shape_min, output_shape_max); - if (input_x->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt8) { - auto output = std::make_shared(kInt32, output_shape); - output->set_shape(output_shape_ptr); - return output; - } - return std::make_shared(input_x->element(), output_shape_ptr); + return std::make_shared(x_type, output_shape_ptr); } AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/tests/st/ops/gpu/test_conv2d_op.py b/tests/st/ops/gpu/test_conv2d_op.py index afa15a818d1..d5919c5b1c8 100644 --- a/tests/st/ops/gpu/test_conv2d_op.py +++ b/tests/st/ops/gpu/test_conv2d_op.py @@ -232,3 +232,36 @@ def test_conv2d_dynamic(): assert (output1.asnumpy() == expect1).all() output2 = conv2d(x2, w2) assert (output2.asnumpy() == expect2).all() + + +class NetConvNHWC(nn.Cell): + def __init__(self, weight, x): + super(NetConvNHWC, self).__init__() + self.conv = nn.Conv2d(in_channels=1, + out_channels=3, + kernel_size=2, + stride=2, + pad_mode="valid", + weight_init=Tensor(weight), + data_format='NHWC' + ) + self.x = Parameter(initializer(Tensor(x), [1, 4, 4, 1]), name="x") + + def construct(self): + return self.conv(self.x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_conv_NHWC(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x1 = Tensor(np.arange(1 * 4 * 4 * 1).reshape(1, 4, 4, 1).astype(np.float32)) + w1 = Tensor(np.arange(3 * 2 * 2 * 1).reshape(3, 2, 2, 1).astype(np.float32)) + expected = np.array([[[[24., 64., 104.], + [36., 108., 180.]], + [[72., 240., 408.], + [84., 284., 484.]]]]).astype(np.float32) + conv2d = NetConvNHWC(w1, x1) + output = conv2d() + assert (output.asnumpy() == expected).all()