forked from mindspore-Ecosystem/mindspore
fix conv2d NHWC filter shape and cyclomatics errors
This commit is contained in:
parent
59a277756e
commit
f34934a317
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "abstract/param_validator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
|
@ -143,5 +145,68 @@ void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBas
|
|||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape) {
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << op << " shape element [" << i << "] must be positive integer, but got " << shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) {
|
||||
MS_LOG(EXCEPTION) << op << " shape element [" << i << "] must be positive integer or SHP_ANY, but got "
|
||||
<< shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) {
|
||||
int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
|
||||
if (attr_val <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid " << attr_name << " value: " << attr_val << ", should be greater then 0";
|
||||
}
|
||||
return attr_val;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx,
|
||||
const size_t num_element) {
|
||||
std::vector<int64_t> result;
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
if (attr->isa<ValueTuple>()) {
|
||||
std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
|
||||
auto it_start = attr_vec.begin() + start_idx;
|
||||
(void)std::transform(it_start, it_start + num_element, std::back_inserter(result),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
} else {
|
||||
int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
|
||||
result.insert(result.begin(), num_element, attr_val);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name,
|
||||
const std::set<std::string> &val_set) {
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
std::string attr_val = attr->cast<StringImmPtr>()->value();
|
||||
if (val_set.find(attr_val) == val_set.end()) {
|
||||
std::ostringstream buffer;
|
||||
bool f_begin = true;
|
||||
buffer << "{";
|
||||
for (auto &x : val_set) {
|
||||
if (!f_begin) {
|
||||
buffer << ", ";
|
||||
} else {
|
||||
f_begin = false;
|
||||
}
|
||||
buffer << x;
|
||||
}
|
||||
buffer << "}";
|
||||
MS_LOG(EXCEPTION) << op << "Unsupported " << attr_name << ": " << attr_val << ". use " << buffer.str();
|
||||
}
|
||||
return attr_val;
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CORE_ABSTRACT_PARAM_VALIDATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -48,6 +49,18 @@ int64_t CheckAxis(const std::string &op, const ValuePtr &axis, int64_t min, int6
|
|||
|
||||
void CheckArgsSize(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t size_expect);
|
||||
|
||||
void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape);
|
||||
|
||||
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape);
|
||||
|
||||
int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name);
|
||||
|
||||
std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx,
|
||||
const size_t num_element);
|
||||
|
||||
std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name,
|
||||
const std::set<std::string> &val_set);
|
||||
|
||||
template <typename T>
|
||||
struct ReportNameTraits {};
|
||||
|
||||
|
|
|
@ -268,200 +268,109 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
|
|||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
|
||||
void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
|
||||
const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
|
||||
const std::vector<int64_t> &dilation, const std::string &pad_mode,
|
||||
const std::vector<int64_t> &padding) {
|
||||
if (pad_mode == "valid") {
|
||||
output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0]));
|
||||
output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1]));
|
||||
pad_list->insert(pad_list->begin(), 4, 0);
|
||||
} else if (pad_mode == "same") {
|
||||
output_hw->push_back(std::ceil((x_h * 1.0) / stride[0]));
|
||||
output_hw->push_back(std::ceil((x_w * 1.0) / stride[1]));
|
||||
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
|
||||
pad_needed_h = std::max((int64_t)0, pad_needed_h);
|
||||
pad_list->push_back(std::floor(pad_needed_h / 2));
|
||||
pad_list->push_back(pad_needed_h - pad_list->at(0));
|
||||
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
|
||||
pad_needed_w = std::max((int64_t)0, pad_needed_w);
|
||||
pad_list->push_back(std::floor(pad_needed_w / 2));
|
||||
pad_list->push_back(pad_needed_w - pad_list->at(2));
|
||||
} else if (pad_mode == "pad") {
|
||||
pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
|
||||
output_hw->push_back(std::floor(
|
||||
1 +
|
||||
((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) / stride[0]));
|
||||
output_hw->push_back(std::floor(
|
||||
1 +
|
||||
((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) / stride[1]));
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
|
||||
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
MS_EXCEPTION_IF_NULL(input_x->shape());
|
||||
ShapeVector x_shape = input_x->shape()->shape();
|
||||
ShapeVector x_min_shape = input_x->shape()->min_shape();
|
||||
ShapeVector x_max_shape = input_x->shape()->max_shape();
|
||||
(void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
if ((x_shape[i] < 0) && (x_shape[i] != Shape::SHP_ANY)) {
|
||||
MS_LOG(EXCEPTION) << "Shape element x_shape[" << i << "] must be positive integer, but got " << x_shape[i];
|
||||
}
|
||||
if (x_min_shape[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Min Shape element x_min_shape[" << i << "] must be positive integer, but got "
|
||||
<< x_min_shape[i];
|
||||
}
|
||||
if (x_max_shape[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Max Shape element x_max_shape[" << i << "] must be positive integer, but got "
|
||||
<< x_max_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
|
||||
CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
|
||||
CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape);
|
||||
CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape);
|
||||
AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(input_w);
|
||||
MS_EXCEPTION_IF_NULL(input_w->shape());
|
||||
ShapeVector w_shape = input_w->shape()->shape();
|
||||
ShapeVector w_min_shape = input_w->shape()->min_shape();
|
||||
ShapeVector w_max_shape = input_w->shape()->max_shape();
|
||||
(void)CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape);
|
||||
for (size_t i = 0; i < w_shape.size(); ++i) {
|
||||
if ((w_shape[i] < 0) && (w_shape[i] != Shape::SHP_ANY)) {
|
||||
MS_LOG(EXCEPTION) << "Shape element w_shape[" << i << "] must be positive integer, but got " << w_shape[i];
|
||||
}
|
||||
if (w_min_shape[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Min Shape element w_min_shape[" << i << "] must be positive integer, but got "
|
||||
<< w_min_shape[i];
|
||||
}
|
||||
if (w_max_shape[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Max Shape element w_max_shape[" << i << "] must be positive integer, but got "
|
||||
<< w_max_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::set<std::string> available_data_format{"NCHW", "NHWC"};
|
||||
CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape);
|
||||
CheckShapeAnyAndPositive(op_name + " w_shape", w_shape);
|
||||
CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape);
|
||||
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
|
||||
std::string data_format = CheckAttrStringSet(op_name, primitive->GetAttr("format"), "format", {"NCHW", "NHWC"});
|
||||
int64_t n_axis = 0;
|
||||
int64_t c_axis = 1;
|
||||
int64_t h_axis = 2;
|
||||
int64_t w_axis = 3;
|
||||
auto data_format_ptr = primitive->GetAttr("format");
|
||||
std::string data_format = "NCHW";
|
||||
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) {
|
||||
data_format = data_format_ptr->cast<StringImmPtr>()->value();
|
||||
if (available_data_format.find(data_format) == available_data_format.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ". use NCHW or NHWC";
|
||||
}
|
||||
if (data_format == "NHWC") {
|
||||
c_axis = 3;
|
||||
h_axis = 1;
|
||||
w_axis = 2;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t group = primitive->GetAttr("group")->cast<Int64ImmPtr>()->value();
|
||||
if (group <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid group value: " << group << ", should be greater then 0";
|
||||
if (data_format == "NHWC") {
|
||||
c_axis = 3;
|
||||
h_axis = 1;
|
||||
w_axis = 2;
|
||||
}
|
||||
int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group");
|
||||
if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) {
|
||||
MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis]
|
||||
<< " (channels) must be divisible by group = " << group;
|
||||
}
|
||||
|
||||
int64_t out_channel = primitive->GetAttr("out_channel")->cast<Int64ImmPtr>()->value();
|
||||
if (out_channel <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid out_channel value: " << out_channel << ", should be greater then 0";
|
||||
int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel");
|
||||
if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {
|
||||
MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must equal to = " << out_channel;
|
||||
}
|
||||
if ((w_shape[0] != Shape::SHP_ANY) && (w_shape[0] != out_channel)) {
|
||||
MS_LOG(EXCEPTION) << "w_shape[0] = " << w_shape[0] << " must equal to = " << out_channel;
|
||||
std::vector<int64_t> kernel_size = CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, 2);
|
||||
if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) {
|
||||
MS_LOG(EXCEPTION) << "weight height = " << w_shape[h_axis] << ", must equal to = " << kernel_size[0];
|
||||
}
|
||||
|
||||
int64_t kernel_h = 0;
|
||||
int64_t kernel_w = 0;
|
||||
ValuePtr kernel_size_attr = primitive->GetAttr("kernel_size");
|
||||
MS_EXCEPTION_IF_NULL(kernel_size_attr);
|
||||
if (kernel_size_attr->isa<ValueTuple>()) {
|
||||
std::vector<ValuePtr> kernel_size_vec = kernel_size_attr->cast<ValueTuplePtr>()->value();
|
||||
kernel_h = GetValue<int64_t>(kernel_size_vec[0]);
|
||||
kernel_w = GetValue<int64_t>(kernel_size_vec[1]);
|
||||
} else {
|
||||
int64_t kernel_size = kernel_size_attr->cast<Int64ImmPtr>()->value();
|
||||
kernel_h = kernel_size;
|
||||
kernel_w = kernel_size;
|
||||
if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) {
|
||||
MS_LOG(EXCEPTION) << "weight width = " << w_shape[w_axis] << ", must equal to = " << kernel_size[1];
|
||||
}
|
||||
if ((w_shape[2] != Shape::SHP_ANY) && (w_shape[2] != kernel_h)) {
|
||||
MS_LOG(EXCEPTION) << "weight height, w_shape[2] = " << w_shape[2] << ", must equal to = " << kernel_h;
|
||||
}
|
||||
if ((w_shape[3] != Shape::SHP_ANY) && (w_shape[3] != kernel_w)) {
|
||||
MS_LOG(EXCEPTION) << "weight width, w_shape[3] = " << w_shape[3] << ", must equal to = " << kernel_w;
|
||||
}
|
||||
|
||||
int64_t stride_h = 0;
|
||||
int64_t stride_w = 0;
|
||||
ValuePtr stride_attr = primitive->GetAttr("stride");
|
||||
MS_EXCEPTION_IF_NULL(stride_attr);
|
||||
if (stride_attr->isa<ValueTuple>()) {
|
||||
std::vector<ValuePtr> stride_vec = stride_attr->cast<ValueTuplePtr>()->value();
|
||||
stride_h = GetValue<int64_t>(stride_vec[2]);
|
||||
stride_w = GetValue<int64_t>(stride_vec[3]);
|
||||
} else {
|
||||
int64_t stride = stride_attr->cast<Int64ImmPtr>()->value();
|
||||
stride_h = stride;
|
||||
stride_w = stride;
|
||||
}
|
||||
|
||||
int64_t dilation_h = 0;
|
||||
int64_t dilation_w = 0;
|
||||
ValuePtr dilation_attr = primitive->GetAttr("dilation");
|
||||
MS_EXCEPTION_IF_NULL(dilation_attr);
|
||||
if (dilation_attr->isa<ValueTuple>()) {
|
||||
std::vector<ValuePtr> dilation_vec = dilation_attr->cast<ValueTuplePtr>()->value();
|
||||
dilation_h = GetValue<int64_t>(dilation_vec[2]);
|
||||
dilation_w = GetValue<int64_t>(dilation_vec[3]);
|
||||
} else {
|
||||
int64_t dilation = dilation_attr->cast<Int64ImmPtr>()->value();
|
||||
dilation_h = dilation;
|
||||
dilation_w = dilation;
|
||||
}
|
||||
|
||||
std::vector<int64_t> padding;
|
||||
ValuePtr padding_attr = primitive->GetAttr("pad");
|
||||
MS_EXCEPTION_IF_NULL(padding_attr);
|
||||
if (padding_attr->isa<ValueTuple>()) {
|
||||
std::vector<ValuePtr> padding_vec = padding_attr->cast<ValueTuplePtr>()->value();
|
||||
(void)std::transform(std::begin(padding_vec), std::end(padding_vec), std::back_inserter(padding),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
} else {
|
||||
int64_t padding_val = padding_attr->cast<Int64ImmPtr>()->value();
|
||||
padding = {padding_val, padding_val, padding_val, padding_val};
|
||||
}
|
||||
|
||||
std::set<std::string> available_pad_mode{"pad", "same", "valid"};
|
||||
ValuePtr pad_mode_attr = primitive->GetAttr("pad_mode");
|
||||
MS_EXCEPTION_IF_NULL(pad_mode_attr);
|
||||
auto pad_mode = pad_mode_attr->cast<StringImmPtr>()->value();
|
||||
if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid";
|
||||
}
|
||||
|
||||
std::function<void(int64_t, int64_t, std::vector<int64_t> &, std::vector<int64_t> &)> pad_function =
|
||||
[kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, pad_mode, padding](
|
||||
int64_t x_h, int64_t x_w, std::vector<int64_t> &output_hw, std::vector<int64_t> &pad_list) {
|
||||
if (pad_mode == "valid") {
|
||||
output_hw.push_back(std::ceil(((x_h * 1.0) - dilation_h * (kernel_h - 1)) / stride_h));
|
||||
output_hw.push_back(std::ceil(((x_w * 1.0) - dilation_w * (kernel_w - 1)) / stride_w));
|
||||
pad_list = {0, 0, 0, 0};
|
||||
} else if (pad_mode == "same") {
|
||||
output_hw.push_back(std::ceil((x_h * 1.0) / stride_h));
|
||||
output_hw.push_back(std::ceil((x_w * 1.0) / stride_w));
|
||||
int64_t pad_needed_h = (output_hw[0] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_h;
|
||||
pad_needed_h = std::max((int64_t)0, pad_needed_h);
|
||||
pad_list.push_back(std::floor(pad_needed_h / 2));
|
||||
pad_list.push_back(pad_needed_h - pad_list[0]);
|
||||
int64_t pad_needed_w = (output_hw[1] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_w;
|
||||
pad_needed_w = std::max((int64_t)0, pad_needed_w);
|
||||
pad_list.push_back(std::floor(pad_needed_w / 2));
|
||||
pad_list.push_back(pad_needed_w - pad_list[2]);
|
||||
} else if (pad_mode == "pad") {
|
||||
pad_list = padding;
|
||||
output_hw.push_back(std::floor(
|
||||
1 + ((x_h * 1.0) + pad_list[0] + pad_list[1] - kernel_h - (kernel_h - 1) * (dilation_h - 1)) / stride_h));
|
||||
output_hw.push_back(std::floor(
|
||||
1 + ((x_w * 1.0) + pad_list[2] + pad_list[3] - kernel_w - (kernel_w - 1) * (dilation_w - 1)) / stride_w));
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<int64_t> stride = CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), 2, 2);
|
||||
std::vector<int64_t> dilation = CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), 2, 2);
|
||||
std::vector<int64_t> padding = CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), 0, 4);
|
||||
std::string pad_mode =
|
||||
CheckAttrStringSet(op_name, primitive->GetAttr("pad_mode"), "pad_mode", {"pad", "same", "valid"});
|
||||
std::vector<int64_t> output_hw;
|
||||
std::vector<int64_t> pad_list;
|
||||
std::vector<int64_t> output_hw_min;
|
||||
std::vector<int64_t> pad_list_min;
|
||||
std::vector<int64_t> output_hw_max;
|
||||
std::vector<int64_t> pad_list_max;
|
||||
pad_function(x_shape[h_axis], x_shape[w_axis], output_hw, pad_list);
|
||||
Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode,
|
||||
padding);
|
||||
if (x_shape[h_axis] == Shape::SHP_ANY) {
|
||||
output_hw[0] = Shape::SHP_ANY;
|
||||
}
|
||||
if (x_shape[w_axis] == Shape::SHP_ANY) {
|
||||
output_hw[1] = Shape::SHP_ANY;
|
||||
}
|
||||
pad_function(x_min_shape[h_axis], x_min_shape[w_axis], output_hw_min, pad_list_min);
|
||||
pad_function(x_max_shape[h_axis], x_max_shape[w_axis], output_hw_max, pad_list_max);
|
||||
|
||||
Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride,
|
||||
dilation, pad_mode, padding);
|
||||
Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride,
|
||||
dilation, pad_mode, padding);
|
||||
std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
|
||||
MakeValue(pad_list[3])};
|
||||
primitive->set_attr("pad_list", MakeValue(pad_list_val));
|
||||
|
@ -477,28 +386,15 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]};
|
||||
output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]};
|
||||
}
|
||||
for (size_t i = 0; i < output_shape.size(); ++i) {
|
||||
if ((output_shape[i] < 0) && (output_shape[i] != Shape::SHP_ANY)) {
|
||||
MS_LOG(EXCEPTION) << "Shape element output_shape[" << i << "] must be positive integer, but got "
|
||||
<< output_shape[i];
|
||||
}
|
||||
if (output_shape_min[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Min Shape element output_shape_min[" << i << "] must be positive integer, but got "
|
||||
<< output_shape_min[i];
|
||||
}
|
||||
if (output_shape_max[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Max Shape element output_shape_max[" << i << "] must be positive integer, but got "
|
||||
<< output_shape_max[i];
|
||||
}
|
||||
CheckShapeAnyAndPositive(op_name + " output_shape", output_shape);
|
||||
CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min);
|
||||
CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max);
|
||||
TypePtr x_type = input_x->element()->GetTypeTrack();
|
||||
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
|
||||
x_type = kInt32;
|
||||
}
|
||||
|
||||
ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max);
|
||||
if (input_x->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt8) {
|
||||
auto output = std::make_shared<AbstractTensor>(kInt32, output_shape);
|
||||
output->set_shape(output_shape_ptr);
|
||||
return output;
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr);
|
||||
return std::make_shared<AbstractTensor>(x_type, output_shape_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -232,3 +232,36 @@ def test_conv2d_dynamic():
|
|||
assert (output1.asnumpy() == expect1).all()
|
||||
output2 = conv2d(x2, w2)
|
||||
assert (output2.asnumpy() == expect2).all()
|
||||
|
||||
|
||||
class NetConvNHWC(nn.Cell):
|
||||
def __init__(self, weight, x):
|
||||
super(NetConvNHWC, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=1,
|
||||
out_channels=3,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
pad_mode="valid",
|
||||
weight_init=Tensor(weight),
|
||||
data_format='NHWC'
|
||||
)
|
||||
self.x = Parameter(initializer(Tensor(x), [1, 4, 4, 1]), name="x")
|
||||
|
||||
def construct(self):
|
||||
return self.conv(self.x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv_NHWC():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x1 = Tensor(np.arange(1 * 4 * 4 * 1).reshape(1, 4, 4, 1).astype(np.float32))
|
||||
w1 = Tensor(np.arange(3 * 2 * 2 * 1).reshape(3, 2, 2, 1).astype(np.float32))
|
||||
expected = np.array([[[[24., 64., 104.],
|
||||
[36., 108., 180.]],
|
||||
[[72., 240., 408.],
|
||||
[84., 284., 484.]]]]).astype(np.float32)
|
||||
conv2d = NetConvNHWC(w1, x1)
|
||||
output = conv2d()
|
||||
assert (output.asnumpy() == expected).all()
|
||||
|
|
Loading…
Reference in New Issue