!42903 solve the bug of inputy is a scalar core dump master

Merge pull request !42903 from zong_shuai/debug_squared_diff_core_dump
This commit is contained in:
i-robot 2022-09-27 12:39:34 +00:00 committed by Gitee
commit 9ec380970b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 14 additions and 13 deletions

View File

@ -42,12 +42,14 @@ int SquaredDifferenceOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator
need_broadcast_ = false;
if (input_shape1.size() != input_shape2.size()) {
need_broadcast_ = true;
}
for (size_t i = 0; i < input_shape1.size(); i++) {
if (input_shape1[i] != input_shape2[i]) {
need_broadcast_ = true;
} else {
for (size_t i = 0; i < input_shape1.size(); i++) {
if (input_shape1[i] != input_shape2[i]) {
need_broadcast_ = true;
}
}
}
if (need_broadcast_ && output_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than " << MAX_DIMS
<< ", but got " << output_shape.size();

View File

@ -33,15 +33,14 @@ abstract::ShapePtr DataFormatVecPermuteInferShape(const PrimitivePtr &primitive,
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
!input_args[kInputIndex0]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex0]->BuildValue()->isa<None>()) {
std::vector<int64_t> shape1 = {4};
std::vector<int64_t> shape2 = {4, 2};
if (x_shape != shape1 && x_shape != shape2) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", input shape must be (4, ) or (4, 2), but got " << x_shape
<< ".";
}
if (IsDynamic(x_shape)) {
return x_shape_ptr;
}
std::vector<int64_t> shape1 = {4};
std::vector<int64_t> shape2 = {4, 2};
if (x_shape != shape1 && x_shape != shape2) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", input shape must be (4, ) or (4, 2), but got " << x_shape
<< ".";
}
return x_shape_ptr;
}