!28878 Fix bug for Scatter_xx ops in r1.6
Merge pull request !28878 from 张毅辉/cherry-pick-1641891038
This commit is contained in:
commit
605b15335e
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -43,18 +43,17 @@ abstract::ShapePtr ScatterNdAddInferShape(const PrimitivePtr &primitive,
|
|||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
if (indices_shape.size() == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape[0], kNotEqual, -1);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildShape());
|
||||
auto last_dim = indices_shape.back();
|
||||
indices_shape.pop_back();
|
||||
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
|
||||
(void)CheckAndConvertUtils::CheckInteger("length of updates_shape and indices_shape + x_shape[1:]",
|
||||
updates_shape.size(), kEqual, indices_shape.size(), prim_name);
|
||||
for (size_t i = 0; i < updates_shape.size(); i++) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("elements of updates_shape and indices_shape + x_shape[1:]",
|
||||
updates_shape[i], kEqual, indices_shape[i], prim_name);
|
||||
if (last_dim < SizeToLong(input_x_shape.size())) {
|
||||
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
|
||||
}
|
||||
if (updates_shape != indices_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
|
||||
<< "updates_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: "
|
||||
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
|
||||
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
|
||||
}
|
||||
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
return output_shape;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,18 +42,18 @@ abstract::ShapePtr ScatterNdUpdateInferShape(const PrimitivePtr &primitive,
|
|||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
if (indices_shape.size() == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape[0], kNotEqual, -1);
|
||||
}
|
||||
auto last_dim = indices_shape.back();
|
||||
indices_shape.pop_back();
|
||||
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
|
||||
(void)CheckAndConvertUtils::CheckInteger("length of indices_shape[:-1] + x_shape[indices_shape[-1]:]",
|
||||
updates_shape.size(), kEqual, indices_shape.size(), prim_name);
|
||||
for (size_t i = 0; i < updates_shape.size(); i++) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("elements of indices_shape[:-1] + x_shape[indices_shape[-1]:]",
|
||||
updates_shape[i], kEqual, indices_shape[i], prim_name);
|
||||
if (last_dim < SizeToLong(input_x_shape.size())) {
|
||||
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
|
||||
}
|
||||
if (updates_shape != indices_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
|
||||
<< "updates_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: "
|
||||
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
|
||||
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
|
||||
}
|
||||
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -41,19 +41,19 @@ abstract::ShapePtr ScatterNonAliasingAddInferShape(const PrimitivePtr &primitive
|
|||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
if (indices_shape.size() == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape[0], kNotEqual, -1);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildShape());
|
||||
auto last_dim = indices_shape.back();
|
||||
indices_shape.pop_back();
|
||||
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
|
||||
(void)CheckAndConvertUtils::CheckInteger("length of updates_shape and indices_shape + x_shape[1:]",
|
||||
updates_shape.size(), kEqual, indices_shape.size(), prim_name);
|
||||
for (size_t i = 0; i < updates_shape.size(); i++) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("elements of updates_shape and indices_shape + x_shape[1:]",
|
||||
updates_shape[i], kEqual, indices_shape[i], prim_name);
|
||||
if (last_dim < SizeToLong(input_x_shape.size())) {
|
||||
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
|
||||
}
|
||||
if (updates_shape != indices_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
|
||||
<< "updates_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: "
|
||||
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
|
||||
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
|
||||
}
|
||||
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue