forked from mindspore-Ecosystem/mindspore
!26922 modify error info of conv2d bprop
Merge pull request !26922 from wangnan39/bug_fix_conv_x_size_error
This commit is contained in:
commit
d74ad572d9
|
@ -848,9 +848,9 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
if (args_spec_list.size() == 2) {
|
||||
auto input_value = args_spec_list[1]->BuildValue();
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
shape = CheckAndConvertUtils::CheckTensorIntValue("reshape args value", input_value, op_name);
|
||||
shape = CheckAndConvertUtils::CheckTensorIntValue("shape", input_value, op_name);
|
||||
} else {
|
||||
shape = CheckAndConvertUtils::CheckAttrTupleInt("reshape args value", input_value, op_name);
|
||||
shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", input_value, op_name);
|
||||
}
|
||||
} else {
|
||||
ValuePtr sh = primitive->GetAttr("shape");
|
||||
|
|
|
@ -99,11 +99,13 @@ abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
} else if (filter_size->isa<abstract::AbstractTuple>()) {
|
||||
// check tensor, tuple or int to raise error.
|
||||
out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("filter_size", filter_size_v, prim_name);
|
||||
out_shape = CheckAndConvertUtils::CheckTupleInt("input[filter_size]", filter_size_v, prim_name);
|
||||
ret_shape = std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "Conv2DBackpropFilter filter_size must be a tuple or tensor, but "
|
||||
<< filter_size->ToString();
|
||||
auto size_type = filter_size->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(size_type);
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input[filter size] must be a tuple or Tensor, "
|
||||
<< "but got " << size_type->ToString();
|
||||
}
|
||||
return ret_shape;
|
||||
}
|
||||
|
|
|
@ -35,9 +35,10 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
|
|||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto kernel_size =
|
||||
CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name);
|
||||
auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name);
|
||||
auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
|
||||
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[kernel_size]", primitive->GetAttr(kKernelSize), prim_name);
|
||||
auto stride = CheckAndConvertUtils::CheckIntOrTupleInt("attribute[stride]", primitive->GetAttr(kStride), prim_name);
|
||||
auto dilation =
|
||||
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
|
||||
// default pad mode is valid
|
||||
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
||||
int64_t pad_mode;
|
||||
|
@ -128,10 +129,13 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
} else if (input_size->isa<abstract::AbstractTuple>()) {
|
||||
// check tensor, tuple or int to raise error.
|
||||
out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("input x size", input_size_v, prim_name);
|
||||
out_shape = CheckAndConvertUtils::CheckTupleInt("input[x size]", input_size_v, prim_name);
|
||||
ret_shape = std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "Conv2DBackpropInput x_size must be a tuple or tensor, but " << input_size->ToString();
|
||||
auto size_type = input_size->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(size_type);
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input[x size] must be a tuple or Tensor, "
|
||||
<< "but got " << size_type->ToString();
|
||||
}
|
||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDoutIndex]->BuildShape())[kShape];
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
void CheckSliceType(const AbstractBasePtr &input_arg, const std::string &arg_name, const std::string &prim_name) {
|
||||
if (input_arg->isa<abstract::AbstractTuple>()) {
|
||||
auto temp_value = input_arg->BuildValue();
|
||||
(void)CheckAndConvertUtils::CheckAttrTupleInt(arg_name, temp_value, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTupleInt(arg_name, temp_value, prim_name);
|
||||
return;
|
||||
} else if (input_arg->isa<abstract::AbstractTensor>()) {
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, prim_name);
|
||||
|
@ -52,7 +52,7 @@ abstract::ShapePtr StridedSliceGradInferShape(const PrimitivePtr &primitive,
|
|||
abstract::ShapePtr ret_shape;
|
||||
|
||||
if (shapex->isa<abstract::AbstractTuple>()) {
|
||||
out_shape = CheckAndConvertUtils::CheckAttrTupleInt("shapex", shape_value, prim_name);
|
||||
out_shape = CheckAndConvertUtils::CheckTupleInt("input[shapex]", shape_value, prim_name);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ abstract::ShapePtr OnesInferShape(const PrimitivePtr &primitive, const std::vect
|
|||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto shape_value = input_args[0]->BuildValue();
|
||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name);
|
||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
if (input_value->isa<tensor::Tensor>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, prim_name);
|
||||
} else {
|
||||
tmp_input = CheckAndConvertUtils::CheckAttrTupleInt("slice args value", input_value, prim_name);
|
||||
tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, prim_name);
|
||||
}
|
||||
(void)input_values.emplace_back(tmp_input);
|
||||
}
|
||||
|
|
|
@ -69,52 +69,63 @@ void GetAndCheckAttrMask(const PrimitivePtr &primitive, std::vector<int64_t> *be
|
|||
return;
|
||||
}
|
||||
|
||||
int64_t GetSlicingLengthForPositiveStrides(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
|
||||
int64_t slicing_length = 0;
|
||||
if ((start_pos < x_dim) && end_pos >= -x_dim) {
|
||||
if (-x_dim <= start_pos && start_pos < 0) {
|
||||
start_pos += x_dim;
|
||||
}
|
||||
if (start_pos < -x_dim) {
|
||||
start_pos = 0;
|
||||
}
|
||||
if (-x_dim <= end_pos && end_pos < 0) {
|
||||
end_pos += x_dim;
|
||||
}
|
||||
if (end_pos > x_dim) {
|
||||
end_pos = x_dim;
|
||||
}
|
||||
if (start_pos > end_pos) {
|
||||
slicing_length = 0;
|
||||
} else {
|
||||
slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
|
||||
}
|
||||
}
|
||||
return slicing_length;
|
||||
}
|
||||
|
||||
int64_t GetSlicingLengthForNegativeStrides(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
|
||||
int64_t slicing_length = 0;
|
||||
if (start_pos >= -x_dim && end_pos < x_dim) {
|
||||
if (start_pos > 0 && start_pos < x_dim) {
|
||||
start_pos += -x_dim;
|
||||
}
|
||||
if (start_pos >= x_dim) {
|
||||
start_pos = -1;
|
||||
}
|
||||
if (end_pos >= 0 && end_pos < x_dim) {
|
||||
end_pos += -x_dim;
|
||||
}
|
||||
if (end_pos < -x_dim - 1) {
|
||||
end_pos = -x_dim - 1;
|
||||
}
|
||||
if (start_pos <= end_pos) {
|
||||
slicing_length = 0;
|
||||
} else {
|
||||
slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
|
||||
}
|
||||
}
|
||||
return slicing_length;
|
||||
}
|
||||
|
||||
int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
|
||||
int64_t slicing_length = 0;
|
||||
if (strides == 0) {
|
||||
MS_EXCEPTION(ValueError) << "StridedSlice's input strides cannot contain 0.";
|
||||
}
|
||||
if (strides > 0) {
|
||||
if ((start_pos >= x_dim) || end_pos < -x_dim) {
|
||||
slicing_length = 0;
|
||||
} else {
|
||||
if (-x_dim <= start_pos && start_pos < 0) {
|
||||
start_pos += x_dim;
|
||||
}
|
||||
if (start_pos < -x_dim) {
|
||||
start_pos = 0;
|
||||
}
|
||||
if (-x_dim <= end_pos && end_pos < 0) {
|
||||
end_pos += x_dim;
|
||||
}
|
||||
if (end_pos > x_dim) {
|
||||
end_pos = x_dim;
|
||||
}
|
||||
if (start_pos > end_pos) {
|
||||
slicing_length = 0;
|
||||
} else {
|
||||
slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
|
||||
}
|
||||
}
|
||||
slicing_length = GetSlicingLengthForPositiveStrides(start_pos, end_pos, strides, x_dim);
|
||||
} else {
|
||||
if (start_pos < -x_dim || end_pos >= x_dim) {
|
||||
slicing_length = 0;
|
||||
} else {
|
||||
if (start_pos > 0 && start_pos < x_dim) {
|
||||
start_pos += -x_dim;
|
||||
}
|
||||
if (start_pos >= x_dim) {
|
||||
start_pos = -1;
|
||||
}
|
||||
if (end_pos >= 0 && end_pos < x_dim) {
|
||||
end_pos += -x_dim;
|
||||
}
|
||||
if (end_pos < -x_dim - 1) {
|
||||
end_pos = -x_dim - 1;
|
||||
}
|
||||
if (start_pos <= end_pos) {
|
||||
slicing_length = 0;
|
||||
} else {
|
||||
slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
|
||||
}
|
||||
}
|
||||
slicing_length = GetSlicingLengthForNegativeStrides(start_pos, end_pos, strides, x_dim);
|
||||
}
|
||||
return slicing_length;
|
||||
}
|
||||
|
@ -288,7 +299,7 @@ bool CheckAndGetDynamicSlice(const AbstractBasePtr &input_arg, const std::string
|
|||
auto input_value = input_arg->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
if (input_arg->isa<abstract::AbstractTuple>()) {
|
||||
*slice_value = CheckAndConvertUtils::CheckAttrTupleInt(arg_name, input_value, "StridedSlice");
|
||||
*slice_value = CheckAndConvertUtils::CheckTupleInt(arg_name, input_value, "StridedSlice");
|
||||
*slice_len = (*slice_value).size();
|
||||
} else if (input_arg->isa<abstract::AbstractTensor>()) {
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, "StridedSlice");
|
||||
|
|
|
@ -64,9 +64,9 @@ abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vect
|
|||
std::vector<int64_t> multiples_v;
|
||||
auto multiple_value = input_args[1]->BuildValue();
|
||||
if (multiple_value->isa<tensor::Tensor>()) {
|
||||
multiples_v = CheckAndConvertUtils::CheckTensorIntValue("tile multiples value", multiple_value, prim_name);
|
||||
multiples_v = CheckAndConvertUtils::CheckTensorIntValue("multiples", multiple_value, prim_name);
|
||||
} else {
|
||||
multiples_v = CheckAndConvertUtils::CheckAttrTupleInt("tile multiples value", multiple_value, prim_name);
|
||||
multiples_v = CheckAndConvertUtils::CheckTupleInt("input[multiples]", multiple_value, prim_name);
|
||||
}
|
||||
|
||||
for (auto multiple : multiples_v) {
|
||||
|
|
|
@ -46,9 +46,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto perm_value = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(perm_value);
|
||||
if (perm_value->isa<tensor::Tensor>()) {
|
||||
p_value = CheckAndConvertUtils::CheckTensorIntValue("perm value", perm_value, op_name);
|
||||
p_value = CheckAndConvertUtils::CheckTensorIntValue("perm", perm_value, op_name);
|
||||
} else {
|
||||
p_value = CheckAndConvertUtils::CheckAttrTupleInt("perm value", perm_value, op_name);
|
||||
p_value = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name);
|
||||
}
|
||||
}
|
||||
if (x_shape.size() != p_value.size()) {
|
||||
|
|
|
@ -31,7 +31,7 @@ abstract::ShapePtr ZerosInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto shape_value = input_args[0]->BuildValue();
|
||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name);
|
||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
|
|
@ -546,7 +546,7 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
|
|||
ShapeVector tensor_value;
|
||||
if (!value->isa<tensor::Tensor>()) {
|
||||
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name
|
||||
<< "] must be a tensor,but got " << value->ToString();
|
||||
<< "] must be a tensor, but got " << value->ToString();
|
||||
}
|
||||
auto input_tensor = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
|
@ -565,7 +565,7 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
|
|||
tensor_value = {tensor_data, tensor_data + data_size};
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "] input argument[" << type_name
|
||||
<< "] must be a Tensor[Int64] or Tensor[Int32] type,but got " << value->ToString();
|
||||
<< "] must be a Tensor[Int64] or Tensor[Int32] type, but got " << value->ToString();
|
||||
}
|
||||
return tensor_value;
|
||||
}
|
||||
|
@ -726,8 +726,8 @@ void CheckAndConvertUtils::CheckMode(const std::string &class_name) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
||||
const std::string &prim_name) {
|
||||
std::vector<int64_t> CheckAndConvertUtils::CheckIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
||||
const std::string &prim_name) {
|
||||
std::vector<int64_t> result;
|
||||
bool is_correct = false;
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
|
@ -749,14 +749,14 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::str
|
|||
}
|
||||
}
|
||||
if (!is_correct) {
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
|
||||
<< "] must be a Int or a tuple with all Int elements, but got " << attr->ToString();
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
|
||||
<< " must be a Int or a tuple with all Int elements, but got " << attr->ToString();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
||||
const std::string &prim_name) {
|
||||
std::vector<int64_t> CheckAndConvertUtils::CheckTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
||||
const std::string &prim_name) {
|
||||
std::vector<int64_t> result;
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
if (attr->isa<ValueTuple>()) {
|
||||
|
@ -764,14 +764,14 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &
|
|||
(void)std::transform(
|
||||
attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
|
||||
if (!e->isa<Int64Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
|
||||
<< "] must be a tuple with all Int elements, but got " << attr->ToString();
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
|
||||
<< " must be a tuple with all Int elements, but got " << attr->ToString();
|
||||
}
|
||||
return GetValue<int64_t>(e);
|
||||
});
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
|
||||
<< "] must be a tuple with all Int elements, but got " << attr->ToString() << ".";
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
|
||||
<< " must be a tuple with all Int elements, but got " << attr->ToString() << ".";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -306,10 +306,10 @@ class CheckAndConvertUtils {
|
|||
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
|
||||
const std::string &class_name);
|
||||
static void CheckMode(const std::string &class_name);
|
||||
static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
||||
const std::string &arg_name);
|
||||
static std::vector<int64_t> CheckAttrTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
||||
const std::string &arg_name);
|
||||
static std::vector<int64_t> CheckIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
||||
const std::string &arg_name);
|
||||
static std::vector<int64_t> CheckTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
||||
const std::string &arg_name);
|
||||
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
|
||||
static int64_t GetAndCheckFormat(const ValuePtr &value);
|
||||
static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);
|
||||
|
|
|
@ -34,7 +34,7 @@ def _check_mul():
|
|||
finally:
|
||||
pass
|
||||
|
||||
print(f"MindSpore version:", ms.__version__)
|
||||
print(f"MindSpore version: ", ms.__version__)
|
||||
|
||||
input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
|
||||
input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32)
|
||||
|
|
Loading…
Reference in New Issue