forked from mindspore-Ecosystem/mindspore
add graph kernel div, floordiv, mod, floormod, floor
This commit is contained in:
parent
bdc687d774
commit
c30b1e6d06
|
@ -24,6 +24,8 @@ from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
|
|||
from .conv2d import Conv2D
|
||||
from .complex import CAbs, CAdd, CDiv, CMul, CSub
|
||||
from .dropout_grad import DropoutGrad
|
||||
from .equal_count import EqualCount
|
||||
from .erfc import Erfc
|
||||
from .expand_dims import ExpandDims
|
||||
from .fused_adam import FusedAdam
|
||||
from .fused_adam_weight_decay import FusedAdamWeightDecay
|
||||
|
@ -57,4 +59,3 @@ from .squared_difference import SquaredDifference
|
|||
from .squeeze import Squeeze
|
||||
from .tanh_grad import TanhGrad
|
||||
from .tile import Tile
|
||||
from .equal_count import EqualCount
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for erfc"""
|
||||
from ._utils import Expander
|
||||
|
||||
|
||||
class Erfc(Expander):
|
||||
"""Erfc expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
result = None
|
||||
if input_x.dtype == "float16":
|
||||
const_one = graph_builder.value("float32", 1)
|
||||
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
|
||||
erf_result = graph_builder.emit('Erf', [input_x])
|
||||
result = graph_builder.emit('Sub', [const_one, erf_result])
|
||||
result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
|
||||
return result
|
||||
const_one = graph_builder.value(input_x.dtype, 1)
|
||||
erf_result = graph_builder.emit('Erf', [input_x])
|
||||
result = graph_builder.emit('Sub', [const_one, erf_result])
|
||||
return result
|
|
@ -158,7 +158,14 @@ class PrimLib:
|
|||
'Exp': Prim(ELEMWISE),
|
||||
'Rsqrt': Prim(ELEMWISE),
|
||||
'Sqrt': Prim(ELEMWISE),
|
||||
'Div': Prim(ELEMWISE),
|
||||
'FloorDiv': Prim(ELEMWISE),
|
||||
'RealDiv': Prim(ELEMWISE),
|
||||
'Mod': Prim(ELEMWISE),
|
||||
'Floor': Prim(ELEMWISE),
|
||||
'FloorMod': Prim(ELEMWISE),
|
||||
'Erf': Prim(ELEMWISE),
|
||||
'Erfc': Prim(ELEMWISE),
|
||||
'Cast': Prim(ELEMWISE),
|
||||
'Pow': Prim(ELEMWISE),
|
||||
'Minimum': Prim(ELEMWISE),
|
||||
|
|
|
@ -60,17 +60,8 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
|
|||
prim::kPrimRealDiv,
|
||||
prim::kPrimReduceSum,
|
||||
prim::kPrimEqual,
|
||||
prim::kPrimNotEqual,
|
||||
prim::kPrimLogicalAnd,
|
||||
prim::kPrimLogicalOr,
|
||||
prim::kPrimLogicalNot,
|
||||
prim::kPrimAssign,
|
||||
prim::kPrimInplaceAssign,
|
||||
prim::kPrimAtan,
|
||||
prim::kPrimAtan2,
|
||||
prim::kPrimExpm1,
|
||||
prim::kPrimAsinh,
|
||||
prim::kPrimAcosh,
|
||||
#if ENABLE_D
|
||||
prim::kPrimMatMul,
|
||||
prim::KPrimTransData,
|
||||
|
@ -88,6 +79,21 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
|
|||
prim::kPrimGreaterEqual,
|
||||
prim::kPrimLessEqual,
|
||||
prim::kPrimSelect,
|
||||
prim::kPrimAtan,
|
||||
prim::kPrimAtan2,
|
||||
prim::kPrimExpm1,
|
||||
prim::kPrimAsinh,
|
||||
prim::kPrimAcosh,
|
||||
prim::kPrimDiv,
|
||||
prim::kPrimFloorDiv,
|
||||
prim::kPrimMod,
|
||||
prim::kPrimFloor,
|
||||
prim::kPrimFloorMod,
|
||||
prim::kPrimErf,
|
||||
prim::kPrimNotEqual,
|
||||
prim::kPrimLogicalAnd,
|
||||
prim::kPrimLogicalOr,
|
||||
prim::kPrimLogicalNot,
|
||||
#endif
|
||||
};
|
||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||
|
|
|
@ -60,6 +60,7 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimTile,
|
||||
prim::kPrimMatMul,
|
||||
prim::kPrimBatchMatMul,
|
||||
prim::kPrimErfc,
|
||||
#if ENABLE_D
|
||||
prim::kPrimSqrtGrad,
|
||||
prim::kPrimClipByNormNoDivSum,
|
||||
|
|
|
@ -622,6 +622,7 @@ inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("La
|
|||
inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType");
|
||||
inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion");
|
||||
inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf");
|
||||
inline const PrimitivePtr kPrimErfc = std::make_shared<Primitive>("Erfc");
|
||||
inline const PrimitivePtr kPrimSplice = std::make_shared<Primitive>("Splice");
|
||||
inline const PrimitivePtr kPrimAffine = std::make_shared<Primitive>("Affine");
|
||||
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ErfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ErfNet, self).__init__()
|
||||
self.erf = P.Erf()
|
||||
|
||||
def construct(self, x):
|
||||
return self.erf(x)
|
||||
|
||||
class ErfcNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ErfcNet, self).__init__()
|
||||
self.erfc = P.Erfc()
|
||||
|
||||
def construct(self, x):
|
||||
return self.erfc(x)
|
||||
|
||||
def get_output(net, inp, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
output = net()(inp)
|
||||
return output
|
||||
|
||||
def basic_test(net, datatype):
|
||||
inp = Tensor(np.random.random((2, 3)).astype(datatype))
|
||||
expect = get_output(net, inp, False)
|
||||
output = get_output(net, inp, True)
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
|
||||
|
||||
inp = Tensor(np.random.random((2, 3, 3, 4, 5)).astype(datatype))
|
||||
expect = get_output(net, inp, False)
|
||||
output = get_output(net, inp, True)
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_fp16():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
basic_test(ErfNet, np.float16)
|
||||
basic_test(ErfcNet, np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_fp32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
basic_test(ErfNet, np.float32)
|
||||
basic_test(ErfcNet, np.float32)
|
Loading…
Reference in New Issue