forked from mindspore-Ecosystem/mindspore
add cdist vmap、ceil tensor api
This commit is contained in:
parent
e420e22839
commit
dab8dde706
|
@ -113,6 +113,7 @@ functional算子是经过初始化后的Primitive,可以直接作为函数使
|
|||
mindspore.ops.bitwise_and
|
||||
mindspore.ops.bitwise_or
|
||||
mindspore.ops.bitwise_xor
|
||||
mindspore.ops.ceil
|
||||
mindspore.ops.cos
|
||||
mindspore.ops.cosh
|
||||
mindspore.ops.div
|
||||
|
|
|
@ -235,49 +235,17 @@ mindspore.Tensor
|
|||
|
||||
- **ValueError** - 输入Tensor和任一 `choices` 无法广播。
|
||||
|
||||
.. py:method:: cdist(input_y, p=2.0)
|
||||
.. py:method:: ceil()
|
||||
|
||||
计算两个tensor的p-范数距离。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **input_y** (tensor) - 。输入的向量。
|
||||
- **p** (float) - P -范数距离的P值,P∈[0,∞]。默认值:2.0。
|
||||
向上取整。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor。p-范数距离,数据类型与输入一致。
|
||||
Tensor。向上取整的结果。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果输入参数 `input_y` 不是Tensor类型。
|
||||
- **TypeError** - 如果当前Tensor或输入参数 `input_y` 的数据类型不是float16或者float32。
|
||||
- **TypeError** - 如果参数 `p` 不是一个float值。
|
||||
- **ValueError** - 如果参数 `p` 是负数。
|
||||
- **ValueError** - 如果当前Tensor的维度信息与输入参数 `input_y` 不相同。
|
||||
- **ValueError** - 如果当前Tensor或输入参数 `input_y` 不是2维或3维。
|
||||
|
||||
.. py:method:: celu(alpha=1.0)
|
||||
|
||||
celu激活函数,按输入元素计算输出,公式定义如下:
|
||||
|
||||
.. math::
|
||||
\text{CeLU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
||||
|
||||
**参数:**
|
||||
|
||||
- **alpha** (float) - celu公式定义的阈值 :math:`\alpha` 。默认值:1.0。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape和数据类型与输入相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `alpha` 不是float。
|
||||
- **ValueError** - `alpha` 的值为零。
|
||||
- **TypeError** - `x` 不是tensor。
|
||||
- **TypeError** - `x` 的dtype既不是float16也不是float32。
|
||||
- **TypeError** - 如果当前Tensor的数据类型不是float16或者float32。
|
||||
|
||||
.. py:method:: clip(xmin, xmax, dtype=None)
|
||||
|
||||
|
|
|
@ -242,8 +242,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"sum", std::string("sum")}, // P.ReduceSum
|
||||
{"repeat", std::string("repeat")}, // C.repeat_elements
|
||||
{"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli()
|
||||
{"cdist", std::string("cdist")}, // P.cdist
|
||||
{"celu", std::string("celu")}, // P.celu
|
||||
{"ceil", std::string("ceil")}, // P.Ceil
|
||||
{"hardshrink", std::string("hardshrink")}, // P.hshrink
|
||||
{"soft_shrink", std::string("soft_shrink")}, // P.SoftShrink
|
||||
{"one_hot", std::string("one_hot")}, // P.OneHot
|
||||
|
|
|
@ -20,10 +20,12 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kCdistInputDimsMin = 2;
|
||||
abstract::ShapePtr CdistInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -38,8 +40,22 @@ abstract::ShapePtr CdistInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
<< "', rank of input_x and input_y must be equal, but got rank of input_x: " << x_size
|
||||
<< ", rank of input_y: " << y_size << ".";
|
||||
}
|
||||
CheckAndConvertUtils::CheckInRange("input_x dim", x_size, kIncludeBoth, {2, 3}, "Cdist");
|
||||
int64_t dim_R = y_shape[y_size - 2];
|
||||
if (x_size < kCdistInputDimsMin) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', rank of input must be greater than "
|
||||
<< kCdistInputDimsMin << ", but got rank of input: " << x_size << ".";
|
||||
}
|
||||
|
||||
if (x_size > kCdistInputDimsMin) {
|
||||
for (size_t i = 0; i < x_size - kCdistInputDimsMin; i++) {
|
||||
if (x_shape[i] != y_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', the batch shape of 'x' must be the same as the shape of 'y', "
|
||||
"but got 'x_shape["
|
||||
<< i << "]': " << x_shape[i] << " and 'y_shape[" << i << "]': " << y_shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t dim_R = y_shape[y_size - kCdistInputDimsMin];
|
||||
auto out_shape = x_shape;
|
||||
out_shape.pop_back();
|
||||
out_shape.push_back(dim_R);
|
||||
|
|
|
@ -1482,13 +1482,6 @@ def repeat(x, repeats, axis=None):
|
|||
return P.Concat(axis)(repeated_subs)
|
||||
|
||||
|
||||
def celu(x, alpha=1.0):
|
||||
r"""
|
||||
Apply the Hard Shrink function for a tensor. Calculates the output according to the input elements.
|
||||
"""
|
||||
return P.CeLU(alpha)(x)
|
||||
|
||||
|
||||
def hardshrink(x, lambd=0.5):
|
||||
r"""
|
||||
Apply the Hard Shrink function for a tensor. Calculates the output according to the input elements.
|
||||
|
@ -2083,11 +2076,11 @@ def float_floordiv(x, y):
|
|||
return floor(x / y)
|
||||
|
||||
|
||||
def cdist(x, y, p=2.0):
|
||||
def ceil(x):
|
||||
"""
|
||||
Computes batched the p-norm distance between each pair of the two collections of row vectors.
|
||||
Rounds a tensor up to the closest integer element-wise.
|
||||
"""
|
||||
return F.cdist(x, y, p)
|
||||
return F.ceil(x)
|
||||
|
||||
|
||||
#############
|
||||
|
|
|
@ -439,6 +439,7 @@ class Tensor(Tensor_):
|
|||
self.assign_value_cpp(value)
|
||||
return self
|
||||
|
||||
|
||||
def item(self, index=None):
|
||||
"""
|
||||
Get the item at the specified index of the tensor.
|
||||
|
@ -612,38 +613,6 @@ class Tensor(Tensor_):
|
|||
axis = ()
|
||||
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
|
||||
|
||||
def cdist(self, input_y, p=2.0):
|
||||
"""
|
||||
Computes batched the p-norm distance between each pair of the two collections of row vectors.
|
||||
|
||||
Args:
|
||||
input_y (Tensor): as the same dtype as `self`, Input tensor of shape :math:`(B, R, M)`.
|
||||
p (float): P value for the p-norm distance to calculate between each vector pair, P ∈ [0,∞]. Default: 2.0.
|
||||
|
||||
Returns:
|
||||
ensor, has the same dtype as `input_y`, which shape is :math:`(B, P, R)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_y` is not a Tensor.
|
||||
TypeError: If dtype of `self` or `input_y` is neither float16 nor float32.
|
||||
TypeError: If `p` is not a float.
|
||||
ValueError: If `p` is a negative float.
|
||||
ValueError: If dimension of `self` is not the same as `input_y`.
|
||||
ValueError: If dimension of `self` or `input_y` is neither 2 nor 3.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> a = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
|
||||
>>> y = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||
>>> output = a.cdist(y)
|
||||
>>> print(output)
|
||||
"""
|
||||
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('cdist')(p)(self, input_y)
|
||||
|
||||
def view(self, *shape):
|
||||
"""
|
||||
|
@ -1013,6 +982,26 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('abs')()(self)
|
||||
|
||||
def ceil(self):
|
||||
"""
|
||||
Rounds a tensor up to the closest integer element-wise.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> a = Tensor([1.1, 2.5, -1.5]).astype("float32")
|
||||
>>> output = a.ceil()
|
||||
>>> print(output)
|
||||
[ 2. 3. -1.]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('ceil')()(self)
|
||||
|
||||
def lerp(self, end, weight):
|
||||
"""
|
||||
Does a linear interpolation of two tensors start and end based on a float or tensor weight.
|
||||
|
@ -3605,45 +3594,6 @@ class Tensor(Tensor_):
|
|||
s, _, _ = svd_op(full_matrices, compute_uv)(self)
|
||||
return s
|
||||
|
||||
def celu(self, alpha=1.0):
|
||||
r"""
|
||||
Computes celu (Continuously differentiable exponential linear units) of input tensors element-wise.
|
||||
|
||||
The formula is defined as follows:
|
||||
|
||||
.. math::
|
||||
\text{CeLU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
||||
|
||||
It returns :math:`\max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))` element-wise.
|
||||
|
||||
The picture about celu looks like this `celu <https://arxiv.org/abs/1704.07483>`_.
|
||||
|
||||
Args:
|
||||
alpha (float): The :math:`\alpha` value for the Celu formulation. Default: 1.0
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and data type as self.
|
||||
|
||||
Raises:
|
||||
TypeError: If `alpha` is not a float.
|
||||
ValueError: If `alpha` has the value of 0.
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]), mindspore.float32)
|
||||
>>> print(x.celu())
|
||||
[-0.86466473 -0.63212055 1. 2. ]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('celu')(alpha)(self)
|
||||
|
||||
def hardshrink(self, lambd=0.5):
|
||||
r"""
|
||||
Apply the Hard Shrink function for tensor. Calculates the output according to the input elements.
|
||||
|
|
|
@ -89,6 +89,27 @@ def get_broadcast_binary_op_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Cdist)
|
||||
def get_cdist_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `cdist` operation."""
|
||||
def vmap_rule(x_bdim, y_bdim):
|
||||
x, x_dim = x_bdim
|
||||
y, y_dim = y_bdim
|
||||
|
||||
if x_dim is None and y_dim is None:
|
||||
out = prim(x, y)
|
||||
return (out, None)
|
||||
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
y = _bdim_at_front(y, y_dim, axis_size)
|
||||
|
||||
out = prim(x, y)
|
||||
return (out, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(math_ops.Lerp)
|
||||
def get_lerp_vamp_rule(prim, axis_size):
|
||||
"""VmapRule for ternary operations with broadcasting, such as `Lerp` ."""
|
||||
|
|
|
@ -186,6 +186,7 @@ from .math_func import (
|
|||
erf,
|
||||
erfc,
|
||||
cdist,
|
||||
ceil,
|
||||
bernoulli,
|
||||
bessel_i0,
|
||||
bessel_i0e,
|
||||
|
|
|
@ -51,6 +51,7 @@ def get_x_shape(x_shape):
|
|||
# Public Operation Functions.
|
||||
#####################################
|
||||
absolute = P.Abs()
|
||||
tensor_ceil = P.Ceil()
|
||||
tensor_add = P.Add()
|
||||
neg_tensor = P.Neg()
|
||||
tensor_sub = P.Sub()
|
||||
|
@ -337,6 +338,38 @@ def neg(x):
|
|||
return neg_tensor(x)
|
||||
|
||||
|
||||
def ceil(x):
|
||||
r"""
|
||||
Rounds a tensor up to the closest integer element-wise.
|
||||
|
||||
.. math::
|
||||
|
||||
out_i = \lceil x_i \rceil = \lfloor x_i \rfloor + 1
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor. It's element data type must be float16 or float32.
|
||||
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is not float16 or float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([1.1, 2.5, -1.5]), mindspore.float32)
|
||||
>>> output = F.ceil(x)
|
||||
>>> print(output)
|
||||
[ 2. 3. -1.]
|
||||
"""
|
||||
return tensor_ceil(x)
|
||||
|
||||
|
||||
def round(x):
|
||||
r"""
|
||||
Returns half to even of a tensor element-wise.
|
||||
|
@ -2929,9 +2962,10 @@ def cdist(x, y, p=2.0):
|
|||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
|
||||
>>> y = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||
>>> output = ops.cdist(x, y, 2.0)
|
||||
>>> output = F.cdist(x, y, 2.0)
|
||||
>>> print(output)
|
||||
[[[2.8284273 2.8284273]
|
||||
[1.4142137 1.4142137]]]
|
||||
|
@ -3781,6 +3815,7 @@ __all__ = [
|
|||
'erf',
|
||||
'erfc',
|
||||
'cdist',
|
||||
'ceil',
|
||||
'bernoulli',
|
||||
'bessel_j0',
|
||||
'bessel_j1',
|
||||
|
|
|
@ -942,7 +942,7 @@ tensor_operator_registry.register('maximum', P.Maximum)
|
|||
tensor_operator_registry.register('minimum', P.Minimum)
|
||||
tensor_operator_registry.register('matrix_determinant', matrix_determinant)
|
||||
tensor_operator_registry.register('log_matrix_determinant', log_matrix_determinant)
|
||||
tensor_operator_registry.register('cdist', P.Cdist)
|
||||
tensor_operator_registry.register('ceil', P.Ceil)
|
||||
tensor_operator_registry.register('fill', P.Fill)
|
||||
tensor_operator_registry.register('tile', P.Tile)
|
||||
tensor_operator_registry.register('logical_not', P.LogicalNot)
|
||||
|
@ -958,7 +958,6 @@ tensor_operator_registry.register('nonzero', nonzero)
|
|||
tensor_operator_registry.register('isclose', isclose)
|
||||
tensor_operator_registry.register('inv', inv)
|
||||
tensor_operator_registry.register('invert', invert)
|
||||
tensor_operator_registry.register('celu', P.CeLU)
|
||||
tensor_operator_registry.register('hardshrink', P.HShrink)
|
||||
tensor_operator_registry.register('soft_shrink', P.SoftShrink)
|
||||
tensor_operator_registry.register('svd', linalg_ops.Svd)
|
||||
|
|
|
@ -20,8 +20,14 @@ from mindspore.ops import operations as P
|
|||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
from mindspore.ops.functional import vmap
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class CdistTEST(nn.Cell):
|
||||
def __init__(self, p):
|
||||
super(CdistTEST, self).__init__()
|
||||
|
@ -30,6 +36,7 @@ class CdistTEST(nn.Cell):
|
|||
def construct(self, x1, x2):
|
||||
return self.cdist(x1, x2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -43,10 +50,12 @@ def test_CdistP2_float32():
|
|||
x1 = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
|
||||
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||
output = cdist(x1, x2)
|
||||
expect = np.array([[[2.828427, 2.828427], [1.4142135, 1.4142135]]]).astype(np.float32)
|
||||
expect = np.array(
|
||||
[[[2.828427, 2.828427], [1.4142135, 1.4142135]]]).astype(np.float32)
|
||||
print(output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -64,6 +73,7 @@ def test_CdistP0_float32():
|
|||
print(output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -95,10 +105,12 @@ def test_CdistP8_float32():
|
|||
x1 = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
|
||||
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||
output = cdist(x1, x2)
|
||||
expect = np.array([[[2.1810155, 2.1810155], [1.0905077, 1.0905077]]]).astype(np.float32)
|
||||
expect = np.array(
|
||||
[[[2.1810155, 2.1810155], [1.0905077, 1.0905077]]]).astype(np.float32)
|
||||
print(output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -115,3 +127,116 @@ def test_CdistPinf_float32():
|
|||
expect = np.array([[[2., 2.], [1., 1.]]]).astype(np.float32)
|
||||
print(output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cdist_p2_float32_func():
|
||||
"""
|
||||
Feature: Cdist cpu kernel
|
||||
Description: test the cdist p = 2.0.
|
||||
Expectation: the output[0] is same as numpy
|
||||
"""
|
||||
x1 = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
|
||||
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||
expect = np.array(
|
||||
[[[2.828427, 2.828427], [1.4142135, 1.4142135]]]).astype(np.float32)
|
||||
output = F.cdist(x1, x2, 2.0)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
print(output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap():
|
||||
"""
|
||||
Feature: cdist vmap.
|
||||
Description: test the rightness of cdist vmap feature.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
def cal_cdist(x, y):
|
||||
return P.Cdist(2.0)(x, y)
|
||||
|
||||
x1 = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
|
||||
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]], [[3.0, 3.0], [3.0, 3.0]],
|
||||
[[3.0, 3.0], [3.0, 3.0]], [[3.0, 3.0], [3.0, 3.0]],
|
||||
[[3.0, 3.0], [3.0, 3.0]], [[3.0, 3.0], [3.0, 3.0]],
|
||||
[[3.0, 3.0], [3.0, 3.0]], [[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||
|
||||
expect = np.array([[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]]]).astype(np.float32)
|
||||
|
||||
vmap_cdist = vmap(cal_cdist, in_axes=(0), out_axes=0)
|
||||
output = vmap_cdist(x1, x2)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
# 【注意】由于Vmap特性在PyNative模式下基于ms_function实现,统一基准,for循环实现也基于ms_function
|
||||
@ms_function
|
||||
def manually_batched(xs, ys):
|
||||
output = []
|
||||
for i in range(xs.shape[0]):
|
||||
output.append(cal_cdist(xs[i], ys[i]))
|
||||
return F.stack(output)
|
||||
|
||||
expect_m = manually_batched(x1, x2)
|
||||
assert (output.asnumpy() == expect_m.asnumpy()).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap2():
|
||||
"""
|
||||
Feature: cdist vmap.
|
||||
Description: test the rightness of cdist vmap feature.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
def cal_cdist(x, y):
|
||||
return P.Cdist(2.0)(x, y)
|
||||
|
||||
x1 = Tensor(np.array([[[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]]],
|
||||
[[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]],
|
||||
[[1.0, 1.0], [2.0, 2.0]], [[1.0, 1.0], [2.0, 2.0]]]]).astype(np.float32))
|
||||
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]], [[3.0, 3.0], [3.0, 3.0]],
|
||||
[[3.0, 3.0], [3.0, 3.0]], [[3.0, 3.0], [3.0, 3.0]],
|
||||
[[3.0, 3.0], [3.0, 3.0]], [[3.0, 3.0], [3.0, 3.0]],
|
||||
[[3.0, 3.0], [3.0, 3.0]], [[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||
|
||||
expect = np.array([[[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]]],
|
||||
[[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]],
|
||||
[[2.828427, 2.828427], [1.4142135, 1.4142135]]]]).astype(np.float32)
|
||||
|
||||
vmap_cdist = vmap(vmap(cal_cdist, in_axes=(0), out_axes=0), in_axes=(0, None), out_axes=0)
|
||||
output = vmap_cdist(x1, x2)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
|
|
@ -81,29 +81,6 @@ def test_celu_func(data_type):
|
|||
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.parametrize("data_type", [np.float32, np.float16])
|
||||
def test_celu_tensor(data_type):
|
||||
"""
|
||||
Feature: Celu cpu kernel
|
||||
Description: test the celu alpha = 1.0.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
error = 1e-3
|
||||
x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]).astype(data_type))
|
||||
expect = np.array([-0.86468184, -0.6321212, 1., 2.]).astype(data_type)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
output = x.celu(1.0)
|
||||
print(output)
|
||||
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
output = x.celu(1.0)
|
||||
print(output)
|
||||
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
|
|
@ -80,29 +80,6 @@ def test_celu_func(data_type):
|
|||
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.parametrize("data_type", [np.float32, np.float16])
|
||||
def test_celu_tensor(data_type):
|
||||
"""
|
||||
Feature: Celu gpu kernel
|
||||
Description: test the celu alpha = 1.0.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
error = 1e-3
|
||||
x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]).astype(data_type))
|
||||
expect = np.array([-0.86468184, -0.6321212, 1., 2.]).astype(data_type)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
output = x.celu(1.0)
|
||||
print(output)
|
||||
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
output = x.celu(1.0)
|
||||
print(output)
|
||||
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
Loading…
Reference in New Issue