forked from mindspore-Ecosystem/mindspore
!959 Add Elu\EluGrad ops for VM.
Merge pull request !959 from liuxiao/ops-for-VM
This commit is contained in:
commit
e75d75854d
|
@ -600,7 +600,6 @@ def get_bprop_roi_align(self):
|
|||
sample_num = self.sample_num
|
||||
|
||||
def bprop(inputs, rois, out, dout):
|
||||
rois_shape = shape_op(rois)
|
||||
inputs_shape = shape_op(inputs)
|
||||
dx = G.ROIAlignGrad(inputs_shape,
|
||||
pooled_height,
|
||||
|
@ -608,7 +607,7 @@ def get_bprop_roi_align(self):
|
|||
spatial_scale,
|
||||
sample_num,
|
||||
)(dout, rois)
|
||||
return dx, zeros_like(rois_shape)
|
||||
return dx, zeros_like(rois)
|
||||
|
||||
return bprop
|
||||
|
||||
|
|
|
@ -76,6 +76,8 @@ from .strided_slice_d import _strided_slice_d_tbe
|
|||
from .strided_slice_grad_d import _strided_slice_grad_d_tbe
|
||||
from .split_d import _split_d_tbe
|
||||
from .exp import _exp_tbe
|
||||
from .elu import _elu_tbe
|
||||
from .elu_grad import _elu_grad_tbe
|
||||
from .div import _div_tbe
|
||||
from .log import _log_tbe
|
||||
from .floor_div import _floor_div_tbe
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Elu op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
elu_op_info = TBERegOp("Elu") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("elu.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("elu") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.attr("alpha", "optional", "float", "all", "1.0") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(elu_op_info)
|
||||
def _elu_tbe():
|
||||
"""Elu TBE register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""EluGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
elu_grad_op_info = TBERegOp("EluGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("elu_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("elu_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "grads", False, "required", "all") \
|
||||
.input(1, "activations", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(elu_grad_op_info)
|
||||
def _elu_grad_tbe():
|
||||
"""EluGrad TBE register"""
|
||||
return
|
|
@ -1527,7 +1527,8 @@ class L2Loss(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name)
|
||||
valid_types = [mstype.float16, mstype.float32, mstype.double]
|
||||
validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
|
|
@ -874,7 +874,7 @@ test_case_nn_ops = [
|
|||
'skip': ['backward']}),
|
||||
('L2Loss_1', {
|
||||
'block': P.L2Loss(),
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float16)],
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
|
||||
'desc_bprop': []}),
|
||||
('L2Loss_2', {
|
||||
'block': P.L2Loss(),
|
||||
|
|
Loading…
Reference in New Issue