!21280 [MS][LITE]add ScatterNdUpdate C op

Merge pull request !21280 from mengyuanli/add_scatterNDims
This commit is contained in:
i-robot 2021-08-03 07:20:20 +00:00 committed by Gitee
commit 9aa662d0fc
11 changed files with 230 additions and 2 deletions

View File

@ -227,8 +227,9 @@ enum PrimType {
PrimType_Affine = 200,
PrimType_Attention = 201,
PrimType_LSTMGrad = 202,
PrimType_ScatterNdUpdate = 203,
PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_LSTMGrad + 1
PrimType_MAX = PrimType_ScatterNdUpdate + 1
};
void RegInfer(int prim_type, InferShape func);

View File

@ -17,7 +17,6 @@
#define MINDSPORE_NNACL_SCATTER_ND_INFER_H
#include "nnacl/infer/common_infer.h"
#include "nnacl/softmax_parameter.h"
#ifdef __cplusplus
extern "C" {

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/infer/scatter_nd_update_infer.h"
#include "nnacl/infer/infer_register.h"
int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
const TensorC *input_x = inputs[0];
const TensorC *indices = inputs[1];
const TensorC *update = inputs[2];
TensorC *output = outputs[0];
SetDataTypeFormat(output, input_x);
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
if (indices->shape_size_ != update->shape_size_) {
return NNACL_ERR;
}
SetShapeArray(output, input_x->shape_, input_x->shape_size_);
return NNACL_OK;
}
REG_INFER(ScatterNdUpdate, PrimType_ScatterNdUpdate, ScatterNdUpdateInferShape)

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_SCATTER_ND_INFER_H
#define MINDSPORE_NNACL_SCATTER_ND_INFER_H
#include "nnacl/infer/common_infer.h"
#ifdef __cplusplus
extern "C" {
#endif
int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_SCATTER_ND_INFER_H

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/scatter_nd_update.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual,
update_shape[0], "ScatterNdUpdate");
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> input_x_valid_types = {kTensorType};
const std::set<TypePtr> indices_valid_types = {kInt32, kInt64};
const std::set<TypePtr> update_valid_types = {kTensorType};
auto input_x_type = input_args[0]->BuildType();
auto indices_type = input_args[1]->BuildType();
auto update_type = input_args[2]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type, input_x_valid_types, prim->name());
CheckAndConvertUtils::CheckTypeValid("update type", update_type, update_valid_types, prim->name());
CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_valid_types, prim->name());
return input_args[0]->BuildType();
}
} // namespace
AbstractBasePtr ScatterNdUpdateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameScatterNdUpdate, ScatterNdUpdate);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_
#define MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameScatterNdUpdate = "ScatterNdUpdate";
class ScatterNdUpdate : public PrimitiveC {
public:
ScatterNdUpdate() : PrimitiveC(kNameScatterNdUpdate) { InitIOName({"input_x", "indices", "update"}, {"output"}); }
~ScatterNdUpdate() = default;
MS_DECLARE_PARENT(ScatterNdUpdate, PrimitiveC);
void Init() {}
};
AbstractBasePtr ScatterNdUpdateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimScatterNdUpdatePtr = std::shared_ptr<ScatterNdUpdate>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_

View File

@ -220,6 +220,7 @@ union PrimitiveType {
Affine,
Attention,
LSTMGrad,
ScatterNdUpdate,
}
table Abs {
@ -1212,3 +1213,6 @@ table Affine {
table Attention {
}
table ScatterNdUpdate {
}

View File

@ -220,6 +220,7 @@ OP_TYPE(TensorArrayWrite)
OP_TYPE(Affine)
OP_TYPE(Attention)
OP_TYPE(LSTMGrad)
OP_TYPE(ScatterNdUpdate)
OP_TYPE_DEF_END(PrimitiveType)
OP_SCHEMA_DEF(Abs)
@ -1212,3 +1213,6 @@ OP_SCHEMA_DEF_END(Affine)
OP_SCHEMA_DEF(Attention)
OP_SCHEMA_DEF_END(Attention)
OP_SCHEMA_DEF(ScatterNdUpdate)
OP_SCHEMA_DEF_END(ScatterNdUpdate)

View File

@ -131,6 +131,7 @@
#include "ops/rsqrt.h"
#include "ops/scale.h"
#include "ops/scatter_nd.h"
#include "ops/scatter_nd_update.h"
#include "ops/select.h"
#include "ops/sgd.h"
#include "ops/shape.h"
@ -462,6 +463,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayRead)
FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayWrite)
FUNC_MSOP2SCHEMAOP_DECLARE(Affine)
FUNC_MSOP2SCHEMAOP_DECLARE(Attention)
FUNC_MSOP2SCHEMAOP_DECLARE(ScatterNdUpdate)
#endif
} // namespace mindspore::lite::ops
#else

View File

@ -809,6 +809,11 @@ std::unique_ptr<schema::PrimitiveT> AttentionPrimitiveCreator(const AnfNodePtr &
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
std::unique_ptr<schema::PrimitiveT> ScatterNdUpdatePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ScatterNdUpdate>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
@ -1034,6 +1039,7 @@ RegistryMSOps g_TensorArrayReadCreatorRegistry("TensorArrayRead", TensorArrayRea
RegistryMSOps g_TensorArrayWriteCreatorRegistry("TensorArrayWrite", TensorArrayWritePrimitiveCreator);
RegistryMSOps g_AffineCreatorRegistry("Affine", AffinePrimitiveCreator);
RegistryMSOps g_AttentionCreatorRegistry("Attention", AttentionPrimitiveCreator);
RegistryMSOps g_ScatterNdUpdateCreatorRegistry("ScatterNdUpdate", ScatterNdUpdatePrimitiveCreator);
std::unique_ptr<schema::PrimitiveT> CustomPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
using mindspore::schema::PrimitiveType_ScatterNdUpdate;
namespace mindspore {
namespace lite {
OpParameter *PopulateScatterNDUpdateParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_ScatterNdUpdate, PopulateScatterNDUpdateParameter, SCHEMA_CUR)
} // namespace lite
} // namespace mindspore