From 110caf48e74967dfb709a89f1dcf5fdc05a87070 Mon Sep 17 00:00:00 2001 From: hujiahui8 Date: Tue, 26 Jul 2022 15:52:28 +0800 Subject: [PATCH] modify the codecheck --- mindspore/core/ops/scatter_arithmetic.cc | 4 + mindspore/core/ops/scatter_div.cc | 95 ------------------------ mindspore/core/ops/scatter_div.h | 3 - 3 files changed, 4 insertions(+), 98 deletions(-) delete mode 100644 mindspore/core/ops/scatter_div.cc diff --git a/mindspore/core/ops/scatter_arithmetic.cc b/mindspore/core/ops/scatter_arithmetic.cc index 13808db38f3..9f0d1bcc565 100644 --- a/mindspore/core/ops/scatter_arithmetic.cc +++ b/mindspore/core/ops/scatter_arithmetic.cc @@ -21,6 +21,7 @@ #include "ops/scatter_update.h" #include "ops/scatter_min.h" #include "ops/scatter_max.h" +#include "ops/scatter_div.h" #include "abstract/ops/primitive_infer_map.h" #include "ops/op_utils.h" #include "utils/check_convert_utils.h" @@ -96,5 +97,8 @@ REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMin, prim::kPrimScatterMin, ScatterArithmeti MIND_API_OPERATOR_IMPL(ScatterMax, BaseOperator); REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMax, prim::kPrimScatterMax, ScatterArithmeticInfer, nullptr, true); + +MIND_API_OPERATOR_IMPL(ScatterDiv, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(ScatterDiv, prim::kPrimScatterDiv, ScatterArithmeticInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/scatter_div.cc b/mindspore/core/ops/scatter_div.cc deleted file mode 100644 index 5b3f7aa27fb..00000000000 --- a/mindspore/core/ops/scatter_div.cc +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "ops/scatter_div.h" -#include -#include -#include -#include "abstract/ops/primitive_infer_map.h" -#include "ops/op_utils.h" -#include "utils/check_convert_utils.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -namespace { -abstract::ShapePtr ScatterDivInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - BaseShapePtr input_x_shape_ptr = input_args[kInputIndex0]->BuildShape(); - MS_EXCEPTION_IF_NULL(input_x_shape_ptr); - BaseShapePtr indices_shape_ptr = input_args[kInputIndex1]->BuildShape(); - MS_EXCEPTION_IF_NULL(indices_shape_ptr); - BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape(); - MS_EXCEPTION_IF_NULL(updates_shape_ptr); - - if (input_x_shape_ptr->IsDynamic()) { - MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", " - << "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is " - << input_x_shape_ptr->ToString(); - } - - if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) { - return input_x_shape_ptr->cast(); - } - - std::vector input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape]; - std::vector indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape]; - std::vector updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape]; - std::vector check_update_shape(indices_shape); - for (int64_t i = 1; i < SizeToLong(input_x_shape.size()); ++i) { - check_update_shape.push_back(input_x_shape[i]); - } - if (updates_shape != check_update_shape) { - MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", " - << "updates_shape = indices_shape + x_shape[1:], but got x_shape: " - << input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString() - << ", updates_shape: " << updates_shape_ptr->ToString() << "."; - } - - auto output_shape = input_args[kInputIndex0]->BuildShape()->cast(); - return output_shape; -} - -TypePtr ScatterDivInferType(const PrimitivePtr &primitive, const std::vector &input_args) { - auto input_x_type_ptr = input_args[kInputIndex0]->BuildType(); - auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType(); - auto updates_type_ptr = input_args[kInputIndex2]->BuildType(); - auto prim_name = primitive->name(); - const std::set indices_types = {kInt32, kInt64}; - std::set valid_types(common_valid_types); - valid_types.emplace(kComplex128); - (void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, indices_types, prim_name); - (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, valid_types, prim_name); - (void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, valid_types, prim_name); - - std::map type_dict; - type_dict.emplace("input_x", input_x_type_ptr); - type_dict.emplace("updates", updates_type_ptr); - return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name); -} -} // namespace - -MIND_API_OPERATOR_IMPL(ScatterDiv, BaseOperator); -AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - const int64_t input_num = 3; - (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name()); - auto infer_type = ScatterDivInferType(primitive, input_args); - auto infer_shape = ScatterDivInferShape(primitive, input_args); - return abstract::MakeAbstract(infer_shape, infer_type); -} -REGISTER_PRIMITIVE_EVAL_IMPL(ScatterDiv, prim::kPrimScatterDiv, ScatterDivInfer, nullptr, true); -} // namespace ops -} // namespace mindspore diff --git a/mindspore/core/ops/scatter_div.h b/mindspore/core/ops/scatter_div.h index 42e97d7aa36..bf26f175c34 100644 --- a/mindspore/core/ops/scatter_div.h +++ b/mindspore/core/ops/scatter_div.h @@ -30,9 +30,6 @@ class MIND_API ScatterDiv : public BaseOperator { /// \brief Constructor. ScatterDiv() : BaseOperator(kNameScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"output"}); } }; - -abstract::AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); } // namespace ops } // namespace mindspore