!26922 modify error info of conv2d bprop

Merge pull request !26922 from wangnan39/bug_fix_conv_x_size_error
This commit is contained in:
i-robot 2021-11-30 01:58:51 +00:00 committed by Gitee
commit d74ad572d9
13 changed files with 96 additions and 79 deletions

View File

@ -848,9 +848,9 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
if (args_spec_list.size() == 2) { if (args_spec_list.size() == 2) {
auto input_value = args_spec_list[1]->BuildValue(); auto input_value = args_spec_list[1]->BuildValue();
if (input_value->isa<tensor::Tensor>()) { if (input_value->isa<tensor::Tensor>()) {
shape = CheckAndConvertUtils::CheckTensorIntValue("reshape args value", input_value, op_name); shape = CheckAndConvertUtils::CheckTensorIntValue("shape", input_value, op_name);
} else { } else {
shape = CheckAndConvertUtils::CheckAttrTupleInt("reshape args value", input_value, op_name); shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", input_value, op_name);
} }
} else { } else {
ValuePtr sh = primitive->GetAttr("shape"); ValuePtr sh = primitive->GetAttr("shape");

View File

@ -99,11 +99,13 @@ abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
} }
} else if (filter_size->isa<abstract::AbstractTuple>()) { } else if (filter_size->isa<abstract::AbstractTuple>()) {
// check tensor, tuple or int to raise error. // check tensor, tuple or int to raise error.
out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("filter_size", filter_size_v, prim_name); out_shape = CheckAndConvertUtils::CheckTupleInt("input[filter_size]", filter_size_v, prim_name);
ret_shape = std::make_shared<abstract::Shape>(out_shape); ret_shape = std::make_shared<abstract::Shape>(out_shape);
} else { } else {
MS_EXCEPTION(TypeError) << "Conv2DBackpropFilter filter_size must be a tuple or tensor, but " auto size_type = filter_size->BuildType();
<< filter_size->ToString(); MS_EXCEPTION_IF_NULL(size_type);
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input[filter size] must be a tuple or Tensor, "
<< "but got " << size_type->ToString();
} }
return ret_shape; return ret_shape;
} }

View File

@ -35,9 +35,10 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check // check
auto kernel_size = auto kernel_size =
CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name); CheckAndConvertUtils::CheckIntOrTupleInt("attribute[kernel_size]", primitive->GetAttr(kKernelSize), prim_name);
auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name); auto stride = CheckAndConvertUtils::CheckIntOrTupleInt("attribute[stride]", primitive->GetAttr(kStride), prim_name);
auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name); auto dilation =
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
// default pad mode is valid // default pad mode is valid
auto attr_pad_list_prt = primitive->GetAttr(kPadList); auto attr_pad_list_prt = primitive->GetAttr(kPadList);
int64_t pad_mode; int64_t pad_mode;
@ -128,10 +129,13 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
} }
} else if (input_size->isa<abstract::AbstractTuple>()) { } else if (input_size->isa<abstract::AbstractTuple>()) {
// check tensor, tuple or int to raise error. // check tensor, tuple or int to raise error.
out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("input x size", input_size_v, prim_name); out_shape = CheckAndConvertUtils::CheckTupleInt("input[x size]", input_size_v, prim_name);
ret_shape = std::make_shared<abstract::Shape>(out_shape); ret_shape = std::make_shared<abstract::Shape>(out_shape);
} else { } else {
MS_EXCEPTION(TypeError) << "Conv2DBackpropInput x_size must be a tuple or tensor, but " << input_size->ToString(); auto size_type = input_size->BuildType();
MS_EXCEPTION_IF_NULL(size_type);
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input[x size] must be a tuple or Tensor, "
<< "but got " << size_type->ToString();
} }
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDoutIndex]->BuildShape())[kShape]; auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDoutIndex]->BuildShape())[kShape];

View File

@ -30,7 +30,7 @@ namespace {
void CheckSliceType(const AbstractBasePtr &input_arg, const std::string &arg_name, const std::string &prim_name) { void CheckSliceType(const AbstractBasePtr &input_arg, const std::string &arg_name, const std::string &prim_name) {
if (input_arg->isa<abstract::AbstractTuple>()) { if (input_arg->isa<abstract::AbstractTuple>()) {
auto temp_value = input_arg->BuildValue(); auto temp_value = input_arg->BuildValue();
(void)CheckAndConvertUtils::CheckAttrTupleInt(arg_name, temp_value, prim_name); (void)CheckAndConvertUtils::CheckTupleInt(arg_name, temp_value, prim_name);
return; return;
} else if (input_arg->isa<abstract::AbstractTensor>()) { } else if (input_arg->isa<abstract::AbstractTensor>()) {
(void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, prim_name); (void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, prim_name);
@ -52,7 +52,7 @@ abstract::ShapePtr StridedSliceGradInferShape(const PrimitivePtr &primitive,
abstract::ShapePtr ret_shape; abstract::ShapePtr ret_shape;
if (shapex->isa<abstract::AbstractTuple>()) { if (shapex->isa<abstract::AbstractTuple>()) {
out_shape = CheckAndConvertUtils::CheckAttrTupleInt("shapex", shape_value, prim_name); out_shape = CheckAndConvertUtils::CheckTupleInt("input[shapex]", shape_value, prim_name);
return std::make_shared<abstract::Shape>(out_shape); return std::make_shared<abstract::Shape>(out_shape);
} }

View File

@ -30,7 +30,7 @@ abstract::ShapePtr OnesInferShape(const PrimitivePtr &primitive, const std::vect
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check // check
auto shape_value = input_args[0]->BuildValue(); auto shape_value = input_args[0]->BuildValue();
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name); std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name); (void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
return std::make_shared<abstract::Shape>(out_shape); return std::make_shared<abstract::Shape>(out_shape);
} }

View File

@ -46,7 +46,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
if (input_value->isa<tensor::Tensor>()) { if (input_value->isa<tensor::Tensor>()) {
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, prim_name); tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, prim_name);
} else { } else {
tmp_input = CheckAndConvertUtils::CheckAttrTupleInt("slice args value", input_value, prim_name); tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, prim_name);
} }
(void)input_values.emplace_back(tmp_input); (void)input_values.emplace_back(tmp_input);
} }

View File

@ -69,12 +69,9 @@ void GetAndCheckAttrMask(const PrimitivePtr &primitive, std::vector<int64_t> *be
return; return;
} }
int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) { int64_t GetSlicingLengthForPositiveStrides(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
int64_t slicing_length = 0; int64_t slicing_length = 0;
if (strides > 0) { if ((start_pos < x_dim) && end_pos >= -x_dim) {
if ((start_pos >= x_dim) || end_pos < -x_dim) {
slicing_length = 0;
} else {
if (-x_dim <= start_pos && start_pos < 0) { if (-x_dim <= start_pos && start_pos < 0) {
start_pos += x_dim; start_pos += x_dim;
} }
@ -93,10 +90,12 @@ int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides
slicing_length = 1 + (end_pos - 1 - start_pos) / strides; slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
} }
} }
} else { return slicing_length;
if (start_pos < -x_dim || end_pos >= x_dim) { }
slicing_length = 0;
} else { int64_t GetSlicingLengthForNegativeStrides(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
int64_t slicing_length = 0;
if (start_pos >= -x_dim && end_pos < x_dim) {
if (start_pos > 0 && start_pos < x_dim) { if (start_pos > 0 && start_pos < x_dim) {
start_pos += -x_dim; start_pos += -x_dim;
} }
@ -115,6 +114,18 @@ int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides
slicing_length = 1 + (end_pos + 1 - start_pos) / strides; slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
} }
} }
return slicing_length;
}
int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
int64_t slicing_length = 0;
if (strides == 0) {
MS_EXCEPTION(ValueError) << "StridedSlice's input strides cannot contain 0.";
}
if (strides > 0) {
slicing_length = GetSlicingLengthForPositiveStrides(start_pos, end_pos, strides, x_dim);
} else {
slicing_length = GetSlicingLengthForNegativeStrides(start_pos, end_pos, strides, x_dim);
} }
return slicing_length; return slicing_length;
} }
@ -288,7 +299,7 @@ bool CheckAndGetDynamicSlice(const AbstractBasePtr &input_arg, const std::string
auto input_value = input_arg->BuildValue(); auto input_value = input_arg->BuildValue();
MS_EXCEPTION_IF_NULL(input_value); MS_EXCEPTION_IF_NULL(input_value);
if (input_arg->isa<abstract::AbstractTuple>()) { if (input_arg->isa<abstract::AbstractTuple>()) {
*slice_value = CheckAndConvertUtils::CheckAttrTupleInt(arg_name, input_value, "StridedSlice"); *slice_value = CheckAndConvertUtils::CheckTupleInt(arg_name, input_value, "StridedSlice");
*slice_len = (*slice_value).size(); *slice_len = (*slice_value).size();
} else if (input_arg->isa<abstract::AbstractTensor>()) { } else if (input_arg->isa<abstract::AbstractTensor>()) {
(void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, "StridedSlice"); (void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, "StridedSlice");

View File

@ -64,9 +64,9 @@ abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vect
std::vector<int64_t> multiples_v; std::vector<int64_t> multiples_v;
auto multiple_value = input_args[1]->BuildValue(); auto multiple_value = input_args[1]->BuildValue();
if (multiple_value->isa<tensor::Tensor>()) { if (multiple_value->isa<tensor::Tensor>()) {
multiples_v = CheckAndConvertUtils::CheckTensorIntValue("tile multiples value", multiple_value, prim_name); multiples_v = CheckAndConvertUtils::CheckTensorIntValue("multiples", multiple_value, prim_name);
} else { } else {
multiples_v = CheckAndConvertUtils::CheckAttrTupleInt("tile multiples value", multiple_value, prim_name); multiples_v = CheckAndConvertUtils::CheckTupleInt("input[multiples]", multiple_value, prim_name);
} }
for (auto multiple : multiples_v) { for (auto multiple : multiples_v) {

View File

@ -46,9 +46,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto perm_value = input_args[1]->BuildValue(); auto perm_value = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(perm_value); MS_EXCEPTION_IF_NULL(perm_value);
if (perm_value->isa<tensor::Tensor>()) { if (perm_value->isa<tensor::Tensor>()) {
p_value = CheckAndConvertUtils::CheckTensorIntValue("perm value", perm_value, op_name); p_value = CheckAndConvertUtils::CheckTensorIntValue("perm", perm_value, op_name);
} else { } else {
p_value = CheckAndConvertUtils::CheckAttrTupleInt("perm value", perm_value, op_name); p_value = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name);
} }
} }
if (x_shape.size() != p_value.size()) { if (x_shape.size() != p_value.size()) {

View File

@ -31,7 +31,7 @@ abstract::ShapePtr ZerosInferShape(const PrimitivePtr &primitive, const std::vec
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check // check
auto shape_value = input_args[0]->BuildValue(); auto shape_value = input_args[0]->BuildValue();
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name); std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name); (void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
return std::make_shared<abstract::Shape>(out_shape); return std::make_shared<abstract::Shape>(out_shape);
} }

View File

@ -546,7 +546,7 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
ShapeVector tensor_value; ShapeVector tensor_value;
if (!value->isa<tensor::Tensor>()) { if (!value->isa<tensor::Tensor>()) {
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name
<< "] must be a tensor,but got " << value->ToString(); << "] must be a tensor, but got " << value->ToString();
} }
auto input_tensor = value->cast<tensor::TensorPtr>(); auto input_tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(input_tensor); MS_EXCEPTION_IF_NULL(input_tensor);
@ -565,7 +565,7 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
tensor_value = {tensor_data, tensor_data + data_size}; tensor_value = {tensor_data, tensor_data + data_size};
} else { } else {
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "] input argument[" << type_name MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "] input argument[" << type_name
<< "] must be a Tensor[Int64] or Tensor[Int32] type,but got " << value->ToString(); << "] must be a Tensor[Int64] or Tensor[Int32] type, but got " << value->ToString();
} }
return tensor_value; return tensor_value;
} }
@ -726,7 +726,7 @@ void CheckAndConvertUtils::CheckMode(const std::string &class_name) {
} }
} }
std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr, std::vector<int64_t> CheckAndConvertUtils::CheckIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
const std::string &prim_name) { const std::string &prim_name) {
std::vector<int64_t> result; std::vector<int64_t> result;
bool is_correct = false; bool is_correct = false;
@ -749,13 +749,13 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::str
} }
} }
if (!is_correct) { if (!is_correct) {
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
<< "] must be a Int or a tuple with all Int elements, but got " << attr->ToString(); << " must be a Int or a tuple with all Int elements, but got " << attr->ToString();
} }
return result; return result;
} }
std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &arg_name, const ValuePtr &attr, std::vector<int64_t> CheckAndConvertUtils::CheckTupleInt(const std::string &arg_name, const ValuePtr &attr,
const std::string &prim_name) { const std::string &prim_name) {
std::vector<int64_t> result; std::vector<int64_t> result;
MS_EXCEPTION_IF_NULL(attr); MS_EXCEPTION_IF_NULL(attr);
@ -764,14 +764,14 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &
(void)std::transform( (void)std::transform(
attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t { attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
if (!e->isa<Int64Imm>()) { if (!e->isa<Int64Imm>()) {
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
<< "] must be a tuple with all Int elements, but got " << attr->ToString(); << " must be a tuple with all Int elements, but got " << attr->ToString();
} }
return GetValue<int64_t>(e); return GetValue<int64_t>(e);
}); });
} else { } else {
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
<< "] must be a tuple with all Int elements, but got " << attr->ToString() << "."; << " must be a tuple with all Int elements, but got " << attr->ToString() << ".";
} }
return result; return result;
} }

View File

@ -306,9 +306,9 @@ class CheckAndConvertUtils {
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value, static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
const std::string &class_name); const std::string &class_name);
static void CheckMode(const std::string &class_name); static void CheckMode(const std::string &class_name);
static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr, static std::vector<int64_t> CheckIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr,
const std::string &arg_name); const std::string &arg_name);
static std::vector<int64_t> CheckAttrTupleInt(const std::string &prim_name, const ValuePtr &attr, static std::vector<int64_t> CheckTupleInt(const std::string &prim_name, const ValuePtr &attr,
const std::string &arg_name); const std::string &arg_name);
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape); static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
static int64_t GetAndCheckFormat(const ValuePtr &value); static int64_t GetAndCheckFormat(const ValuePtr &value);

View File

@ -34,7 +34,7 @@ def _check_mul():
finally: finally:
pass pass
print(f"MindSpore version:", ms.__version__) print(f"MindSpore version: ", ms.__version__)
input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32) input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32) input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32)