forked from mindspore-Ecosystem/mindspore
!2241 Adapting operator named AccumulateNV2
Merge pull request !2241 from zhangzheng/accumulate
This commit is contained in:
commit
932b7649e7
|
@ -82,6 +82,7 @@ const std::map<std::string, OperatorType> DictOpType{
|
||||||
{"Abs", OperatorType::kRecElmWiseOp},
|
{"Abs", OperatorType::kRecElmWiseOp},
|
||||||
{"Acosh", OperatorType::kRecElmWiseOp},
|
{"Acosh", OperatorType::kRecElmWiseOp},
|
||||||
{"AddN", OperatorType::kRecElmWiseOp},
|
{"AddN", OperatorType::kRecElmWiseOp},
|
||||||
|
{"AccumulateNV2", OperatorType::kRecElmWiseOp},
|
||||||
{"Atan2", OperatorType::kRecElmWiseOp},
|
{"Atan2", OperatorType::kRecElmWiseOp},
|
||||||
{"Erf", OperatorType::kRecElmWiseOp},
|
{"Erf", OperatorType::kRecElmWiseOp},
|
||||||
{"Floor", OperatorType::kRecElmWiseOp},
|
{"Floor", OperatorType::kRecElmWiseOp},
|
||||||
|
|
|
@ -932,6 +932,18 @@ def get_bprop_scalar_cast(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.AccumulateNV2)
|
||||||
|
def get_bprop_scalar_accumulatenv2(self):
|
||||||
|
"""Generate bprop for AccumulateNV2"""
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx = ()
|
||||||
|
for _ in range(len(x)):
|
||||||
|
dx = dx + (dout,)
|
||||||
|
return dx
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.AddN)
|
@bprop_getters.register(P.AddN)
|
||||||
def get_bprop_scalar_addn(self):
|
def get_bprop_scalar_addn(self):
|
||||||
"""Generate bprop for AddN"""
|
"""Generate bprop for AddN"""
|
||||||
|
|
|
@ -26,6 +26,7 @@ from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe
|
||||||
from .add import _add_tbe
|
from .add import _add_tbe
|
||||||
from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
|
from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
|
||||||
from .add_n import _add_n_tbe
|
from .add_n import _add_n_tbe
|
||||||
|
from .accumulate_n_v2 import _accumulate_n_v2_tbe
|
||||||
from .apply_ftrl import _apply_ftrl_tbe
|
from .apply_ftrl import _apply_ftrl_tbe
|
||||||
from .apply_momentum import _apply_momentum_tbe
|
from .apply_momentum import _apply_momentum_tbe
|
||||||
from .apply_adam import _apply_adam_tbe
|
from .apply_adam import _apply_adam_tbe
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""AccumulateNV2 op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
accumulate_n_v2_op_info = TBERegOp("AccumulateNV2") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("accumulate_n_v2.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("accumulate_n_v2") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("n", "required", "int", "all") \
|
||||||
|
.input(0, "x", False, "dynamic", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("broadcast") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(accumulate_n_v2_op_info)
|
||||||
|
def _accumulate_n_v2_tbe():
|
||||||
|
"""AccumulateNV2 TBE register"""
|
||||||
|
return
|
|
@ -41,7 +41,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
|
||||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||||
from .inner_ops import ScalarCast
|
from .inner_ops import ScalarCast
|
||||||
|
|
||||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
|
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
|
||||||
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
|
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
|
||||||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
|
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
|
||||||
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
|
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
|
||||||
|
@ -88,6 +88,7 @@ __all__ = [
|
||||||
'ArgMaxWithValue',
|
'ArgMaxWithValue',
|
||||||
'ArgMinWithValue',
|
'ArgMinWithValue',
|
||||||
'AddN',
|
'AddN',
|
||||||
|
'AccumulateNV2',
|
||||||
'Sub',
|
'Sub',
|
||||||
'CumSum',
|
'CumSum',
|
||||||
'MatMul',
|
'MatMul',
|
||||||
|
|
|
@ -798,6 +798,64 @@ class AddN(PrimitiveWithInfer):
|
||||||
return Tensor(out)
|
return Tensor(out)
|
||||||
|
|
||||||
|
|
||||||
|
class AccumulateNV2(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Computes accumulation of all input tensors element-wise.
|
||||||
|
|
||||||
|
AccumulateNV2 is like AddN with a significant difference: AccumulateNV2 won't
|
||||||
|
wait for all of its inputs to be ready before beginning to sum. That is to say,
|
||||||
|
AccumulateNV2 will be able to save memory when inputs are ready at different
|
||||||
|
times since minimum temporary storage is proportional to the output size rather
|
||||||
|
than the inputs size.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list
|
||||||
|
is made up of multiple tensors whose dtype is number to be added together.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape and dtype as each entry of the `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class NetAccumulateNV2(nn.Cell):
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> super(NetAccumulateNV2, self).__init__()
|
||||||
|
>>> self.accumulateNV2 = P.AccumulateNV2()
|
||||||
|
>>>
|
||||||
|
>>> def construct(self, *z):
|
||||||
|
>>> return self.accumulateNV2(z)
|
||||||
|
>>>
|
||||||
|
>>> net = NetAccumulateNV2()
|
||||||
|
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||||
|
>>> input_y = Tensor(np.array([4, 5, 6]), mindspore.float32)
|
||||||
|
>>> net(input_x, input_y, input_x, input_y)
|
||||||
|
Tensor([10., 14., 18.], shape=(3,), dtype=mindspore.float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
self.__setattr_flag__ = True
|
||||||
|
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||||
|
|
||||||
|
def infer_shape(self, inputs):
|
||||||
|
cls_name = self.name
|
||||||
|
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
|
||||||
|
self.add_prim_attr('n', len(inputs))
|
||||||
|
shp0 = inputs[0]
|
||||||
|
for i, shp in enumerate(inputs):
|
||||||
|
validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0, Rel.EQ, cls_name)
|
||||||
|
return shp0
|
||||||
|
|
||||||
|
def infer_dtype(self, inputs):
|
||||||
|
cls_name = self.name
|
||||||
|
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
|
||||||
|
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
|
||||||
|
args = {}
|
||||||
|
for i, dtype in enumerate(inputs):
|
||||||
|
args[f"inputs[{i}]"] = dtype
|
||||||
|
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
|
||||||
|
return inputs[0]
|
||||||
|
|
||||||
|
|
||||||
class Neg(PrimitiveWithInfer):
|
class Neg(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Returns a tensor with negative values of the input tensor element-wise.
|
Returns a tensor with negative values of the input tensor element-wise.
|
||||||
|
|
|
@ -1415,6 +1415,11 @@ test_case_array_ops = [
|
||||||
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
|
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
|
||||||
'desc_bprop': [[2, 3, 3, 5]],
|
'desc_bprop': [[2, 3, 3, 5]],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
|
('AccumulateNV2', {
|
||||||
|
'block': NetForTupleInput(P.AccumulateNV2()),
|
||||||
|
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
|
||||||
|
'desc_bprop': [[2, 3, 3, 5]],
|
||||||
|
'skip': ['backward']}),
|
||||||
('Shape', {
|
('Shape', {
|
||||||
'block': P.Shape(),
|
'block': P.Shape(),
|
||||||
'desc_inputs': [[3, 3, 2, 2]],
|
'desc_inputs': [[3, 3, 2, 2]],
|
||||||
|
|
Loading…
Reference in New Issue