add vm support for Expm1

This commit is contained in:
zhouneng 2020-06-01 09:30:33 +08:00
parent fd045e9115
commit e5419f7bd1
6 changed files with 89 additions and 1 deletions

View File

@ -422,6 +422,19 @@ def get_bprop_exp(self):
return bprop
@bprop_getters.register(P.Expm1)
def get_bprop_expm1(self):
"""Grad definition for `Expm1` operation."""
exp_ = P.Exp()
def bprop(x, out, dout):
g = exp_(x)
dx = g * dout
return (dx,)
return bprop
@bprop_getters.register(P.Minimum)
def get_bprop_minimum(self):
"""Grad definition for `Minimum` operation."""

View File

@ -83,6 +83,7 @@ from .strided_slice_d import _strided_slice_d_tbe
from .strided_slice_grad_d import _strided_slice_grad_d_tbe
from .split_d import _split_d_tbe
from .exp import _exp_tbe
from .expm1 import _expm1_tbe
from .elu import _elu_tbe
from .elu_grad import _elu_grad_tbe
from .div import _div_tbe

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Expm1 op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
expm1_op_info = TBERegOp("Expm1") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("expm1.so") \
.compute_cost(10) \
.kernel_name("expm1") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(expm1_op_info)
def _expm1_tbe():
"""Expm1 TBE register"""
return

View File

@ -42,7 +42,7 @@ from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, BitwiseXor,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Cos, Div, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
@ -89,6 +89,7 @@ __all__ = [
'Mul',
'Pow',
'Exp',
'Expm1',
'Rsqrt',
'Sqrt',
'Square',

View File

@ -1004,6 +1004,36 @@ class Exp(PrimitiveWithInfer):
return x_type
class Expm1(PrimitiveWithInfer):
"""
Returns exponential then minus 1 of a tensor element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape as the `input_x`.
Examples:
>>> input_x = Tensor(np.array([0.0, 1.0, 2.0, 4.0]), mindspore.float32)
>>> expm1 = P.Expm1()
>>> expm1(input_x)
[ 0., 1.71828183, 6.3890561 , 53.59815003]
"""
@prim_attr_register
def __init__(self):
"""init Exp"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.name)
return x_type
class Log(PrimitiveWithInfer):
"""
Returns the natural logarithm of a tensor element-wise.

View File

@ -348,6 +348,10 @@ test_case_math_ops = [
'block': P.Exp(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Expm1', {
'block': P.Expm1(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Erf', {
'block': P.Erf(),
'desc_inputs': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))],