vm for prelu and prelugrad

This commit is contained in:
jiangjinsheng 2020-05-19 15:17:32 +08:00
parent f73867222e
commit ce09f5e15a
5 changed files with 121 additions and 0 deletions

View File

@ -36,6 +36,8 @@ static std::map<string, string> tbe_func_adapter_map = {
{"re_lu6_grad", "relu6_grad"},
{"re_lu", "relu"},
{"re_luv2", "relu_v2"},
{"p_re_lu", "prelu"},
{"p_re_lu_grad", "prelu_grad"},
{"tensor_add", "add"},
{"reduce_mean", "reduce_mean_d"},
{"reduce_max", "reduce_max_d"},

View File

@ -184,3 +184,5 @@ from .bn_training_update_v2 import _bn_training_update_v2_tbe
from .square_sum_all import square_sum_all_op_info
from .pack import _pack_tbe
from .unpack import _unpack_tbe
from .prelu import _prelu_tbe
from .prelu_grad import _prelu_grad_tbe

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PReLU op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
prelu_op_info = TBERegOp("PReLU") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("prelu.so") \
.compute_cost(10) \
.kernel_name("prelu") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "weight", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_NCHW, DataType.F16_Default, DataType.F16_NCHW) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_NCHW, DataType.F32_Default, DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(prelu_op_info)
def _prelu_tbe():
"""PReLU TBE register"""
return

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PReLUGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
prelu_grad_op_info = TBERegOp("PReLUGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("prelu_grad.so") \
.compute_cost(10) \
.kernel_name("prelu_grad") \
.partial_flag(True) \
.input(0, "grads", False, "required", "all") \
.input(1, "features", False, "required", "all") \
.input(2, "weights", False, "required", "all") \
.output(0, "dx", False, "required", "all") \
.output(0, "da", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_Default,
DataType.F32_NCHW, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(prelu_grad_op_info)
def _prelu_grad_tbe():
"""PReLUGrad TBE register"""
return

View File

@ -24,6 +24,7 @@ from mindspore.common.initializer import initializer
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
@ -456,6 +457,28 @@ class FlattenNet(nn.Cell):
return self.flatten(x)
class PReLUNet(nn.Cell):
""" PReLUNet definition """
def __init__(self):
super(PReLUNet, self).__init__()
self.prelu = P.PReLU()
self.w = Tensor(np.ones(3, np.float32))
def construct(self, x):
return self.prelu(x, self.w)
class PReLUGradNet(nn.Cell):
""" PReLUGradNet definition """
def __init__(self):
super(PReLUGradNet, self).__init__()
self.prelu_grad = G.PReLUGrad()
def construct(self, dout, x, w):
return self.prelu_grad(dout, x, w)
test_cases = [
('SoftMaxGrad', {
'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())),
@ -545,6 +568,16 @@ test_cases = [
'block': FlattenNet(),
'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],
}),
('PReLUNet', {
'block': PReLUNet(),
'desc_inputs': [Tensor(np.ones([1, 3, 4, 4], np.float32))],
}),
('PReLUGradNet', {
'block': PReLUGradNet(),
'desc_inputs': [Tensor(np.ones([1, 3, 4, 4], np.float32)),
Tensor(np.ones([1, 3, 4, 4], np.float32)),
Tensor(np.ones(3, np.float32))],
}),
]
test_cases_for_verify_exception = [