add graph kernel div, floordiv, mod, floormod, floor

This commit is contained in:
yanglf1121 2021-06-29 17:22:51 +08:00
parent bdc687d774
commit c30b1e6d06
7 changed files with 134 additions and 10 deletions

View File

@ -24,6 +24,8 @@ from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
from .conv2d import Conv2D
from .complex import CAbs, CAdd, CDiv, CMul, CSub
from .dropout_grad import DropoutGrad
from .equal_count import EqualCount
from .erfc import Erfc
from .expand_dims import ExpandDims
from .fused_adam import FusedAdam
from .fused_adam_weight_decay import FusedAdamWeightDecay
@ -57,4 +59,3 @@ from .squared_difference import SquaredDifference
from .squeeze import Squeeze
from .tanh_grad import TanhGrad
from .tile import Tile
from .equal_count import EqualCount

View File

@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for erfc"""
from ._utils import Expander
class Erfc(Expander):
"""Erfc expander"""
def _expand(self, graph_builder):
input_x = self.inputs[0]
result = None
if input_x.dtype == "float16":
const_one = graph_builder.value("float32", 1)
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
erf_result = graph_builder.emit('Erf', [input_x])
result = graph_builder.emit('Sub', [const_one, erf_result])
result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
return result
const_one = graph_builder.value(input_x.dtype, 1)
erf_result = graph_builder.emit('Erf', [input_x])
result = graph_builder.emit('Sub', [const_one, erf_result])
return result

View File

@ -158,7 +158,14 @@ class PrimLib:
'Exp': Prim(ELEMWISE),
'Rsqrt': Prim(ELEMWISE),
'Sqrt': Prim(ELEMWISE),
'Div': Prim(ELEMWISE),
'FloorDiv': Prim(ELEMWISE),
'RealDiv': Prim(ELEMWISE),
'Mod': Prim(ELEMWISE),
'Floor': Prim(ELEMWISE),
'FloorMod': Prim(ELEMWISE),
'Erf': Prim(ELEMWISE),
'Erfc': Prim(ELEMWISE),
'Cast': Prim(ELEMWISE),
'Pow': Prim(ELEMWISE),
'Minimum': Prim(ELEMWISE),

View File

@ -60,17 +60,8 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
prim::kPrimRealDiv,
prim::kPrimReduceSum,
prim::kPrimEqual,
prim::kPrimNotEqual,
prim::kPrimLogicalAnd,
prim::kPrimLogicalOr,
prim::kPrimLogicalNot,
prim::kPrimAssign,
prim::kPrimInplaceAssign,
prim::kPrimAtan,
prim::kPrimAtan2,
prim::kPrimExpm1,
prim::kPrimAsinh,
prim::kPrimAcosh,
#if ENABLE_D
prim::kPrimMatMul,
prim::KPrimTransData,
@ -88,6 +79,21 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
prim::kPrimGreaterEqual,
prim::kPrimLessEqual,
prim::kPrimSelect,
prim::kPrimAtan,
prim::kPrimAtan2,
prim::kPrimExpm1,
prim::kPrimAsinh,
prim::kPrimAcosh,
prim::kPrimDiv,
prim::kPrimFloorDiv,
prim::kPrimMod,
prim::kPrimFloor,
prim::kPrimFloorMod,
prim::kPrimErf,
prim::kPrimNotEqual,
prim::kPrimLogicalAnd,
prim::kPrimLogicalOr,
prim::kPrimLogicalNot,
#endif
};
const auto &flags = context::GraphKernelFlags::GetInstance();

View File

@ -60,6 +60,7 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimTile,
prim::kPrimMatMul,
prim::kPrimBatchMatMul,
prim::kPrimErfc,
#if ENABLE_D
prim::kPrimSqrtGrad,
prim::kPrimClipByNormNoDivSum,

View File

@ -622,6 +622,7 @@ inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("La
inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType");
inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion");
inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf");
inline const PrimitivePtr kPrimErfc = std::make_shared<Primitive>("Erfc");
inline const PrimitivePtr kPrimSplice = std::make_shared<Primitive>("Splice");
inline const PrimitivePtr kPrimAffine = std::make_shared<Primitive>("Affine");

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class ErfNet(nn.Cell):
def __init__(self):
super(ErfNet, self).__init__()
self.erf = P.Erf()
def construct(self, x):
return self.erf(x)
class ErfcNet(nn.Cell):
def __init__(self):
super(ErfcNet, self).__init__()
self.erfc = P.Erfc()
def construct(self, x):
return self.erfc(x)
def get_output(net, inp, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
output = net()(inp)
return output
def basic_test(net, datatype):
inp = Tensor(np.random.random((2, 3)).astype(datatype))
expect = get_output(net, inp, False)
output = get_output(net, inp, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
inp = Tensor(np.random.random((2, 3, 3, 4, 5)).astype(datatype))
expect = get_output(net, inp, False)
output = get_output(net, inp, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
basic_test(ErfNet, np.float16)
basic_test(ErfcNet, np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_fp32():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
basic_test(ErfNet, np.float32)
basic_test(ErfcNet, np.float32)