forked from mindspore-Ecosystem/mindspore
!21280 [MS][LITE]add ScatterNdUpdate C op
Merge pull request !21280 from mengyuanli/add_scatterNDims
This commit is contained in:
commit
9aa662d0fc
|
@ -227,8 +227,9 @@ enum PrimType {
|
|||
PrimType_Affine = 200,
|
||||
PrimType_Attention = 201,
|
||||
PrimType_LSTMGrad = 202,
|
||||
PrimType_ScatterNdUpdate = 203,
|
||||
PrimType_MIN = PrimType_NONE,
|
||||
PrimType_MAX = PrimType_LSTMGrad + 1
|
||||
PrimType_MAX = PrimType_ScatterNdUpdate + 1
|
||||
};
|
||||
|
||||
void RegInfer(int prim_type, InferShape func);
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#define MINDSPORE_NNACL_SCATTER_ND_INFER_H
|
||||
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/infer/scatter_nd_update_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
|
||||
int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
|
||||
const TensorC *input_x = inputs[0];
|
||||
const TensorC *indices = inputs[1];
|
||||
const TensorC *update = inputs[2];
|
||||
TensorC *output = outputs[0];
|
||||
|
||||
SetDataTypeFormat(output, input_x);
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
if (indices->shape_size_ != update->shape_size_) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
SetShapeArray(output, input_x->shape_, input_x->shape_size_);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
REG_INFER(ScatterNdUpdate, PrimType_ScatterNdUpdate, ScatterNdUpdateInferShape)
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_SCATTER_ND_INFER_H
|
||||
#define MINDSPORE_NNACL_SCATTER_ND_INFER_H
|
||||
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_NNACL_SCATTER_ND_INFER_H
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/scatter_nd_update.h"
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual,
|
||||
update_shape[0], "ScatterNdUpdate");
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> input_x_valid_types = {kTensorType};
|
||||
const std::set<TypePtr> indices_valid_types = {kInt32, kInt64};
|
||||
const std::set<TypePtr> update_valid_types = {kTensorType};
|
||||
auto input_x_type = input_args[0]->BuildType();
|
||||
auto indices_type = input_args[1]->BuildType();
|
||||
auto update_type = input_args[2]->BuildType();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type, input_x_valid_types, prim->name());
|
||||
CheckAndConvertUtils::CheckTypeValid("update type", update_type, update_valid_types, prim->name());
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_valid_types, prim->name());
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr ScatterNdUpdateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameScatterNdUpdate, ScatterNdUpdate);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_
|
||||
#define MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterNdUpdate = "ScatterNdUpdate";
|
||||
class ScatterNdUpdate : public PrimitiveC {
|
||||
public:
|
||||
ScatterNdUpdate() : PrimitiveC(kNameScatterNdUpdate) { InitIOName({"input_x", "indices", "update"}, {"output"}); }
|
||||
~ScatterNdUpdate() = default;
|
||||
MS_DECLARE_PARENT(ScatterNdUpdate, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr ScatterNdUpdateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimScatterNdUpdatePtr = std::shared_ptr<ScatterNdUpdate>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SCATTER_ND_UPDATE_H_
|
|
@ -220,6 +220,7 @@ union PrimitiveType {
|
|||
Affine,
|
||||
Attention,
|
||||
LSTMGrad,
|
||||
ScatterNdUpdate,
|
||||
}
|
||||
|
||||
table Abs {
|
||||
|
@ -1212,3 +1213,6 @@ table Affine {
|
|||
|
||||
table Attention {
|
||||
}
|
||||
|
||||
table ScatterNdUpdate {
|
||||
}
|
||||
|
|
|
@ -220,6 +220,7 @@ OP_TYPE(TensorArrayWrite)
|
|||
OP_TYPE(Affine)
|
||||
OP_TYPE(Attention)
|
||||
OP_TYPE(LSTMGrad)
|
||||
OP_TYPE(ScatterNdUpdate)
|
||||
OP_TYPE_DEF_END(PrimitiveType)
|
||||
|
||||
OP_SCHEMA_DEF(Abs)
|
||||
|
@ -1212,3 +1213,6 @@ OP_SCHEMA_DEF_END(Affine)
|
|||
|
||||
OP_SCHEMA_DEF(Attention)
|
||||
OP_SCHEMA_DEF_END(Attention)
|
||||
|
||||
OP_SCHEMA_DEF(ScatterNdUpdate)
|
||||
OP_SCHEMA_DEF_END(ScatterNdUpdate)
|
||||
|
|
|
@ -131,6 +131,7 @@
|
|||
#include "ops/rsqrt.h"
|
||||
#include "ops/scale.h"
|
||||
#include "ops/scatter_nd.h"
|
||||
#include "ops/scatter_nd_update.h"
|
||||
#include "ops/select.h"
|
||||
#include "ops/sgd.h"
|
||||
#include "ops/shape.h"
|
||||
|
@ -462,6 +463,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayRead)
|
|||
FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayWrite)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Affine)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Attention)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(ScatterNdUpdate)
|
||||
#endif
|
||||
} // namespace mindspore::lite::ops
|
||||
#else
|
||||
|
|
|
@ -809,6 +809,11 @@ std::unique_ptr<schema::PrimitiveT> AttentionPrimitiveCreator(const AnfNodePtr &
|
|||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PrimitiveT> ScatterNdUpdatePrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ScatterNdUpdate>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
|
||||
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
|
||||
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
|
||||
|
@ -1034,6 +1039,7 @@ RegistryMSOps g_TensorArrayReadCreatorRegistry("TensorArrayRead", TensorArrayRea
|
|||
RegistryMSOps g_TensorArrayWriteCreatorRegistry("TensorArrayWrite", TensorArrayWritePrimitiveCreator);
|
||||
RegistryMSOps g_AffineCreatorRegistry("Affine", AffinePrimitiveCreator);
|
||||
RegistryMSOps g_AttentionCreatorRegistry("Attention", AttentionPrimitiveCreator);
|
||||
RegistryMSOps g_ScatterNdUpdateCreatorRegistry("ScatterNdUpdate", ScatterNdUpdatePrimitiveCreator);
|
||||
|
||||
std::unique_ptr<schema::PrimitiveT> CustomPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
using mindspore::schema::PrimitiveType_ScatterNdUpdate;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateScatterNDUpdateParameter(const void *prim) {
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
|
||||
auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(OpParameter));
|
||||
|
||||
param->type_ = primitive->value_type();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
REG_POPULATE(PrimitiveType_ScatterNdUpdate, PopulateScatterNDUpdateParameter, SCHEMA_CUR)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue