forked from mindspore-Ecosystem/mindspore
add function and Tensor clip clamp
This commit is contained in:
parent
a86ad15939
commit
d8068b112e
|
@ -17,6 +17,7 @@
|
|||
"mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/common/dtype.py" "undefined-all-variable"
|
||||
"mindspore/mindspore/python/mindspore/context.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/clip_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations" "super-init-not-called"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/_quant_ops.py" "unused-import"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/nn_ops.py" "redefined-builtin"
|
||||
|
@ -58,6 +59,7 @@
|
|||
"mindspore/mindspore/python/mindspore/ops/function/__init__.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-outer-name"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/clip_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/math_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/functional.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/functional.py" "wildcard-import"
|
||||
|
@ -65,6 +67,7 @@
|
|||
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "len-as-condition"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "redefined-outer-name"
|
||||
"mindspore/mindspore/python/mindspore/common/tensor.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"
|
||||
|
|
|
@ -184,6 +184,8 @@ mindspore.ops
|
|||
mindspore.ops.bitwise_or
|
||||
mindspore.ops.bitwise_xor
|
||||
mindspore.ops.ceil
|
||||
mindspore.ops.clamp
|
||||
mindspore.ops.clip
|
||||
mindspore.ops.copysign
|
||||
mindspore.ops.cos
|
||||
mindspore.ops.cosh
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.clamp
|
||||
=======================
|
||||
|
||||
.. py:method:: mindspore.Tensor.clamp(min, max)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.clamp`。
|
|
@ -1,24 +1,6 @@
|
|||
mindspore.Tensor.clip
|
||||
=====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.clip(xmin, xmax, dtype=None)
|
||||
.. py:method:: mindspore.Tensor.clip(min, max)
|
||||
|
||||
裁剪Tensor中的值。
|
||||
|
||||
给定一个区间,区间外的值将被裁剪到区间边缘。
|
||||
例如,如果指定的间隔为 :math:`[0, 1]` ,则小于0的值将变为0,大于1的值将变为1。
|
||||
|
||||
.. note::
|
||||
目前不支持裁剪 `xmin=nan` 或 `xmax=nan` 。
|
||||
|
||||
参数:
|
||||
- **xmin** (Tensor, scalar, None) - 最小值。如果值为None,则不在间隔的下边缘执行裁剪操作。`xmin` 或 `xmax` 只能有一个为None。
|
||||
- **xmax** (Tensor, scalar, None) - 最大值。如果值为None,则不在间隔的上边缘执行裁剪操作。`xmin` 或 `xmax` 只能有一个为None。如果 `xmin` 或 `xmax` 是Tensor,则三个Tensor将被广播进行shape匹配。
|
||||
- **dtype** (mindspore.dtype, 可选) - 覆盖输出Tensor的dtype。默认值为None。
|
||||
|
||||
返回:
|
||||
Tensor,含有输入Tensor的元素,其中values < `xmin` 被替换为 `xmin` ,values > `xmax` 被替换为 `xmax` 。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 输入的类型与Tensor不一致。
|
||||
- **ValueError** - 输入与Tensor的shape不能广播,或者 `xmin` 和 `xmax` 都是 `None` 。
|
||||
:func:`mindspore.Tensor.clamp` 的别名。
|
|
@ -68,6 +68,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.cholesky
|
||||
mindspore.Tensor.cholesky_inverse
|
||||
mindspore.Tensor.choose
|
||||
mindspore.Tensor.clamp
|
||||
mindspore.Tensor.clip
|
||||
mindspore.Tensor.col2im
|
||||
mindspore.Tensor.conj
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
mindspore.ops.clamp
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.clamp(x, min=None, max=None)
|
||||
|
||||
将输入Tensor值裁剪到指定的最小值和最大值之间。
|
||||
|
||||
限制 :math:`x` 的范围,其 :math:`x` 的最小值为 `min` ,最大值为 `max` 。
|
||||
|
||||
.. math::
|
||||
out_i= \left\{
|
||||
\begin{array}{align}
|
||||
max & \text{ if } x_i\ge max \\
|
||||
x_i & \text{ if } min \lt x_i \lt max \\
|
||||
min & \text{ if } x_i \le min \\
|
||||
\end{array}\right.
|
||||
|
||||
.. note::
|
||||
- `min` 和 `max` 不能同时为None;
|
||||
- 当 `min` 为None,`max` 不为None时,Tensor中大于 `max` 的元素会变为 `max`;
|
||||
- 当 `min` 不为None,`max` 为None时,Tensor中小于 `min` 的元素会变为 `min`;
|
||||
- 当 `min` 大于 `max` 时,Tensor中所有元素的值会被置为 `max`;
|
||||
- :math:`x` , `min` 和 `max` 的数据类型需支持隐式类型转换,且不能为布尔型。
|
||||
|
||||
参数:
|
||||
- **x** (Union(Tensor, list[Tensor], tuple[Tensor])) - `clamp` 的输入,类型为Tensor、Tensor的列表或元组。支持任意维度的Tensor。
|
||||
- **min** (Union(Tensor, float, int)) - 指定最小值。默认值为None。
|
||||
- **max** (Union(Tensor, float, int)) - 指定最大值。默认值为None。
|
||||
|
||||
返回:
|
||||
Tensor、Tensor的列表或元组,表示裁剪后的Tensor。其shape和数据类型和 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **ValueError** - 如果 `min` 和 `max` 都为None。
|
||||
- **TypeError** - 如果 `x` 的数据类型不在Tensor、list[Tensor]或tuple[Tensor]中。
|
||||
- **TypeError** - 如果 `min` 的数据类型不为None、Tensor、float或int。
|
||||
- **TypeError** - 如果 `max` 的数据类型不为None、Tensor、float或int。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.ops.clip
|
||||
===================
|
||||
|
||||
.. py:function:: mindspore.ops.clip(x, min=None, max=None)
|
||||
|
||||
:func:`mindspore.ops.clamp` 的别名。
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
.. note::
|
||||
- `clip_value_min` 和 `clip_value_max` 不能同时为None;
|
||||
- 当 `clip_value_min` 为None, `clip_value_max` 不为None时,Tensor中大于 `clip_value_max` 的元素会变为 `clip_value_max`;
|
||||
- 当 `clip_value_min` 不为None, `clip_value_max` 为None时,Tensor中小于 `clip_value_min` 的元素会变为 `clip_value_min`;
|
||||
- 当 `clip_value_min` 为None,`clip_value_max` 不为None时,Tensor中大于 `clip_value_max` 的元素会变为 `clip_value_max`;
|
||||
- 当 `clip_value_min` 不为None,`clip_value_max` 为None时,Tensor中小于 `clip_value_min` 的元素会变为 `clip_value_min`;
|
||||
- 当 `clip_value_min` 大于 `clip_value_max` 时,Tensor中所有元素的值会被置为 `clip_value_max`;
|
||||
- :math:`x` , `clip_value_min` 和 `clip_value_max` 的数据类型需支持隐式类型转换,且不能为布尔型。
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@
|
|||
mindspore.Tensor.cholesky
|
||||
mindspore.Tensor.cholesky_inverse
|
||||
mindspore.Tensor.choose
|
||||
mindspore.Tensor.clamp
|
||||
mindspore.Tensor.clip
|
||||
mindspore.Tensor.col2im
|
||||
mindspore.Tensor.conj
|
||||
|
|
|
@ -185,6 +185,8 @@ Element-by-Element Operations
|
|||
mindspore.ops.bitwise_or
|
||||
mindspore.ops.bitwise_xor
|
||||
mindspore.ops.ceil
|
||||
mindspore.ops.clip
|
||||
mindspore.ops.clamp
|
||||
mindspore.ops.copysign
|
||||
mindspore.ops.cos
|
||||
mindspore.ops.cosh
|
||||
|
|
|
@ -272,7 +272,8 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"fill", std::string("fill")}, // P.fill()
|
||||
{"fills", std::string("fills")}, // P.fills
|
||||
{"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()
|
||||
{"clip", std::string("clip")}, // P.maximum(P.minimum)
|
||||
{"clamp", std::string("clamp")}, // clamp()
|
||||
{"clip", std::string("clamp")}, // clamp()
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
{"argmax", std::string("argmax")}, // P.Argmax()
|
||||
{"argmin", std::string("argmin")}, // P.Argmax()
|
||||
|
|
|
@ -1744,64 +1744,18 @@ def ptp(x, axis=None, keepdims=False):
|
|||
return x.max(axis, keepdims) - x.min(axis, keepdims)
|
||||
|
||||
|
||||
def clip(x, xmin, xmax, dtype=None):
|
||||
def clamp(x, min=None, max=None):
|
||||
"""
|
||||
Clips (limits) the values in an array.
|
||||
|
||||
Given an interval, values outside the interval are clipped to the interval edges.
|
||||
For example, if an interval of :math:`[0, 1]` is specified, values smaller than 0 become 0,
|
||||
and values larger than 1 become 1.
|
||||
|
||||
Note:
|
||||
Currently, clip with `nan` is not supported.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor containing elements to clip.
|
||||
xmin (Tensor, scalar, None): Minimum value. If None, clipping is not performed
|
||||
on lower interval edge. Not more than one of `xmin` and `xmax` may be None.
|
||||
xmax (Tensor, scalar, None): Maximum value. If None, clipping is not performed
|
||||
on upper interval edge. Not more than one of `xmin` and `xmax` may be None.
|
||||
If `xmin` or `xmax` are tensors, then the three tensors will be broadcasted
|
||||
to match their shapes.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, a tensor with the elements of `x`, but where values
|
||||
< `xmin` are replaced with `xmin`, and those > `xmax` with `xmax`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor([1, 2, 3, -4, 0, 3, 2, 0]).astype("float32")
|
||||
>>> output = x.clip(0, 2)
|
||||
>>> print(output)
|
||||
[1 2 2 0 0 2 2 0]
|
||||
Clamps all elements in `x` into the range `[min, max]`.
|
||||
"""
|
||||
if xmin is None and xmax is None:
|
||||
const_utils.raise_value_error("One of max or min must be given.")
|
||||
is_scalar = False
|
||||
if xmin is not None:
|
||||
xmin = const_utils.make_tensor(xmin, x.dtype)
|
||||
if x.ndim == 0 and xmin.ndim == 0:
|
||||
x = F.maximum(x.reshape((1,)), xmin).squeeze()
|
||||
else:
|
||||
x = F.maximum(x, xmin)
|
||||
if xmax is not None:
|
||||
xmax = const_utils.make_tensor(xmax, x.dtype)
|
||||
if x.ndim == 0 and xmax.ndim == 0:
|
||||
x = F.minimum(x.reshape((1,)), xmax).squeeze()
|
||||
else:
|
||||
x = F.minimum(x, xmax)
|
||||
if is_scalar:
|
||||
return x.squeeze()
|
||||
if dtype is not None:
|
||||
dtype = check_astype_dtype_const(dtype)
|
||||
if dtype != x.dtype:
|
||||
return x.astype(dtype)
|
||||
return x
|
||||
return F.clamp(x, min, max)
|
||||
|
||||
|
||||
def clip(x, min=None, max=None):
|
||||
"""
|
||||
Clamps all elements in `x` into the range `[min, max]`.
|
||||
"""
|
||||
return F.clamp(x, min, max)
|
||||
|
||||
|
||||
def var(x, axis=None, ddof=0, keepdims=False):
|
||||
|
|
|
@ -2286,68 +2286,18 @@ class Tensor(Tensor_):
|
|||
"""
|
||||
return tensor_operator_registry.get('minimum')()(self, other)
|
||||
|
||||
def clip(self, xmin, xmax, dtype=None):
|
||||
def clamp(self, min=None, max=None):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.clamp`.
|
||||
"""
|
||||
Clips (limits) the values in a Tensor.
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('clamp')(self, min, max)
|
||||
|
||||
Given an interval, values outside the interval are clipped to the interval edges.
|
||||
For example, if an interval of :math:`[0, 1]` is specified, values smaller than 0 become 0,
|
||||
and values larger than 1 become 1.
|
||||
|
||||
Note:
|
||||
Currently, clip with `xmin=nan` or `xmax=nan` is not supported.
|
||||
|
||||
Args:
|
||||
xmin (Tensor, scalar, None): Minimum value. If None, clipping is not performed
|
||||
on the lower interval edge. Not more than one of `xmin` and `xmax` may be None.
|
||||
xmax (Tensor, scalar, None): Maximum value. If None, clipping is not performed
|
||||
on the upper interval edge. Not more than one of `xmin` and `xmax` may be None.
|
||||
If `xmin` or `xmax` are tensors, then `xmin`, `xmax` and the given tensor
|
||||
will be broadcasted to match their shapes.
|
||||
dtype (:class:`mindspore.dtype`, optional): Overrides the dtype of the
|
||||
output Tensor. Default is None.
|
||||
|
||||
Returns:
|
||||
Tensor, a tensor with the elements of the input tensor, but where values
|
||||
< `xmin` are replaced with `xmin`, and those > `xmax` with `xmax`.
|
||||
|
||||
Raises:
|
||||
TypeError: If inputs have types not specified above.
|
||||
ValueError: If the shapes of `x1` and `x2` cannot broadcast, or both `xmin` and `xmax` are `None`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor([1, 2, 3, -4, 0, 3, 2, 0]).astype("float32")
|
||||
>>> y = x.clip(0, 2)
|
||||
>>> print(y)
|
||||
[1. 2. 2. 0. 0. 2. 2. 0.]
|
||||
>>> t = Tensor([1, 1, 1, 1, 1, 1, 1, 1])
|
||||
>>> y = x.clip(t, 2)
|
||||
>>> print(y)
|
||||
[1. 2. 2. 1. 1. 2. 2. 1.]
|
||||
def clip(self, min=None, max=None):
|
||||
r"""
|
||||
Alias for :func:`mindspore.Tensor.clamp`.
|
||||
"""
|
||||
if xmin is None and xmax is None:
|
||||
raise ValueError("For 'Tensor.clip', the argument 'xmin' and 'xman' cannot all be None.")
|
||||
x = self
|
||||
# F.maximum/minimum does not support when both operands are scalar
|
||||
if xmin is not None:
|
||||
xmin = Tensor(xmin).astype(x.dtype)
|
||||
if x.ndim == 0 and xmin.ndim == 0:
|
||||
x = tensor_operator_registry.get("maximum")(x.reshape((1,)), xmin).squeeze()
|
||||
else:
|
||||
x = tensor_operator_registry.get("maximum")(x, xmin)
|
||||
if xmax is not None:
|
||||
xmax = Tensor(xmax).astype(x.dtype)
|
||||
if x.ndim == 0 and xmax.ndim == 0:
|
||||
x = tensor_operator_registry.get("minimum")()(x.reshape((1,)), xmax).squeeze()
|
||||
else:
|
||||
x = tensor_operator_registry.get("minimum")()(x, xmax)
|
||||
if dtype is not None and dtype != x.dtype:
|
||||
return x.astype(dtype)
|
||||
return x
|
||||
return self.clamp(min, max)
|
||||
|
||||
def _init_check(self):
|
||||
if self.has_init:
|
||||
|
|
|
@ -541,6 +541,8 @@ from .sparse_unary_func import (
|
|||
)
|
||||
from .clip_func import (
|
||||
clip_by_value,
|
||||
clamp,
|
||||
clip,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
|
|
|
@ -22,6 +22,8 @@ from mindspore.common.tensor import Tensor
|
|||
|
||||
__all__ = [
|
||||
'clip_by_value',
|
||||
'clamp',
|
||||
'clip',
|
||||
]
|
||||
|
||||
hyper_map = C.HyperMap()
|
||||
|
@ -81,6 +83,7 @@ def clip_by_value(x, clip_value_min=None, clip_value_max=None):
|
|||
|
||||
Examples:
|
||||
>>> # case 1: the data type of x is Tensor
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> import numpy as np
|
||||
>>> min_value = Tensor(5, mindspore.float32)
|
||||
|
@ -131,3 +134,80 @@ def clip_by_value(x, clip_value_min=None, clip_value_max=None):
|
|||
if isinstance(x, tuple):
|
||||
results = tuple(results)
|
||||
return results
|
||||
|
||||
|
||||
def clamp(x, min=None, max=None):
|
||||
r"""
|
||||
Clamps tensor values to a specified min and max.
|
||||
|
||||
Limits the value of :math:`x` to a range, whose lower limit is `min` and upper limit is `max` .
|
||||
|
||||
.. math::
|
||||
|
||||
out_i= \left\{
|
||||
\begin{array}{align}
|
||||
max & \text{ if } x_i\ge max \\
|
||||
x_i & \text{ if } min \lt x_i \lt max \\
|
||||
min & \text{ if } x_i \le min \\
|
||||
\end{array}\right.
|
||||
|
||||
Note:
|
||||
- `min` and `max` cannot be None at the same time;
|
||||
- When `min` is None and `max` is not None, the elements in Tensor larger than `max` will become `max`;
|
||||
- When `min` is not None and `max` is None, the elements in Tensor smaller than `min` will become `min`;
|
||||
- If `min` is greater than `max`, the value of all elements in Tensor will be set to `max`;
|
||||
- The data type of `x`, `min` and `max` should support implicit type conversion and cannot be bool type.
|
||||
|
||||
Args:
|
||||
x (Union(Tensor, list[Tensor], tuple[Tensor])): Input data, which type is Tensor or a list or tuple of Tensor.
|
||||
The shape of Tensor is :math:`(N,*)` where :math:`*` means,
|
||||
any number of additional dimensions.
|
||||
min (Union(Tensor, float, int)): The minimum value. Default: None.
|
||||
max (Union(Tensor, float, int)): The maximum value. Default: None.
|
||||
|
||||
Returns:
|
||||
(Union(Tensor, tuple[Tensor], list[Tensor])), a clipped Tensor or a tuple or a list of clipped Tensor.
|
||||
The data type and shape are the same as x.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `min` and `max` are None.
|
||||
TypeError: If the type of `x` is not in Tensor or list[Tensor] or tuple[Tensor].
|
||||
TypeError: If the type of `min` is not in None, Tensor, float or int.
|
||||
TypeError: If the type of `max` is not in None, Tensor, float or int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> # case 1: the data type of x is Tensor
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> import numpy as np
|
||||
>>> min_value = Tensor(5, mindspore.float32)
|
||||
>>> max_value = Tensor(20, mindspore.float32)
|
||||
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
|
||||
>>> output = ops.clamp(x, min_value, max_value)
|
||||
>>> print(output)
|
||||
[[ 5. 20. 5. 7.]
|
||||
[ 5. 11. 6. 20.]]
|
||||
>>> # case 2: the data type of x is list[Tensor]
|
||||
>>> min_value = 5
|
||||
>>> max_value = 20
|
||||
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
|
||||
>>> y = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
|
||||
>>> output = ops.clamp([x,y], min_value, max_value)
|
||||
>>> print(output)
|
||||
[[[ 5. 20. 5. 7.]
|
||||
[ 5. 11. 6. 20.]],
|
||||
[[ 5. 20. 5. 7.]
|
||||
[ 5. 11. 6. 20.]]]
|
||||
"""
|
||||
return clip_by_value(x, min, max)
|
||||
|
||||
|
||||
def clip(x, min=None, max=None):
|
||||
"""
|
||||
Alias for ops.clamp.
|
||||
For details, please refer to :func:`mindspore.ops.clamp`.
|
||||
"""
|
||||
return clamp(x, min, max)
|
||||
|
|
|
@ -238,6 +238,7 @@ tensor_operator_registry.register('erfinv', erfinv)
|
|||
tensor_operator_registry.register('less_equal', less_equal)
|
||||
tensor_operator_registry.register('lcm', lcm)
|
||||
tensor_operator_registry.register('ldexp', ldexp)
|
||||
tensor_operator_registry.register('clamp', clamp)
|
||||
tensor_operator_registry.register('fold', fold)
|
||||
tensor_operator_registry.register('unfold', unfold)
|
||||
tensor_operator_registry.register('index_add', index_add)
|
||||
|
|
|
@ -24,6 +24,16 @@ class NetWorkClipByValue(nn.Cell):
|
|||
return ops.clip_by_value(x, min_value, max_value)
|
||||
|
||||
|
||||
class NetWorkClamp(nn.Cell):
|
||||
def construct(self, x, min_value, max_value):
|
||||
return ops.clamp(x, min_value, max_value)
|
||||
|
||||
|
||||
class NetWorkClip(nn.Cell):
|
||||
def construct(self, x, min_value, max_value):
|
||||
return ops.clip(x, min_value, max_value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
|
@ -73,3 +83,53 @@ def test_clip_by_value_list_tensor(mode):
|
|||
assert np.allclose(output[0].asnumpy(), expect_output[0])
|
||||
assert np.allclose(output[1].asnumpy(), expect_output[1])
|
||||
assert np.allclose(output[2].asnumpy(), expect_output[2])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_ops_clamp(mode):
|
||||
"""
|
||||
Feature: ops.clamp
|
||||
Description: Verify the result of clamp
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
|
||||
net = NetWorkClamp()
|
||||
output_case_1 = net(x, -0.3, 0.4)
|
||||
expect_output_case_1 = [-0.3, 0.4, 0.2349, -0.3, 0.4]
|
||||
output_case_2 = net(x, 0.4, -0.3)
|
||||
expect_output_case_2 = [-0.3, -0.3, -0.3, -0.3, -0.3]
|
||||
assert np.allclose(output_case_1.asnumpy(), expect_output_case_1)
|
||||
assert np.allclose(output_case_2.asnumpy(), expect_output_case_2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_ops_clip(mode):
|
||||
"""
|
||||
Feature: ops.clip
|
||||
Description: Verify the result of clip
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
|
||||
net = NetWorkClip()
|
||||
output_case_1 = net(x, -0.3, 0.4)
|
||||
expect_output_case_1 = [-0.3, 0.4, 0.2349, -0.3, 0.4]
|
||||
output_case_2 = net(x, 0.4, -0.3)
|
||||
expect_output_case_2 = [-0.3, -0.3, -0.3, -0.3, -0.3]
|
||||
assert np.allclose(output_case_1.asnumpy(), expect_output_case_1)
|
||||
assert np.allclose(output_case_2.asnumpy(), expect_output_case_2)
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class ClampNet(nn.Cell):
|
||||
def construct(self, x, min_value, max_value):
|
||||
return x.clamp(min_value, max_value)
|
||||
|
||||
|
||||
class ClipNet(nn.Cell):
|
||||
def construct(self, x, min_value, max_value):
|
||||
return x.clip(min_value, max_value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_clamp(mode):
|
||||
"""
|
||||
Feature: test Tensor.clamp
|
||||
Description: Verify the result of Tensor.clamp
|
||||
Expectation: expect correct forward result
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x_np = np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]).astype(np.float32)
|
||||
x = Tensor(x_np, ms.float32)
|
||||
net = ClampNet()
|
||||
output_ms_case_1 = net(x, 5, 20)
|
||||
expect_output_case_1 = np.clip(x_np, 5, 20)
|
||||
output_ms_case_2 = net(x, 20, 5)
|
||||
expect_output_case_2 = np.clip(x_np, 20, 5)
|
||||
assert np.allclose(output_ms_case_1.asnumpy(), expect_output_case_1)
|
||||
assert np.allclose(output_ms_case_2.asnumpy(), expect_output_case_2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_clip(mode):
|
||||
"""
|
||||
Feature: test Tensor.clip
|
||||
Description: Verify the result of Tensor.clip
|
||||
Expectation: expect correct forward result
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x_np = np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]).astype(np.float32)
|
||||
x = Tensor(x_np, ms.float32)
|
||||
net = ClipNet()
|
||||
output_ms_case_1 = net(x, 5, 20)
|
||||
expect_output_case_1 = np.clip(x_np, 5, 20)
|
||||
output_ms_case_2 = net(x, 20, 5)
|
||||
expect_output_case_2 = np.clip(x_np, 20, 5)
|
||||
assert np.allclose(output_ms_case_1.asnumpy(), expect_output_case_1)
|
||||
assert np.allclose(output_ms_case_2.asnumpy(), expect_output_case_2)
|
Loading…
Reference in New Issue