forked from mindspore-Ecosystem/mindspore
modify the codecheck
This commit is contained in:
parent
efc85e972a
commit
110caf48e7
|
@ -21,6 +21,7 @@
|
|||
#include "ops/scatter_update.h"
|
||||
#include "ops/scatter_min.h"
|
||||
#include "ops/scatter_max.h"
|
||||
#include "ops/scatter_div.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
@ -96,5 +97,8 @@ REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMin, prim::kPrimScatterMin, ScatterArithmeti
|
|||
|
||||
MIND_API_OPERATOR_IMPL(ScatterMax, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMax, prim::kPrimScatterMax, ScatterArithmeticInfer, nullptr, true);
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ScatterDiv, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterDiv, prim::kPrimScatterDiv, ScatterArithmeticInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/scatter_div.h"
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ScatterDivInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
BaseShapePtr input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
BaseShapePtr indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
|
||||
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||
|
||||
if (input_x_shape_ptr->IsDynamic()) {
|
||||
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||
<< "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is "
|
||||
<< input_x_shape_ptr->ToString();
|
||||
}
|
||||
|
||||
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
std::vector<int64_t> indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
std::vector<int64_t> updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
std::vector<int64_t> check_update_shape(indices_shape);
|
||||
for (int64_t i = 1; i < SizeToLong(input_x_shape.size()); ++i) {
|
||||
check_update_shape.push_back(input_x_shape[i]);
|
||||
}
|
||||
if (updates_shape != check_update_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||
<< "updates_shape = indices_shape + x_shape[1:], but got x_shape: "
|
||||
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
|
||||
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
|
||||
}
|
||||
|
||||
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
TypePtr ScatterDivInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_x_type_ptr = input_args[kInputIndex0]->BuildType();
|
||||
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
|
||||
auto prim_name = primitive->name();
|
||||
const std::set<TypePtr> indices_types = {kInt32, kInt64};
|
||||
std::set<TypePtr> valid_types(common_valid_types);
|
||||
valid_types.emplace(kComplex128);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, indices_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, valid_types, prim_name);
|
||||
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("input_x", input_x_type_ptr);
|
||||
type_dict.emplace("updates", updates_type_ptr);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ScatterDiv, BaseOperator);
|
||||
AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto infer_type = ScatterDivInferType(primitive, input_args);
|
||||
auto infer_shape = ScatterDivInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterDiv, prim::kPrimScatterDiv, ScatterDivInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -30,9 +30,6 @@ class MIND_API ScatterDiv : public BaseOperator {
|
|||
/// \brief Constructor.
|
||||
ScatterDiv() : BaseOperator(kNameScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue