forked from mindspore-Ecosystem/mindspore
!25626 [assistant][ops]New operator implementation, include ScatterNonAliasingAdd
Merge pull request !25626 from 陈桂锟/ScatterNonAliasingAdd
This commit is contained in:
commit
00e912da2a
|
@ -262,6 +262,7 @@ inline const PrimitivePtr kPrimBatchToSpace = std::make_shared<Primitive>("Batch
|
|||
inline const PrimitivePtr kPrimSpaceToBatch = std::make_shared<Primitive>("SpaceToBatch");
|
||||
inline const PrimitivePtr kPrimScatterNd = std::make_shared<Primitive>("ScatterNd");
|
||||
inline const PrimitivePtr kPrimScatterNdUpdate = std::make_shared<Primitive>("ScatterNdUpdate");
|
||||
inline const PrimitivePtr kPrimScatterNonAliasingAdd = std::make_shared<Primitive>("ScatterNonAliasingAdd");
|
||||
inline const PrimitivePtr kPrimConstantOfShape = std::make_shared<Primitive>("ConstantOfShape");
|
||||
inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference");
|
||||
inline const PrimitivePtr kPrimReverseV2 = std::make_shared<Primitive>("ReverseV2");
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/scatter_non_aliasing_add.h"
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ScatterNonAliasingAddInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
auto indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
|
||||
auto updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
if (indices_shape.size() == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape[0], kNotEqual, -1);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildShape());
|
||||
auto last_dim = indices_shape.back();
|
||||
indices_shape.pop_back();
|
||||
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
|
||||
(void)CheckAndConvertUtils::CheckInteger("length of updates_shape and indices_shape + x_shape[1:]",
|
||||
updates_shape.size(), kEqual, indices_shape.size(), prim_name);
|
||||
for (size_t i = 0; i < updates_shape.size(); i++) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("elements of updates_shape and indices_shape + x_shape[1:]",
|
||||
updates_shape[i], kEqual, indices_shape[i], prim_name);
|
||||
}
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
TypePtr ScatterNonAliasingAddInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||
std::set<TypePtr> type_set = {kInt32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("input_x", input_args[kInputIndex0]->BuildType());
|
||||
type_dict.emplace("updates", input_args[kInputIndex2]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr ScatterNonAliasingAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputNum = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
|
||||
auto infer_type = ScatterNonAliasingAddInferType(primitive, input_args);
|
||||
auto infer_shape = ScatterNonAliasingAddInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterNonAliasingAdd, prim::kPrimScatterNonAliasingAdd, ScatterNonAliasingAddInfer,
|
||||
nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SCATTER_NON_ALIASING_ADD_H_
|
||||
#define MINDSPORE_CORE_OPS_SCATTER_NON_ALIASING_ADD_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterNonAliasingAdd = "ScatterNonAliasingAdd";
|
||||
class ScatterNonAliasingAdd : public PrimitiveC {
|
||||
public:
|
||||
ScatterNonAliasingAdd() : PrimitiveC(kNameScatterNonAliasingAdd) {
|
||||
InitIOName({"input_x", "indices", "updates"}, {"y"});
|
||||
}
|
||||
~ScatterNonAliasingAdd() = default;
|
||||
MS_DECLARE_PARENT(ScatterNonAliasingAdd, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr ScatterNonAliasingAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using kPrimScatterNonAliasingAddPtr = std::shared_ptr<ScatterNonAliasingAdd>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SCATTER_NON_ALIASING_ADD_H_
|
|
@ -142,6 +142,7 @@ from .scatter_nd_add import _scatter_nd_add_tbe
|
|||
from .scatter_nd_add_ds import _scatter_nd_add_ds_tbe
|
||||
from .scatter_nd_sub import _scatter_nd_sub_tbe
|
||||
from .scatter_non_aliasing_add import _scatter_non_aliasing_add_tbe
|
||||
from .scatter_non_aliasing_add_ds import _scatter_non_aliasing_add_ds_tbe
|
||||
from .reduce_mean import _reduce_mean_tbe
|
||||
from .tile import _tile_tbe
|
||||
from .tile_ds import _tile_ds_tbe
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ScatterNonAliasingAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_non_aliasing_add_ds_op_info = TBERegOp("ScatterNonAliasingAdd") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_non_aliasing_add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_non_aliasing_add") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "input_x", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "updates", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_non_aliasing_add_ds_op_info)
|
||||
def _scatter_non_aliasing_add_ds_tbe():
|
||||
"""ScatterNonAliasingAdd TBE register"""
|
||||
return
|
|
@ -4929,7 +4929,7 @@ class ScatterNdSub(_ScatterNdOp):
|
|||
"""
|
||||
|
||||
|
||||
class ScatterNonAliasingAdd(_ScatterNdOp):
|
||||
class ScatterNonAliasingAdd(Primitive):
|
||||
"""
|
||||
Applies sparse addition to the input using individual values or slices.
|
||||
|
||||
|
@ -4969,18 +4969,18 @@ class ScatterNonAliasingAdd(_ScatterNdOp):
|
|||
[ 1. 10. 9. 4. 12. 6. 7. 17.]
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('updates', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ScatterNonAliasingAdd"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
||||
self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
|
||||
args = {"x": x_dtype, "updates": updates_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SpaceToDepth(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue