diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h index 351e4f70086..5929f8e3f4f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h @@ -227,8 +227,9 @@ enum PrimType { PrimType_Affine = 200, PrimType_Attention = 201, PrimType_LSTMGrad = 202, + PrimType_ScatterNdUpdate = 203, PrimType_MIN = PrimType_NONE, - PrimType_MAX = PrimType_LSTMGrad + 1 + PrimType_MAX = PrimType_ScatterNdUpdate + 1 }; void RegInfer(int prim_type, InferShape func); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_infer.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_infer.h index 699405e831f..7b035b15a0e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_infer.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_infer.h @@ -17,7 +17,6 @@ #define MINDSPORE_NNACL_SCATTER_ND_INFER_H #include "nnacl/infer/common_infer.h" -#include "nnacl/softmax_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_update_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_update_infer.c new file mode 100644 index 00000000000..3262a5b1c53 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_update_infer.c @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/scatter_nd_update_infer.h" +#include "nnacl/infer/infer_register.h" + +int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input_x = inputs[0]; + const TensorC *indices = inputs[1]; + const TensorC *update = inputs[2]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input_x); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (indices->shape_size_ != update->shape_size_) { + return NNACL_ERR; + } + + SetShapeArray(output, input_x->shape_, input_x->shape_size_); + return NNACL_OK; +} + +REG_INFER(ScatterNdUpdate, PrimType_ScatterNdUpdate, ScatterNdUpdateInferShape) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_update_infer.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_update_infer.h new file mode 100644 index 00000000000..72878fb9d04 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_update_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SCATTER_ND_INFER_H +#define MINDSPORE_NNACL_SCATTER_ND_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SCATTER_ND_INFER_H diff --git a/mindspore/core/ops/scatter_nd_update.cc b/mindspore/core/ops/scatter_nd_update.cc new file mode 100644 index 00000000000..eb128665abf --- /dev/null +++ b/mindspore/core/ops/scatter_nd_update.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/scatter_nd_update.h" +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr InferShape(const std::vector &input_args) { + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, + update_shape[0], "ScatterNdUpdate"); + return std::make_shared(in_shape); +} + +TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + const std::set input_x_valid_types = {kTensorType}; + const std::set indices_valid_types = {kInt32, kInt64}; + const std::set update_valid_types = {kTensorType}; + auto input_x_type = input_args[0]->BuildType(); + auto indices_type = input_args[1]->BuildType(); + auto update_type = input_args[2]->BuildType(); + CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type, input_x_valid_types, prim->name()); + CheckAndConvertUtils::CheckTypeValid("update type", update_type, update_valid_types, prim->name()); + CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_valid_types, prim->name()); + return input_args[0]->BuildType(); +} +} // namespace + +AbstractBasePtr ScatterNdUpdateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(InferType(primitive, input_args), InferShape(input_args)->shape()); +} +REGISTER_PRIMITIVE_C(kNameScatterNdUpdate, ScatterNdUpdate); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/scatter_nd_update.h b/mindspore/core/ops/scatter_nd_update.h new file mode 100644 index 00000000000..5909f0ef48a --- /dev/null +++ b/mindspore/core/ops/scatter_nd_update.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_ +#define MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_ +#include +#include + +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameScatterNdUpdate = "ScatterNdUpdate"; +class ScatterNdUpdate : public PrimitiveC { + public: + ScatterNdUpdate() : PrimitiveC(kNameScatterNdUpdate) { InitIOName({"input_x", "indices", "update"}, {"output"}); } + ~ScatterNdUpdate() = default; + MS_DECLARE_PARENT(ScatterNdUpdate, PrimitiveC); + void Init() {} +}; +AbstractBasePtr ScatterNdUpdateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimScatterNdUpdatePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index ded169b7171..e1721611ed2 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -220,6 +220,7 @@ union PrimitiveType { Affine, Attention, LSTMGrad, + ScatterNdUpdate, } table Abs { @@ -1212,3 +1213,6 @@ table Affine { table Attention { } + +table ScatterNdUpdate { +} diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index b64ca1619fb..46c264b86c9 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -220,6 +220,7 @@ OP_TYPE(TensorArrayWrite) OP_TYPE(Affine) OP_TYPE(Attention) OP_TYPE(LSTMGrad) +OP_TYPE(ScatterNdUpdate) OP_TYPE_DEF_END(PrimitiveType) OP_SCHEMA_DEF(Abs) @@ -1212,3 +1213,6 @@ OP_SCHEMA_DEF_END(Affine) OP_SCHEMA_DEF(Attention) OP_SCHEMA_DEF_END(Attention) + +OP_SCHEMA_DEF(ScatterNdUpdate) +OP_SCHEMA_DEF_END(ScatterNdUpdate) diff --git a/mindspore/lite/src/ops/ops_func_declare.h b/mindspore/lite/src/ops/ops_func_declare.h index a2dee794b4e..da54b2dc899 100644 --- a/mindspore/lite/src/ops/ops_func_declare.h +++ b/mindspore/lite/src/ops/ops_func_declare.h @@ -131,6 +131,7 @@ #include "ops/rsqrt.h" #include "ops/scale.h" #include "ops/scatter_nd.h" +#include "ops/scatter_nd_update.h" #include "ops/select.h" #include "ops/sgd.h" #include "ops/shape.h" @@ -462,6 +463,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayRead) FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayWrite) FUNC_MSOP2SCHEMAOP_DECLARE(Affine) FUNC_MSOP2SCHEMAOP_DECLARE(Attention) +FUNC_MSOP2SCHEMAOP_DECLARE(ScatterNdUpdate) #endif } // namespace mindspore::lite::ops #else diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index 10a23304de7..90f57a89bb5 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -809,6 +809,11 @@ std::unique_ptr AttentionPrimitiveCreator(const AnfNodePtr & return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +std::unique_ptr ScatterNdUpdatePrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} + RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); @@ -1034,6 +1039,7 @@ RegistryMSOps g_TensorArrayReadCreatorRegistry("TensorArrayRead", TensorArrayRea RegistryMSOps g_TensorArrayWriteCreatorRegistry("TensorArrayWrite", TensorArrayWritePrimitiveCreator); RegistryMSOps g_AffineCreatorRegistry("Affine", AffinePrimitiveCreator); RegistryMSOps g_AttentionCreatorRegistry("Attention", AttentionPrimitiveCreator); +RegistryMSOps g_ScatterNdUpdateCreatorRegistry("ScatterNdUpdate", ScatterNdUpdatePrimitiveCreator); std::unique_ptr CustomPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); diff --git a/mindspore/lite/src/ops/populate/scatter_nd_update_populate.cc b/mindspore/lite/src/ops/populate/scatter_nd_update_populate.cc new file mode 100644 index 00000000000..cf8a3475f73 --- /dev/null +++ b/mindspore/lite/src/ops/populate/scatter_nd_update_populate.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/populate/populate_register.h" +using mindspore::schema::PrimitiveType_ScatterNdUpdate; + +namespace mindspore { +namespace lite { +OpParameter *PopulateScatterNDUpdateParameter(const void *prim) { + auto primitive = static_cast(prim); + MS_ASSERT(primitive != nullptr); + + auto *param = reinterpret_cast(malloc(sizeof(OpParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "malloc ScatterNDParameter failed."; + return nullptr; + } + memset(param, 0, sizeof(OpParameter)); + + param->type_ = primitive->value_type(); + return reinterpret_cast(param); +} +REG_POPULATE(PrimitiveType_ScatterNdUpdate, PopulateScatterNDUpdateParameter, SCHEMA_CUR) +} // namespace lite +} // namespace mindspore