vm for erfc

This commit is contained in:
jiangjinsheng 2020-05-21 10:08:23 +08:00
parent 93e7c97a96
commit 2d3cd8276e
6 changed files with 102 additions and 1 deletions

View File

@ -361,6 +361,23 @@ def get_bprop_erf(self):
return bprop return bprop
@bprop_getters.register(P.Erfc)
def get_bprop_erfc(self):
"""Grad definition for `Erfc` operation."""
exp = P.Exp()
square = P.Square()
sqrt = P.Sqrt()
cast = P.Cast()
dtype = P.DType()
def bprop(x, out, dout):
half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
x_square = square(x)
dx = dout * (-half_root_pi * exp(-x_square))
return (dx,)
return bprop
@bprop_getters.register(P.Pow) @bprop_getters.register(P.Pow)
def get_bprop_pow(self): def get_bprop_pow(self):
"""Grad definition for `Pow` operation.""" """Grad definition for `Pow` operation."""

View File

@ -152,6 +152,7 @@ from .fused_mul_add_n import _fused_mul_add_n_tbe
from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe
from .fill import _fill_op_tbe from .fill import _fill_op_tbe
from .erf import _erf_op_tbe from .erf import _erf_op_tbe
from .erfc import _erfc_op_tbe
from .depthwise_conv2d import _depthwise_conv2d_tbe from .depthwise_conv2d import _depthwise_conv2d_tbe
from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe
from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Erfc op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
erfc_op_info = TBERegOp("Erfc") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("erfc.so") \
.compute_cost(10) \
.kernel_name("erfc") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(erfc_op_info)
def _erfc_op_tbe():
"""Erfc TBE register"""
return

View File

@ -39,7 +39,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Erf, Floor, FloorDiv, FloorMod, Acosh, Cos, Div, Equal, EqualCount, Exp, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
LogicalNot, LogicalOr, MatMul, Maximum, LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual, Minimum, Mul, Neg, NMSWithMask, NotEqual,

View File

@ -1067,6 +1067,36 @@ class Erf(PrimitiveWithInfer):
return x_type return x_type
class Erfc(PrimitiveWithInfer):
r"""
Computes the complementary error function of `input_x` element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape and dtype as the `input_x`.
Examples:
>>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32)
>>> erfc = P.Erfc()
>>> erfc(input_x)
[1.8427168, 0., 0.1572832, 0.00469124, 0.00002235]
"""
@prim_attr_register
def __init__(self):
"""init Erfc"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
return x_type
class Minimum(_MathBinaryOp): class Minimum(_MathBinaryOp):
""" """
Computes the element-wise minimum of input tensors. Computes the element-wise minimum of input tensors.

View File

@ -372,6 +372,15 @@ class Log1pNet(nn.Cell):
return self.log1p(x) return self.log1p(x)
class ErfcNet(nn.Cell):
def __init__(self):
super(ErfcNet, self).__init__()
self.erfc = P.Erfc()
def construct(self, x):
return self.erfc(x)
test_case_math_ops = [ test_case_math_ops = [
('MatMulGrad', { ('MatMulGrad', {
'block': GradWrap(NetWithLoss(MatMulNet())), 'block': GradWrap(NetWithLoss(MatMulNet())),
@ -422,6 +431,11 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))], 'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))], 'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
'skip': ['backward']}), 'skip': ['backward']}),
('Erfc', {
'block': ErfcNet(),
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
}),
] ]
test_case_lists = [test_case_math_ops] test_case_lists = [test_case_math_ops]