modify the codecheck

This commit is contained in:
hujiahui8 2022-07-26 15:52:28 +08:00
parent efc85e972a
commit 110caf48e7
3 changed files with 4 additions and 98 deletions

View File

@ -21,6 +21,7 @@
#include "ops/scatter_update.h"
#include "ops/scatter_min.h"
#include "ops/scatter_max.h"
#include "ops/scatter_div.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
@ -96,5 +97,8 @@ REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMin, prim::kPrimScatterMin, ScatterArithmeti
MIND_API_OPERATOR_IMPL(ScatterMax, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMax, prim::kPrimScatterMax, ScatterArithmeticInfer, nullptr, true);
MIND_API_OPERATOR_IMPL(ScatterDiv, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterDiv, prim::kPrimScatterDiv, ScatterArithmeticInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,95 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/scatter_div.h"
#include <set>
#include <map>
#include <string>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ScatterDivInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
BaseShapePtr input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
BaseShapePtr indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
if (input_x_shape_ptr->IsDynamic()) {
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
<< "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is "
<< input_x_shape_ptr->ToString();
}
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
return input_x_shape_ptr->cast<abstract::ShapePtr>();
}
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
std::vector<int64_t> indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
std::vector<int64_t> updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
std::vector<int64_t> check_update_shape(indices_shape);
for (int64_t i = 1; i < SizeToLong(input_x_shape.size()); ++i) {
check_update_shape.push_back(input_x_shape[i]);
}
if (updates_shape != check_update_shape) {
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
<< "updates_shape = indices_shape + x_shape[1:], but got x_shape: "
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
}
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
return output_shape;
}
TypePtr ScatterDivInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto input_x_type_ptr = input_args[kInputIndex0]->BuildType();
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
auto prim_name = primitive->name();
const std::set<TypePtr> indices_types = {kInt32, kInt64};
std::set<TypePtr> valid_types(common_valid_types);
valid_types.emplace(kComplex128);
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, indices_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, valid_types, prim_name);
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("input_x", input_x_type_ptr);
type_dict.emplace("updates", updates_type_ptr);
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(ScatterDiv, BaseOperator);
AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto infer_type = ScatterDivInferType(primitive, input_args);
auto infer_shape = ScatterDivInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterDiv, prim::kPrimScatterDiv, ScatterDivInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -30,9 +30,6 @@ class MIND_API ScatterDiv : public BaseOperator {
/// \brief Constructor.
ScatterDiv() : BaseOperator(kNameScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
};
abstract::AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore