!22831 [assistant][ops]add math ops Ger

Merge pull request !22831 from yangwm/ger
This commit is contained in:
i-robot 2021-09-28 06:27:35 +00:00 committed by Gitee
commit 2085ab1774
9 changed files with 219 additions and 1 deletions

View File

@ -67,6 +67,7 @@ constexpr auto kAbs = "Abs";
constexpr auto kTrunc = "Trunc";
constexpr auto kSquare = "Square";
constexpr auto kReal = "Real";
constexpr auto kGer = "Ger";
// Arrays
constexpr auto kDynamicShape = "DynamicShape";
@ -467,6 +468,7 @@ inline const PrimitivePtr kPrimTensorListStack = std::make_shared<Primitive>("Te
inline const PrimitivePtr kPrimTensorListSetItem = std::make_shared<Primitive>("TensorListSetItem");
// Maths
inline const PrimitivePtr kPrimGer = std::make_shared<Primitive>("Ger");
inline const PrimitivePtr kPrimCeil = std::make_shared<Primitive>("Ceil");
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>(kAdd);

67
mindspore/core/ops/ger.cc Normal file
View File

@ -0,0 +1,67 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <set>
#include "ops/ger.h"
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x1 rank", SizeToLong(first_input_shape.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("x2 rank", SizeToLong(second_input_shape.size()), kEqual, 1, prim_name);
std::vector<int64_t> out_shape = {first_input_shape[0], second_input_shape[0]};
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
types.emplace("x1", input_args[0]->BuildType());
types.emplace("x2", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
}
} // namespace
AbstractBasePtr GerInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Ger, prim::kPrimGer, GerInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

44
mindspore/core/ops/ger.h Normal file
View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_GER_H_
#define MINDSPORE_CORE_OPS_GER_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameGer = "Ger";
class MS_CORE_API Ger : public PrimitiveC {
public:
Ger() : PrimitiveC(kNameGer) { InitIOName({"x", "y"}, {"output"}); }
~Ger() = default;
MS_DECLARE_PARENT(Ger, PrimitiveC);
};
AbstractBasePtr GerInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimGerPtr = std::shared_ptr<Ger>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_GER_H_

View File

@ -108,3 +108,19 @@ def get_bprop_trunc(self):
return (bc_x,)
return bprop
@bprop_getters.register(P.Ger)
def get_bprop_ger(self):
"""Grad definition for 'Ger' operation"""
transpose_op = P.Transpose()
matmul = P.MatMul()
expand_dims = P.ExpandDims()
squeeze = P.Squeeze(1)
def bprop(input_x, input_y, out, dout):
dx = squeeze(matmul(dout, expand_dims(input_y, 1)))
dy = squeeze(matmul(transpose_op(dout, (1, 0)), expand_dims(input_x, 1)))
return dx, dy
return bprop

View File

@ -65,6 +65,7 @@ from .dropout_do_mask import _dropout_do_mask_tbe
from .dropout_do_mask_ds import _dropout_do_mask_ds_tbe
from .gelu import _gelu_tbe
from .gelu_grad import _gelu_grad_tbe
from .ger import _ger_tbe
from .fast_gelu import _fast_gelu_tbe
from .fast_gelu_grad import _fast_gelu_grad_tbe
from .max_pool import _max_pool_tbe

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Ger op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
ger_op_info = TBERegOp("Ger") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("ger.so") \
.compute_cost(10) \
.kernel_name("ger") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(0, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC, DataType.F16_NHWC) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(ger_op_info)
def _ger_tbe():
"""Ger TBE register"""
return

View File

@ -47,7 +47,7 @@ from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerA
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay)
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr,
BitwiseAnd, BitwiseOr, Ger,
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
@ -123,6 +123,7 @@ from .rl_ops import (BufferAppend, BufferGetItem, BufferSample)
from ._inner_ops import (MatmulDDS, DSDMatmul, NonZero)
__all__ = [
'Ger',
'Unique',
'ReverseSequence',
'Sort',

View File

@ -136,6 +136,46 @@ class _BitwiseBinaryOp(_MathBinaryOp):
return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name)
class Ger(Primitive):
r"""
Ger product of `x1` and `x2`. Calculate the outer product of two one-dimensional arrays.If `x1` is a 1D Tensor of
shape :math:`(m,)` and `x2` is a 1D Tensor of shape :math:`(n,)`,then `output` must be a Tensor of shape
:math:`(m * n)`.
Inputs:
- **x1** - (Tensor) - 1-D input Tensor, with dtype of float16 or float32.
- **x2** - (Tensor) - 1-D input Tensor, with dtype of float16 or float32.
Outputs:
Tensor, output matrix with the same dtype as inputs.With `x1` shape :math:`(m,)` and
`x2` shape of :math:`(n,)`,`output` has shape :math:`(m * n)`.
Raises:
TypeError: If `x1` or `x2` is not a Tensor.
TypeError: If the dtype of `x1` and `x2` is neither float16 nor float32.
ValueError: If `x1` or `x2` is not a 1D Tensor.
Supported Platforms:
``Ascend``
Examples:
>>> x1 = Tensor([1., 2., 3., 4.], mindspore.float32)
>>> x2 = Tensor([1., 2., 3.], mindspore.float32)
>>> ger = ops.Ger()
>>> output = ger(x1, x2)
>>> print(output)
[[ 1. 2. 3.]
[ 2. 4. 6.]
[ 3. 6. 9.]
[ 4. 8. 12.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Ger"""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
class Add(_MathBinaryOp):
r"""

View File

@ -1043,6 +1043,10 @@ class SparseApplyRMSPropNet(nn.Cell):
return out
test_case_math_ops = [
('Ger', {
'block': P.Ger(),
'desc_inputs': [[3,], [4,]],
'desc_bprop': [[3, 4]]}),
('BitwiseAnd', {
'block': P.BitwiseAnd(),
'desc_inputs': [Tensor(np.array([0, 0, 1, -1, 1, 1, 1]), mstype.int16),