!28878 Fix bug for Scatter_xx ops in r1.6

Merge pull request !28878 from 张毅辉/cherry-pick-1641891038
This commit is contained in:
i-robot 2022-01-12 08:26:43 +00:00 committed by Gitee
commit 605b15335e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 29 additions and 30 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -43,18 +43,17 @@ abstract::ShapePtr ScatterNdAddInferShape(const PrimitivePtr &primitive,
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
if (indices_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape[0], kNotEqual, -1);
}
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildShape());
auto last_dim = indices_shape.back();
indices_shape.pop_back();
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
(void)CheckAndConvertUtils::CheckInteger("length of updates_shape and indices_shape + x_shape[1:]",
updates_shape.size(), kEqual, indices_shape.size(), prim_name);
for (size_t i = 0; i < updates_shape.size(); i++) {
(void)CheckAndConvertUtils::CheckInteger("elements of updates_shape and indices_shape + x_shape[1:]",
updates_shape[i], kEqual, indices_shape[i], prim_name);
if (last_dim < SizeToLong(input_x_shape.size())) {
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
}
if (updates_shape != indices_shape) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
<< "updates_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: "
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
}
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
return output_shape;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,18 +42,18 @@ abstract::ShapePtr ScatterNdUpdateInferShape(const PrimitivePtr &primitive,
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
if (indices_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape[0], kNotEqual, -1);
}
auto last_dim = indices_shape.back();
indices_shape.pop_back();
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
(void)CheckAndConvertUtils::CheckInteger("length of indices_shape[:-1] + x_shape[indices_shape[-1]:]",
updates_shape.size(), kEqual, indices_shape.size(), prim_name);
for (size_t i = 0; i < updates_shape.size(); i++) {
(void)CheckAndConvertUtils::CheckInteger("elements of indices_shape[:-1] + x_shape[indices_shape[-1]:]",
updates_shape[i], kEqual, indices_shape[i], prim_name);
if (last_dim < SizeToLong(input_x_shape.size())) {
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
}
if (updates_shape != indices_shape) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
<< "updates_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: "
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
}
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -41,19 +41,19 @@ abstract::ShapePtr ScatterNonAliasingAddInferShape(const PrimitivePtr &primitive
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
if (indices_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape[0], kNotEqual, -1);
}
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildShape());
auto last_dim = indices_shape.back();
indices_shape.pop_back();
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
(void)CheckAndConvertUtils::CheckInteger("length of updates_shape and indices_shape + x_shape[1:]",
updates_shape.size(), kEqual, indices_shape.size(), prim_name);
for (size_t i = 0; i < updates_shape.size(); i++) {
(void)CheckAndConvertUtils::CheckInteger("elements of updates_shape and indices_shape + x_shape[1:]",
updates_shape[i], kEqual, indices_shape[i], prim_name);
if (last_dim < SizeToLong(input_x_shape.size())) {
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
}
if (updates_shape != indices_shape) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
<< "updates_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: "
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
}
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}