forked from mindspore-Ecosystem/mindspore
!35560 Move ClipByNorm to inner interface
Merge pull request !35560 from JoyLvliang/move_clipbynorm_to_inner_interface
This commit is contained in:
commit
0f27520ddb
|
@ -498,7 +498,6 @@ Parameter操作算子
|
|||
mindspore.ops.batch_dot
|
||||
mindspore.ops.clip_by_global_norm
|
||||
mindspore.ops.clip_by_value
|
||||
mindspore.ops.clip_by_norm
|
||||
mindspore.ops.core
|
||||
mindspore.ops.count_nonzero
|
||||
mindspore.ops.cummin
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
mindspore.ops.clip_by_norm
|
||||
============================
|
||||
|
||||
.. py:function:: mindspore.ops.clip_by_norm(x, clip_norm, axis=None)
|
||||
|
||||
基于 :math:`L_2`-norm 对输入Tensor :math:`x` 进行剪裁。如果输入Tensor :math:`x` 的 :math:`L_2`-norm 小于或者等于 :math:`clip_norm` ,原样返回输入Tensor :math:`x` 。否则,按照以下公式返回剪裁后的Tensor。
|
||||
|
||||
.. math::
|
||||
\text{output}(x) = \frac{\text{clip_norm} * x}{L_2-norm(x)}.
|
||||
|
||||
.. note::
|
||||
:math:`L_2`-norm 是对输入Tensor计算 `L_2` 范数。
|
||||
|
||||
**输入:**
|
||||
|
||||
- **x** (Tensor) - 任意维度的Tensor。数据类型是 `float16` 或者 `float32` 。
|
||||
- **clip_norm** (Tensor) - 表示裁剪比率的Tensor,数值应该大于0。Shape必须支持能广播至 `x` 的shape。数据类型是 `float16` 或者 `float32` 。
|
||||
- **axis** (Union[None, int, tuple(int), list(int)]) - 执行 :math:`L_2`-norm 计算的维度。默认值: `None` ,表示所有维度。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `x` 的数据类型不是 `float16` 或者 `float32` 。
|
||||
- **TypeError** - `clip_norm` 的数据类型不是 `float16` 或者 `float32` 。
|
||||
- **TypeError** - `axis` 的类型不是 `None`、`int`、`tuple(int)` 或者 `list(int)` 。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,表示裁剪后的Tensor。其shape和数据类型和 `x` 相同。
|
|
@ -497,7 +497,6 @@ Other Operators
|
|||
mindspore.ops.batch_dot
|
||||
mindspore.ops.clip_by_global_norm
|
||||
mindspore.ops.clip_by_value
|
||||
mindspore.ops.clip_by_norm
|
||||
mindspore.ops.core
|
||||
mindspore.ops.count_nonzero
|
||||
mindspore.ops.cummin
|
||||
|
|
|
@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.functional import identity
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.primitive import constexpr, Primitive
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
@ -448,44 +447,10 @@ class ClipByNorm(Cell):
|
|||
def __init__(self, axis=None):
|
||||
"""Initialize ClipByNorm."""
|
||||
super(ClipByNorm, self).__init__()
|
||||
if axis is None:
|
||||
axis = ()
|
||||
if isinstance(axis, tuple):
|
||||
for idx, item in enumerate(axis):
|
||||
Validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
|
||||
self.axis = Validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.select_ = P.Select()
|
||||
self.greater_ = P.Greater()
|
||||
self.cast = P.Cast()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.max_op = P.Maximum()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.fill = P.Fill()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.dtype = P.DType()
|
||||
self.clip_by_norm = inner.ClipByNorm(axis)
|
||||
|
||||
def construct(self, x, clip_norm):
|
||||
mul_x = F.square(x)
|
||||
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
|
||||
cond = self.greater_(l2sum, 0)
|
||||
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
|
||||
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
||||
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
||||
|
||||
_dtype_check(self.dtype(x), self.cls_name)
|
||||
if _is_equal_one(clip_norm):
|
||||
intermediate = x
|
||||
else:
|
||||
intermediate = x * clip_norm
|
||||
|
||||
max_norm = self.max_op(l2norm, clip_norm)
|
||||
if _need_reduce_all(self.axis):
|
||||
max_norm = self.expand_dims(max_norm, -1)
|
||||
values_clip = self.cast(intermediate, mstype.float32) / max_norm
|
||||
values_clip = self.reshape(values_clip, self.shape(x))
|
||||
values_clip = identity(values_clip)
|
||||
values_clip = self.clip_by_norm(x, clip_norm)
|
||||
return values_clip
|
||||
|
||||
|
||||
|
|
|
@ -17,14 +17,14 @@
|
|||
"""Define the grad rules of clip operations."""
|
||||
from .grad_base import bprop_getters
|
||||
from .. import operations as P
|
||||
from ..operations import clip_ops
|
||||
from ..operations import _inner_ops as inner
|
||||
from ..operations import _grad_ops as G
|
||||
from ...common import dtype as mstype
|
||||
from .._grad.grad_math_ops import _sum_grad
|
||||
from .._grad.grad_math_ops import binop_grad_common
|
||||
from ..operations import _grad_ops as G
|
||||
|
||||
|
||||
@bprop_getters.register(clip_ops.ClipByNorm)
|
||||
@bprop_getters.register(inner.ClipByNorm)
|
||||
def get_bprop_clip_by_norm(self):
|
||||
"""Grad definition for `ClipByNorm` operation."""
|
||||
neg_op = P.Neg()
|
||||
|
|
|
@ -25,7 +25,6 @@ from . import (
|
|||
math_func,
|
||||
nn_func,
|
||||
linalg_func,
|
||||
clip_func,
|
||||
)
|
||||
from .array_func import (
|
||||
unique,
|
||||
|
@ -217,9 +216,6 @@ from .nn_func import (
|
|||
from .linalg_func import (
|
||||
svd,
|
||||
)
|
||||
from .clip_func import (
|
||||
clip_by_norm,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(array_func.__all__)
|
||||
|
@ -227,5 +223,4 @@ __all__.extend(parameter_func.__all__)
|
|||
__all__.extend(math_func.__all__)
|
||||
__all__.extend(nn_func.__all__)
|
||||
__all__.extend(linalg_func.__all__)
|
||||
__all__.extend(clip_func.__all__)
|
||||
__all__.sort()
|
||||
|
|
|
@ -1,67 +0,0 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Defines clip operators with functional form."""
|
||||
|
||||
from ..operations import clip_ops
|
||||
|
||||
|
||||
def clip_by_norm(x, clip_norm, axis=None):
|
||||
r"""
|
||||
This function is used to clip tensor to a maximum :math:`L_2`-norm. If the :math:`L_2`-norm of the input 'x' is not
|
||||
greater than the input `clip_norm`, the output tensor remains unchanged. Otherwise the output tensor will be
|
||||
normalized as:
|
||||
|
||||
.. math::
|
||||
\text{output}(x) = \frac{\text{clip_norm} * x}{L_2(x)},
|
||||
|
||||
where :math:`L_2(x)` is the :math:`L_2`-norm of :math:`x`.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
|
||||
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
|
||||
Or a shape can be broadcast to the shape of `x`. The type must be float16 or float32.
|
||||
- **axis** (Union[None, int, tuple(int), list(int)]) - Compute the `L_2`-norm along the specific dimension.
|
||||
Default: None, all dimensions to calculate.
|
||||
|
||||
Outputs:
|
||||
Tensor, clipped Tensor, whose shape is the same as `x` and type is float32.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
TypeError: If dtype of `clip_norm` is neither float16 nor float32.
|
||||
TypeError: If `axis` is not one of None, int, tuple(int) and list(int).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.random.randint(0, 10, [6, 16]), mindspore.float32)
|
||||
>>> clip_norm = Tensor(np.array([10]).astype(np.float32))
|
||||
>>> output = F.clip_by_norm(x, clip_norm)
|
||||
>>> print(output.shape)
|
||||
(6, 16)
|
||||
"""
|
||||
return clip_ops.ClipByNorm(axis=axis)(x, clip_norm)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'clip_by_norm'
|
||||
]
|
||||
__all__.sort()
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Inner operators."""
|
||||
from types import FunctionType, MethodType
|
||||
from collections.abc import Iterable
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import Tensor
|
||||
|
@ -2038,3 +2039,77 @@ class KMeansCentroids(PrimitiveWithInfer):
|
|||
k = y_shape[0]
|
||||
em_size = x_shape[1]
|
||||
return (k, em_size), (k, 1), (1)
|
||||
|
||||
|
||||
class ClipByNorm(PrimitiveWithInfer):
|
||||
r"""
|
||||
Clips tensor values to a maximum :math:`L_2`-norm.
|
||||
|
||||
Note:
|
||||
The output tensor of this operator remains the same with input tensor if the :math:`L_2`-norm of the input
|
||||
tensor is not greater than the argument `clip_norm`. Otherwise the output tensor will be normalized as:
|
||||
|
||||
.. math::
|
||||
\text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
|
||||
|
||||
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension.
|
||||
Default: None, all dimensions to calculate.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
|
||||
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
|
||||
Or a Tensor which shape can be broadcast to the shape of `x`. The type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, clipped Tensor with the same shape as the `x`, whose type is float32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `axis` is not one of None, int, tuple(int) and list(int).
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
TypeError: If dtype of `clip_norm` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.operations import _inner_ops as inner
|
||||
>>> clip_by_norm = inner.ClipByNorm()
|
||||
>>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
|
||||
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
|
||||
>>> output = clip_by_norm(x, clip_norm)
|
||||
>>> print(output.shape)
|
||||
(4, 16)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=None):
|
||||
"""Initialize ClipByNorm"""
|
||||
self.axis = () if axis is None else axis
|
||||
validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
|
||||
axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
||||
for i, value in enumerate(axis_check):
|
||||
validator.check_value_type('axis[%d]' % i, value, [int], self.name)
|
||||
self.init_attrs['axis'] = self.axis
|
||||
self.add_prim_attr('axis', self.axis)
|
||||
self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape, clip_norm_shape):
|
||||
"""Infer shape for ClipByNorm"""
|
||||
x_dim = len(x_shape)
|
||||
axis = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
||||
for _, value in enumerate(axis):
|
||||
validator.check_int_range(value, -x_dim, x_dim, Rel.INC_LEFT, 'axis', self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, clip_norm_type):
|
||||
"""Infer data type for ClipByNorm"""
|
||||
validator.check_tensor_dtype_valid("x_type", x_type, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid("clip_norm_type", clip_norm_type,
|
||||
[mstype.float16, mstype.float32], self.name)
|
||||
return mstype.float32
|
||||
|
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Operators for clip."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
||||
class ClipByNorm(PrimitiveWithInfer):
|
||||
r"""
|
||||
Clips tensor values to a maximum :math:`L_2`-norm.
|
||||
|
||||
Note:
|
||||
The output tensor of this operator remains the same with input tensor if the :math:`L_2`-norm of the input
|
||||
tensor is not greater than the argument `clip_norm`. Otherwise the output tensor will be normalized as:
|
||||
|
||||
.. math::
|
||||
\text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
|
||||
|
||||
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension.
|
||||
Default: None, all dimensions to calculate.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
|
||||
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
|
||||
Or a Tensor which shape can be broadcast to the shape of `x`. The type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, clipped Tensor with the same shape as the `x`, whose type is float32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `axis` is not one of None, int, tuple(int) and list(int).
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
TypeError: If dtype of `clip_norm` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.operations import clip_ops
|
||||
>>> clip_by_norm = clip_ops.ClipByNorm()
|
||||
>>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
|
||||
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
|
||||
>>> output = clip_by_norm(x, clip_norm)
|
||||
>>> print(output.shape)
|
||||
(4, 16)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=None):
|
||||
"""Initialize ClipByNorm"""
|
||||
self.axis = () if axis is None else axis
|
||||
validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
|
||||
axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
||||
for i, value in enumerate(axis_check):
|
||||
validator.check_value_type('axis[%d]' % i, value, [int], self.name)
|
||||
self.init_attrs['axis'] = self.axis
|
||||
self.add_prim_attr('axis', self.axis)
|
||||
self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape, clip_norm_shape):
|
||||
"""Infer shape for ClipByNorm"""
|
||||
x_dim = len(x_shape)
|
||||
axis = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
||||
for _, value in enumerate(axis):
|
||||
validator.check_int_range(value, -x_dim, x_dim, Rel.INC_LEFT, 'axis', self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, clip_norm_type):
|
||||
"""Infer data type for ClipByNorm"""
|
||||
validator.check_tensor_dtype_valid("x_type", x_type, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid("clip_norm_type", clip_norm_type,
|
||||
[mstype.float16, mstype.float32], self.name)
|
||||
return mstype.float32
|
|
@ -1,107 +0,0 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_clip_by_norm_graph():
|
||||
"""
|
||||
Feature: ClipByNorm Operation function verification in GRAPH mode.
|
||||
Description: The calculation results of 'ops.ClipByNorm' should be same with the 'nn.ClipByNorm'.
|
||||
Expectation: Normal output without assert wrong.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
# test input arg with data type float32 and float32
|
||||
x1 = np.random.rand(2, 3, 6, 16) * 10 - 5
|
||||
x1 = Tensor(x1, ms.float32)
|
||||
clip_norm_g = Tensor(np.array([1.0]).astype(np.float32))
|
||||
actual_out1 = F.clip_by_norm(x1, clip_norm_g)
|
||||
expected_out1 = nn.ClipByNorm()(x1, clip_norm_g)
|
||||
assert np.allclose(actual_out1.asnumpy(), expected_out1.asnumpy(), 0.0001, 0.0001)
|
||||
# test input arg with data type float16 and float16
|
||||
x2 = np.random.rand(2, 3, 6, 16) * 10 - 5
|
||||
x2 = Tensor(x2, ms.float16)
|
||||
clip_norm_g = Tensor(np.array([1.0]).astype(np.float16))
|
||||
actual_out2 = F.clip_by_norm(x2, clip_norm_g)
|
||||
expected_out2 = nn.ClipByNorm()(x2, clip_norm_g)
|
||||
assert np.allclose(actual_out2.asnumpy(), expected_out2.asnumpy(), 0.001, 0.001)
|
||||
# test input arg with data type float32 and float16
|
||||
x3 = np.random.rand(2, 3, 6, 16) * 10 - 5
|
||||
x3 = Tensor(x3, ms.float32)
|
||||
clip_norm_g = Tensor(np.array([2.0]).astype(np.float16))
|
||||
actual_out3 = F.clip_by_norm(x3, clip_norm_g)
|
||||
expected_out3 = nn.ClipByNorm()(x3, clip_norm_g)
|
||||
assert np.allclose(actual_out3.asnumpy(), expected_out3.asnumpy(), 0.0001, 0.0001)
|
||||
# test input arg with data type float16 and float32
|
||||
x4 = np.random.rand(2, 3, 6, 16) * 10 - 5
|
||||
x4 = Tensor(x4, ms.float16)
|
||||
clip_norm_g = Tensor(np.array([2.0]).astype(np.float32))
|
||||
actual_out4 = F.clip_by_norm(x4, clip_norm_g)
|
||||
expected_out4 = nn.ClipByNorm()(x4, clip_norm_g)
|
||||
assert np.allclose(actual_out4.asnumpy(), expected_out4.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_clip_by_norm_pynative():
|
||||
"""
|
||||
Feature: ClipByNorm Operation function verification in PyNative mode.
|
||||
Description: The calculation results of 'ops.ClipByNorm' should be same with the 'nn.ClipByNorm'.
|
||||
Expectation: Normal output without assert wrong.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
# test input arg with data type float32 and float32
|
||||
x5 = np.random.rand(2, 3, 6, 16) * 10 - 5
|
||||
x5 = Tensor(x5, ms.float32)
|
||||
clip_norm_y = Tensor(np.array([1.0]).astype(np.float32))
|
||||
actual_out5 = F.clip_by_norm(x5, clip_norm_y)
|
||||
expected_out5 = nn.ClipByNorm()(x5, clip_norm_y)
|
||||
assert np.allclose(actual_out5.asnumpy(), expected_out5.asnumpy(), 0.0001, 0.0001)
|
||||
# test input arg with data type float16 and float16
|
||||
x6 = np.random.rand(2, 3, 6, 16) * 10 - 5
|
||||
x6 = Tensor(x6, ms.float16)
|
||||
clip_norm_y = Tensor(np.array([1.0]).astype(np.float16))
|
||||
actual_out6 = F.clip_by_norm(x6, clip_norm_y)
|
||||
expected_out6 = nn.ClipByNorm()(x6, clip_norm_y)
|
||||
assert np.allclose(actual_out6.asnumpy(), expected_out6.asnumpy(), 0.001, 0.001)
|
||||
# test input arg with data type float32 and float16
|
||||
x7 = np.random.rand(2, 3, 6, 16) * 10 - 5
|
||||
x7 = Tensor(x7, ms.float32)
|
||||
clip_norm_y = Tensor(np.array([2.0]).astype(np.float16))
|
||||
actual_out7 = F.clip_by_norm(x7, clip_norm_y)
|
||||
expected_out7 = nn.ClipByNorm()(x7, clip_norm_y)
|
||||
assert np.allclose(actual_out7.asnumpy(), expected_out7.asnumpy(), 0.0001, 0.0001)
|
||||
# test input arg with data type float16 and float32
|
||||
x8 = np.random.rand(2, 3, 6, 16) * 10 - 5
|
||||
x8 = Tensor(x8, ms.float16)
|
||||
clip_norm_y = Tensor(np.array([2.0]).astype(np.float32))
|
||||
actual_out8 = F.clip_by_norm(x8, clip_norm_y)
|
||||
expected_out8 = nn.ClipByNorm()(x8, clip_norm_y)
|
||||
assert np.allclose(actual_out8.asnumpy(), expected_out8.asnumpy(), 0.001, 0.001)
|
Loading…
Reference in New Issue