!20055 [assistant][ops] Add math operator IndexAdd

Merge pull request !20055 from 孟权令/IndexAdd
This commit is contained in:
i-robot 2021-08-26 01:15:39 +00:00 committed by Gitee
commit 457a01fd09
7 changed files with 192 additions and 28 deletions

View File

@ -503,6 +503,7 @@ inline const PrimitivePtr kPrimAsinhGrad = std::make_shared<Primitive>("AsinhGra
inline const PrimitivePtr kPrimAcoshGrad = std::make_shared<Primitive>("AcoshGrad");
inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
inline const PrimitivePtr kPrimIndexAdd = std::make_shared<Primitive>("IndexAdd");
inline const PrimitivePtr kPrimIdentityMath = std::make_shared<Primitive>("Identity", kSideEffectPropagate);
inline const PrimitivePtr kPrimErfinv = std::make_shared<Primitive>("Erfinv");
inline const PrimitivePtr kPrimIsNan = std::make_shared<Primitive>("IsNan");

View File

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/index_add.h"
#include <algorithm>
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr IndexAddInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
auto y_rank = SizeToLong(y_shape.size());
CheckAndConvertUtils::Check("x rank", x_rank, kEqual, "y rank", y_rank, prim_name);
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeNeither, {-x_rank - 1, x_rank}, prim_name);
auto idx_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto idx_rank = SizeToLong(idx_shape.size());
CheckAndConvertUtils::CheckInteger("idx size", idx_rank, kEqual, 1, prim_name);
auto axis_rank = axis;
if (axis < 0) {
axis_rank = axis + x_rank;
}
CheckAndConvertUtils::Check("size of indices", idx_shape[0], kEqual, "dimension of y[axis]", y_shape[axis_rank],
prim_name);
for (int dim = 0; dim < x_rank; dim = dim + 1) {
if (dim != axis_rank) {
CheckAndConvertUtils::Check("x dim", x_shape[dim], kEqual, "y dim", y_shape[dim], prim_name);
}
}
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr IndexAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = prim->name();
CheckAndConvertUtils::CheckInteger("IndexAdd infer", input_args.size(), kEqual, 3, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kUInt8, kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> indices_types = {kInt32};
auto var_type = input_args[0]->BuildType();
auto indices_type = input_args[1]->BuildType();
auto updates_type = input_args[2]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_types, prim->name());
CheckAndConvertUtils::CheckTensorTypeValid("input_y type", updates_type, valid_types, prim->name());
return CheckAndConvertUtils::CheckTensorTypeValid("input_x type", var_type, valid_types, prim->name());
}
} // namespace
AbstractBasePtr IndexAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(IndexAddInferShape(primitive, input_args), IndexAddInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(IndexAdd, prim::kPrimIndexAdd, IndexAddInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_INDEX_ADD_H_
#define MINDSPORE_CORE_OPS_INDEX_ADD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include <set>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameIndexAdd = "IndexAdd";
class IndexAdd : public PrimitiveC {
public:
IndexAdd() : PrimitiveC(kNameIndexAdd) { InitIOName({"input_x", "indices", "input_y"}, {"output"}); }
~IndexAdd() = default;
MS_DECLARE_PARENT(IndexAdd, PrimitiveC);
};
AbstractBasePtr IndexAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimIndexAddPtr = std::shared_ptr<IndexAdd>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_INDEX_ADD_H_

View File

@ -399,6 +399,7 @@ from .roll import _roll_tbe
from .soft_shrink import _soft_shrink_tbe
from .erfinv import _erfinv_tbe
from .soft_shrink_grad import _soft_shrink_grad_tbe
from .index_add import _index_add_tbe
from .hsigmoid_grad import _hsigmoid_grad_tbe
from .hsigmoid import _hsigmoid_tbe
from .hshrink import _hshrink_tbe

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""IndexAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
index_add_op_info = TBERegOp("IndexAdd") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("inplace_index_add.so") \
.compute_cost(10) \
.kernel_name("inplace_index_add") \
.partial_flag(True) \
.attr("axis", "required", "int", "all") \
.input(0, "input_x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "input_y", False, "required", "all") \
.output(0, "input_x", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(index_add_op_info)
def _index_add_tbe():
"""IndexAdd TBE register"""
return

View File

@ -5177,7 +5177,7 @@ class MatrixInverse(PrimitiveWithInfer):
return x_shape
class IndexAdd(PrimitiveWithInfer):
class IndexAdd(Primitive):
"""
Adds tensor y to specified axis and indices of tensor x. The axis should be in the range from 0 to len(x.dim) - 1,
and indices should be in the range from 0 to the size of x at the axis dimension.
@ -5186,8 +5186,7 @@ class IndexAdd(PrimitiveWithInfer):
axis (int): The dimension along which to index.
Inputs:
- **x** (Parameter) - The input tensor to add to, with data type float64, float32, float16, int32, int16,
int8, uint8.
- **x** (Parameter) - The input tensor to add to.
- **indices** (Tensor) - The index of `x` on the `axis` th dimension to add to, with data type int32.
The `indices` must be 1D with the same size as the size of the `axis` th dimension of `y`. The values
of `indices` should be in the range of 0 to the size of the `axis` th dimension of `x`.
@ -5198,16 +5197,15 @@ class IndexAdd(PrimitiveWithInfer):
Tensor, has the same shape and dtype as x.
Raises:
TypeError: If dtype of `x` is not one of: float64, float32, float16, int32, int16, int8, uint8.
TypeError: If `x` is not a Tensor.
TypeError: If neither `indices` nor `y` is a Tensor.
TypeError: If shape of `y` is not same as the `x`.
ValueError: If axis is out of `x` rank's range.
ValueError: If `x` rank is not the same as `y` rank.
ValueError: If size of `indices` is not equal to dimension of y[axis].
ValueError: If `y`'s shape is not the same as `x` except the `axis` th dimension.
Supported Platforms:
``GPU``
``Ascend`` ``GPU``
Examples:
>>> class Net(nn.Cell):
@ -5241,28 +5239,6 @@ class IndexAdd(PrimitiveWithInfer):
self.axis = axis
validator.check_value_type('axis', axis, [int], self.name)
def infer_dtype(self, x_dtype, idx_type, y_dtype):
args = {'input_x': x_dtype, 'input_y': y_dtype}
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.int32, mstype.int16, mstype.int8,
mstype.uint8]
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
valid_idx_type = [mstype.int32]
validator.check_tensor_dtype_valid('indices', idx_type, valid_idx_type, self.name)
return x_dtype
def infer_shape(self, x_shape, idx_shape, y_shape):
validator.check("x rank", len(x_shape), "y rank", len(y_shape), Rel.EQ, self.name)
x_rank = len(x_shape)
validator.check_int_range(self.axis, -x_rank - 1, x_rank, Rel.INC_NEITHER, 'axis', self.name)
validator.check_equal_int(len(idx_shape), 1, "rank of idx_shape", self.name)
validator.check("size of indices", idx_shape[0], "dimension of y[axis]", y_shape[self.axis],
Rel.EQ, self.name)
axis = self.axis if self.axis >= 0 else x_rank + self.axis
for dim in range(x_rank):
if dim != axis:
validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name)
return x_shape
class Erfinv(Primitive):
r"""

View File

@ -81,6 +81,18 @@ def test_recursive_grad():
f2(x, y)
class IndexAdd(nn.Cell):
"""IndexAdd net definition"""
def __init__(self, axis):
super(IndexAdd, self).__init__()
self.index_add = P.IndexAdd(axis)
self.input_x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)))
def construct(self, indices, updates):
return self.index_add(self.input_x, indices, updates)
class InputBackward(nn.Cell):
def __init__(self, network):
super(InputBackward, self).__init__()
@ -1671,6 +1683,11 @@ test_case_math_ops = [
'block': P.Erfinv(),
'desc_inputs': [Tensor(np.array([0.1, 0.1, 0.1]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([1, 1, 1]).astype(np.float16))]}),
('IndexAdd', {
'block': IndexAdd(1),
'desc_inputs': (Tensor(np.array([0, 1, 2]).astype(np.int32)),
Tensor(np.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0], [2.0, 2.5, 3.0]]).astype(np.float32))),
'desc_bprop': [Tensor(np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).astype(np.float32))]}),
]
test_case_nn_ops = [