!25626 [assistant][ops]New operator implementation, include ScatterNonAliasingAdd

Merge pull request !25626 from 陈桂锟/ScatterNonAliasingAdd
This commit is contained in:
i-robot 2021-12-20 02:08:58 +00:00 committed by Gitee
commit 00e912da2a
6 changed files with 178 additions and 8 deletions

View File

@ -262,6 +262,7 @@ inline const PrimitivePtr kPrimBatchToSpace = std::make_shared<Primitive>("Batch
inline const PrimitivePtr kPrimSpaceToBatch = std::make_shared<Primitive>("SpaceToBatch");
inline const PrimitivePtr kPrimScatterNd = std::make_shared<Primitive>("ScatterNd");
inline const PrimitivePtr kPrimScatterNdUpdate = std::make_shared<Primitive>("ScatterNdUpdate");
inline const PrimitivePtr kPrimScatterNonAliasingAdd = std::make_shared<Primitive>("ScatterNonAliasingAdd");
inline const PrimitivePtr kPrimConstantOfShape = std::make_shared<Primitive>("ConstantOfShape");
inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference");
inline const PrimitivePtr kPrimReverseV2 = std::make_shared<Primitive>("ReverseV2");

View File

@ -0,0 +1,84 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/scatter_non_aliasing_add.h"
#include <map>
#include <set>
#include <string>
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ScatterNonAliasingAddInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
auto indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
auto updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
if (indices_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape[0], kNotEqual, -1);
}
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildShape());
auto last_dim = indices_shape.back();
indices_shape.pop_back();
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
(void)CheckAndConvertUtils::CheckInteger("length of updates_shape and indices_shape + x_shape[1:]",
updates_shape.size(), kEqual, indices_shape.size(), prim_name);
for (size_t i = 0; i < updates_shape.size(); i++) {
(void)CheckAndConvertUtils::CheckInteger("elements of updates_shape and indices_shape + x_shape[1:]",
updates_shape[i], kEqual, indices_shape[i], prim_name);
}
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}
TypePtr ScatterNonAliasingAddInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
std::set<TypePtr> type_set = {kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("input_x", input_args[kInputIndex0]->BuildType());
type_dict.emplace("updates", input_args[kInputIndex2]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
}
} // namespace
AbstractBasePtr ScatterNonAliasingAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
auto infer_type = ScatterNonAliasingAddInferType(primitive, input_args);
auto infer_shape = ScatterNonAliasingAddInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterNonAliasingAdd, prim::kPrimScatterNonAliasingAdd, ScatterNonAliasingAddInfer,
nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCATTER_NON_ALIASING_ADD_H_
#define MINDSPORE_CORE_OPS_SCATTER_NON_ALIASING_ADD_H_
#include <memory>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameScatterNonAliasingAdd = "ScatterNonAliasingAdd";
class ScatterNonAliasingAdd : public PrimitiveC {
public:
ScatterNonAliasingAdd() : PrimitiveC(kNameScatterNonAliasingAdd) {
InitIOName({"input_x", "indices", "updates"}, {"y"});
}
~ScatterNonAliasingAdd() = default;
MS_DECLARE_PARENT(ScatterNonAliasingAdd, PrimitiveC);
};
AbstractBasePtr ScatterNonAliasingAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimScatterNonAliasingAddPtr = std::shared_ptr<ScatterNonAliasingAdd>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCATTER_NON_ALIASING_ADD_H_

View File

@ -142,6 +142,7 @@ from .scatter_nd_add import _scatter_nd_add_tbe
from .scatter_nd_add_ds import _scatter_nd_add_ds_tbe
from .scatter_nd_sub import _scatter_nd_sub_tbe
from .scatter_non_aliasing_add import _scatter_non_aliasing_add_tbe
from .scatter_non_aliasing_add_ds import _scatter_non_aliasing_add_ds_tbe
from .reduce_mean import _reduce_mean_tbe
from .tile import _tile_tbe
from .tile_ds import _tile_ds_tbe

View File

@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterNonAliasingAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_non_aliasing_add_ds_op_info = TBERegOp("ScatterNonAliasingAdd") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_non_aliasing_add.so") \
.compute_cost(10) \
.kernel_name("scatter_non_aliasing_add") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "input_x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(scatter_non_aliasing_add_ds_op_info)
def _scatter_non_aliasing_add_ds_tbe():
"""ScatterNonAliasingAdd TBE register"""
return

View File

@ -4929,7 +4929,7 @@ class ScatterNdSub(_ScatterNdOp):
"""
class ScatterNonAliasingAdd(_ScatterNdOp):
class ScatterNonAliasingAdd(Primitive):
"""
Applies sparse addition to the input using individual values or slices.
@ -4969,18 +4969,18 @@ class ScatterNonAliasingAdd(_ScatterNdOp):
[ 1. 10. 9. 4. 12. 6. 7. 17.]
"""
__mindspore_signature__ = (
sig.make_sig('input_x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
sig.make_sig('updates', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self):
"""Initialize ScatterNonAliasingAdd"""
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y'])
self.add_prim_attr('side_effect_mem', True)
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
return x_dtype
class SpaceToDepth(PrimitiveWithInfer):
r"""