forked from mindspore-Ecosystem/mindspore
!33569 add hswish operator
Merge pull request !33569 from jjfeing/add_hswish_operator
This commit is contained in:
commit
eafd35217a
|
@ -58,6 +58,7 @@ functional算子是经过初始化后的Primitive,可以直接作为函数使
|
|||
|
||||
mindspore.ops.fast_gelu
|
||||
mindspore.ops.hardshrink
|
||||
mindspore.ops.hardswish
|
||||
mindspore.ops.padding
|
||||
mindspore.ops.tanh
|
||||
|
||||
|
|
|
@ -521,6 +521,22 @@ mindspore.Tensor
|
|||
- **TypeError** - `lambd` 不是float。
|
||||
- **TypeError** - 原始Tensor的dtype既不是float16也不是float32。
|
||||
|
||||
.. py:method:: hardswish()
|
||||
|
||||
Hard Swish激活函数。
|
||||
|
||||
对输入的每个元素计算Hard Swish。
|
||||
|
||||
更多细节参考 :func:`mindspore.ops.hardswish`。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,具有与输入Tensor相同的数据类型和shape。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 输入Tensor的数据类型既不是float16也不是float32。
|
||||
|
||||
.. py:method:: has_init
|
||||
:property:
|
||||
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
mindspore.ops.HSwish
|
||||
=====================
|
||||
|
||||
.. py:class:: mindspore.ops.HSwish
|
||||
|
||||
Hard Swish激活函数。
|
||||
|
||||
更多参考详见 :func:`mindspore.ops.hardswish`。
|
|
@ -0,0 +1,27 @@
|
|||
mindspore.ops.hardswish
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.ops.hardswish(x)
|
||||
|
||||
Hard Swish激活函数。
|
||||
|
||||
对输入的每个元素计算Hard Swish。
|
||||
|
||||
Hard Swish定义如下:
|
||||
|
||||
.. math::
|
||||
\text{hardswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6}
|
||||
|
||||
其中, :math:`x_i` 是输入的元素。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - 用于计算Hard Swish的Tensor。数据类型必须是float16或float32。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape和数据类型与输入相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
|
@ -58,6 +58,7 @@ Activation Functions
|
|||
|
||||
mindspore.ops.fast_gelu
|
||||
mindspore.ops.hardshrink
|
||||
mindspore.ops.hardswish
|
||||
mindspore.ops.softsign
|
||||
mindspore.ops.tanh
|
||||
|
||||
|
|
|
@ -247,6 +247,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli()
|
||||
{"cdist", std::string("cdist")}, // P.cdist
|
||||
{"hardshrink", std::string("hardshrink")}, // P.hshrink
|
||||
{"hardswish", std::string("hardswish")}, // P.HSwish
|
||||
{"soft_shrink", std::string("soft_shrink")}, // P.SoftShrink
|
||||
{"one_hot", std::string("one_hot")}, // P.OneHot
|
||||
{"intopk", std::string("intopk")}, // P.InTopK
|
||||
|
|
|
@ -1542,6 +1542,34 @@ def adaptive_avgpool2d(x, output_size):
|
|||
return F.adaptive_avgpool2d(x, output_size)
|
||||
|
||||
|
||||
def hardswish(x):
|
||||
r"""
|
||||
Hard swish activation function.
|
||||
|
||||
Calculate Hard Swish for each element of input.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> x = np.array([-1, -2, 0, 2, 1])
|
||||
>>> print(x.hardswish())
|
||||
[-0.3333 -0.3333 0 1.666 0.6665]
|
||||
"""
|
||||
return P.HSwish()(x)
|
||||
|
||||
|
||||
def getitem(data, index):
|
||||
"""Implementation of `getitem`."""
|
||||
return data.__getitem__(index)
|
||||
|
|
|
@ -3730,6 +3730,34 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('hardshrink')(lambd)(self)
|
||||
|
||||
def hardswish(self):
|
||||
r"""
|
||||
Hard swish activation function.
|
||||
|
||||
Calculate Hard Swish for each element of input.
|
||||
|
||||
Refer to :func:`mindspore.ops.hardswish` for more detail.
|
||||
|
||||
Returns:
|
||||
Tensor, with the same type and shape as the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of input is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> a = Tensor([-1, -2, 0, 2, 1]).astype("float16")
|
||||
>>> output = a.hardswish()
|
||||
>>> print(output)
|
||||
[-0.3333 -0.3333 0 1.666 0.6665]
|
||||
"""
|
||||
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('hardswish')()(self)
|
||||
|
||||
def soft_shrink(self, lambd=0.5):
|
||||
"""
|
||||
Apply the soft shrink function for a tensor. Calculates the output according to the input elements.
|
||||
|
|
|
@ -585,3 +585,7 @@ from .adam_apply_one_ds import _adam_apply_one_ds_tbe
|
|||
from .adam_apply_one_with_decay_ds import _adam_apply_one_with_decay_ds_tbe
|
||||
from .adaptive_max_pool2d import _adaptive_max_pool2d_tbe
|
||||
from .pooling import _pooling_tbe
|
||||
from .hard_swish import _hard_swish_tbe
|
||||
from .hard_swish_grad import _hard_swish_grad_tbe
|
||||
from .hard_swish_ds import _hard_swish_ds_tbe
|
||||
from .hard_swish_grad_ds import _hard_swish_grad_ds_tbe
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Hard Swish op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
hard_swish_op_info = TBERegOp("HSwish") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("hard_swish.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("hard_swish") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hard_swish_op_info)
|
||||
def _hard_swish_tbe():
|
||||
"""Hard Swish TBE register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Hard Swish op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
hard_swish_op_info = TBERegOp("HSwish") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("hard_swish.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("hard_swish") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.dynamic_shape(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hard_swish_op_info)
|
||||
def _hard_swish_ds_tbe():
|
||||
"""Hard Swish TBE register"""
|
||||
return
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Hard Swish Grad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
hard_swish_grad_op_info = TBERegOp("HSwishGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("hard_swish_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("hard_swish_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "grad", False, "required", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hard_swish_grad_op_info)
|
||||
def _hard_swish_grad_tbe():
|
||||
"""Hard Swish Grad TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Hard Swish Grad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
hard_swish_grad_op_info = TBERegOp("HSwishGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("hard_swish_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("hard_swish_grad") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "grad", False, "required", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hard_swish_grad_op_info)
|
||||
def _hard_swish_grad_ds_tbe():
|
||||
"""Hard Swish Grad TBE register"""
|
||||
return
|
|
@ -204,6 +204,7 @@ def get_in_top_k_vmap_rule(prim, axis_size):
|
|||
|
||||
@vmap_rules_getters.register(G.FastGeLUGrad)
|
||||
@vmap_rules_getters.register(G.HShrinkGrad)
|
||||
@vmap_rules_getters.register(G.HSwishGrad)
|
||||
@vmap_rules_getters.register(G.SoftShrinkGrad)
|
||||
def get_fast_gelu_grad_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for common activation grad operation."""
|
||||
|
|
|
@ -208,6 +208,7 @@ from .nn_func import (
|
|||
hardshrink,
|
||||
soft_shrink,
|
||||
intopk,
|
||||
hardswish,
|
||||
softsign,
|
||||
pdist,
|
||||
nll_loss,
|
||||
|
|
|
@ -111,6 +111,7 @@ def adaptive_avgpool2d(x, output_size):
|
|||
|
||||
fast_gelu_ = P.FastGeLU()
|
||||
softsign_ = P.Softsign()
|
||||
hardswish_ = P.HSwish()
|
||||
|
||||
|
||||
def fast_gelu(x):
|
||||
|
@ -185,6 +186,42 @@ def hardshrink(x, lambd=0.5):
|
|||
return hshrink_op(x)
|
||||
|
||||
|
||||
def hardswish(x):
|
||||
r"""
|
||||
Hard swish activation function.
|
||||
|
||||
Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.
|
||||
|
||||
Hard swish is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
|
||||
|
||||
where :math:`x_i` is an element of the input Tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input to compute the Hard Swish with data type of float16 or float32.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data type and shape as the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
|
||||
>>> output = ops.hardswish(x)
|
||||
>>> print(result)
|
||||
[-0.3333 -0.3333 0 1.666 0.6665]
|
||||
"""
|
||||
return hardswish_(x)
|
||||
|
||||
|
||||
def softsign(x):
|
||||
r"""
|
||||
Softsign activation function.
|
||||
|
@ -701,6 +738,7 @@ __all__ = [
|
|||
'hardshrink',
|
||||
'soft_shrink',
|
||||
'intopk',
|
||||
'hardswish',
|
||||
'softsign',
|
||||
'pdist',
|
||||
'cross_entropy',
|
||||
|
|
|
@ -963,6 +963,7 @@ tensor_operator_registry.register('invert', invert)
|
|||
tensor_operator_registry.register('matrix_band_part', matrix_band_part)
|
||||
tensor_operator_registry.register('padding', padding)
|
||||
tensor_operator_registry.register('hardshrink', P.HShrink)
|
||||
tensor_operator_registry.register('hardswish', P.HSwish)
|
||||
tensor_operator_registry.register('soft_shrink', P.SoftShrink)
|
||||
tensor_operator_registry.register('svd', linalg_ops.Svd)
|
||||
tensor_operator_registry.register('diag', P.Diag)
|
||||
|
|
|
@ -820,29 +820,10 @@ class HSwish(PrimitiveWithInfer):
|
|||
r"""
|
||||
Hard swish activation function.
|
||||
|
||||
Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.
|
||||
|
||||
Hard swish is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
|
||||
|
||||
where :math:`x_i` is an element of the input Tensor.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
|
||||
additional dimensions, with float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If dtype of `input_x` is neither float16 nor float32.
|
||||
Refer to :func:`mindspore.ops.hardswish` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> hswish = ops.HSwish()
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
@ms_function
|
||||
def construct(self, input_, output_grad):
|
||||
return self.grad(self.network)(input_, output_grad)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.hswish = P.HSwish()
|
||||
|
||||
def construct(self, x):
|
||||
return self.hswish(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net():
|
||||
"""
|
||||
Feature: Monitor the accuracy of hswish operator.
|
||||
Description: Input Tensor with [-1, -2, 0, 2, 1], run in ascend.
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
x = np.array([-1, -2, 0, 2, 1]).astype(np.float32)
|
||||
hswish = Net()
|
||||
y = hswish(Tensor(x))
|
||||
expect = np.array([-0.33333334, -0.33333334, 0., 1.6666666, 0.6666667]).astype(np.float32)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
diff = y.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
sens = np.random.randn(5).astype(np.float32)
|
||||
backward_net = Grad(Net())
|
||||
output = backward_net(Tensor(x), Tensor(sens))
|
||||
print(len(output))
|
||||
print(output[0].asnumpy())
|
||||
|
||||
|
||||
def expect_hswish_forward_result(x):
|
||||
return np.where(x <= -3, 0, np.where(x >= 3, x, x * (x + 3) / 6))
|
||||
|
||||
|
||||
def expect_hswish_backward_result(x, dout):
|
||||
return np.where(x <= -3, 0, np.where(x >= 3, 1, x / 3 + 0.5)) * dout
|
||||
|
||||
|
||||
def judge_result_correct(result, expect):
|
||||
assert result.dtype == expect.dtype
|
||||
assert result.shape == expect.shape
|
||||
assert np.allclose(result, expect)
|
||||
|
||||
|
||||
def generate_test_cases(np_type, mode):
|
||||
context.set_context(mode=mode, device_target="Ascend")
|
||||
x = np.array([-1, -2, 0, 4, 5]).astype(np_type)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expect = expect_hswish_forward_result(x)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(np_type)
|
||||
backward_net = Grad(Net())
|
||||
output = backward_net(Tensor(x), Tensor(sens))
|
||||
expect = expect_hswish_backward_result(x, sens)
|
||||
judge_result_correct(output[0].asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_hardswish_forward_and_backward():
|
||||
"""
|
||||
Feature: Monitor the accuracy of hswish operator.
|
||||
Description: Input Tensor with [-1, -2, 0, 2, 1], run in ascend.
|
||||
Expectation: success
|
||||
"""
|
||||
modes = (context.GRAPH_MODE, context.PYNATIVE_MODE)
|
||||
dtypes = (np.float32, np.float16)
|
||||
for mode in modes:
|
||||
for dtype in dtypes:
|
||||
generate_test_cases(dtype, mode)
|
Loading…
Reference in New Issue