add pdist_cpu_kernel

This commit is contained in:
hemaohua 2022-05-19 16:29:07 +08:00
parent f0fc501cf0
commit cd9c40f807
21 changed files with 625 additions and 1 deletions

View File

@ -36,6 +36,16 @@ functional算子是经过初始化后的Primitive可以直接作为函数使
神经网络层算子
----------------
神经网络
^^^^^^^^
.. mscnplatformautosummary::
:toctree: ops
:nosignatures:
:template: classtemplate.rst
mindspore.ops.pdist
激活函数
^^^^^^^^^^

View File

@ -77,6 +77,7 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支
mindspore.ops.Padding
mindspore.ops.ResizeNearestNeighbor
mindspore.ops.ResizeBilinear
mindspore.ops.Pdist
损失函数
^^^^^^^^^^

View File

@ -1259,3 +1259,19 @@ mindspore.Tensor
**返回:**
Tensor具有与入参 `shape` 相同的维度。
.. py:method:: pdist(p=2.0)
计算输入中每对行向量之间的p-范数距离。
.. math::
y[n] = \sqrt[p]{{\mid x_{i} - x_{j} \mid}^p}
**参数:**
- **p** (float) - P -范数距离的P值P∈[0∞]。默认值:2.0。
**返回:**
Tensor类型与 `x` 一致。

View File

@ -0,0 +1,9 @@
mindspore.ops.Pdist
==================
.. py:function:: mindspore.ops.Pdist()
计算输入中每对行向量之间的p-范数距离。
更多参考详见 :func:`mindspore.ops.pdist`

View File

@ -0,0 +1,30 @@
mindspore.ops.pdist
==================
.. py:function:: mindspore.ops.pdist(x, p)
计算输入中每对行向量之间的p-范数距离。如果输入`x`的shape为 :math:`(N, M)`那么输出就是一个shape为 :math:`(N * (N - 1) / 2,)`
的Tensor。如果`x`的shape为 :math:`(*B, N, M)`那么输出就是一个shape为 :math:`(*B, N * (N - 1) / 2)`的Tensor。
.. math::
y[n] = \sqrt[p]{{\mid x_{i} - x_{j} \mid}^p}
**参数:**
- **x** (tensor) - 输入tensor x其shape为 :math:`(*B, N, M)`,其中 :math:`*B`表示批处理大小可以是多维度。类型float16float32或float64。
- **p** (float) - P -范数距离的P值P∈[0∞]。默认值:2.0。
**返回:**
Tensor类型与 `x` 一致。
**异常:**
- **TypeError** - `x` 不是tensor。
- **TypeError** - `x` 的数据类型不是float16float32float64。
- **TypeError** - `p` 不是float。
- **ValueError** - `p` 是负数。
- **ValueError** - `x` 的维度小于2。
**支持平台:**
``CPU``

View File

@ -36,6 +36,16 @@ The functional operators are initialized Primitives and can be used directly as
Neural Network Layer Operators
------------------------------
Neural Network
^^^^^^^^^^^^^^
.. msplatformautosummary::
:toctree: ops
:nosignatures:
:template: classtemplate.rst
mindspore.ops.pdist
Activation Functions
^^^^^^^^^^^^^^^^^^^^

View File

@ -77,6 +77,7 @@ Neural Network
mindspore.ops.Padding
mindspore.ops.ResizeNearestNeighbor
mindspore.ops.ResizeBilinear
mindspore.ops.Pdist
Loss Function
^^^^^^^^^^^^^

View File

@ -240,6 +240,7 @@ BuiltInTypeMap &GetMethodMap() {
{"gather_nd", std::string("gather_nd")}, // P.GatherNd()
{"unique_consecutive", std::string("unique_consecutive")}, // UniqueConsecutive()
{"diag", std::string("diag")}, // P.Diag()
{"pdist", std::string("pdist")}, // F.pdist()
}},
{kObjectTypeRowTensorType,
{

View File

@ -0,0 +1,162 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <functional>
#include <map>
#include <algorithm>
#include "plugin/device/cpu/kernel/pdist_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/pdist.h"
#include "abstract/utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kPdistInputsNum = 1;
constexpr size_t kPdistOutputsNum = 1;
} // namespace
template <typename T>
void PdistZeroNormalcompute(const T *input, T *output, size_t start_x, size_t start_y, float p, size_t col,
size_t idx) {
double res = 0;
for (size_t i = 0; i < col; i++) {
res += (input[start_x + i] == input[start_y + i]) ? 0 : 1;
}
output[idx] = static_cast<T>(res);
}
template <typename T>
void PdistInfNormalcompute(const T *input, T *output, size_t start_x, size_t start_y, float p, size_t col, size_t idx) {
double res = 0;
for (size_t i = 0; i < col; i++) {
double x = static_cast<double>(input[start_x + i]);
double y = static_cast<double>(input[start_y + i]);
res = std::max(std::abs(x - y), res);
}
output[idx] = static_cast<T>(res);
}
template <typename T>
void PdistOneNormalcompute(const T *input, T *output, size_t start_x, size_t start_y, float p, size_t col, size_t idx) {
double res = 0;
for (size_t i = 0; i < col; i++) {
double x = static_cast<double>(input[start_x + i]);
double y = static_cast<double>(input[start_y + i]);
res += std::abs(x - y);
}
output[idx] = static_cast<T>(res);
}
template <typename T>
void PdistNormalcompute(const T *input, T *output, size_t start_x, size_t start_y, float p, size_t col, size_t idx) {
double res = 0;
for (size_t i = 0; i < col; i++) {
double x = static_cast<double>(input[start_x + i]);
double y = static_cast<double>(input[start_y + i]);
res += std::pow(std::abs(x - y), p);
}
res = std::pow(res, 1.0 / p);
output[idx] = static_cast<T>(res);
}
bool PdistCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Pdist>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast Pdist ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
p_ = kernel_ptr->get_p();
if (inputs.size() != kPdistInputsNum || outputs.size() != kPdistOutputsNum) {
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kPdistInputsNum << " and "
<< kPdistOutputsNum << ", but get " << inputs.size() << " and " << outputs.size();
return false;
}
auto input_shape = inputs[0]->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
input_dim_ = input_shape_.size();
input_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
auto input_dtype_ = inputs[0]->GetDtype();
switch (input_dtype_) {
case kNumberTypeFloat64:
kernel_func_ = &PdistCpuKernelMod::LaunchKernel<double>;
break;
case kNumberTypeFloat32:
kernel_func_ = &PdistCpuKernelMod::LaunchKernel<float>;
break;
case kNumberTypeFloat16:
kernel_func_ = &PdistCpuKernelMod::LaunchKernel<float16>;
break;
default:
MS_LOG(ERROR) << "Pdist kernel does not support " << TypeIdToString(input_dtype_);
return false;
}
return true;
}
int PdistCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
if (NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others) == KRET_RESIZE_FAILED) {
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
return KRET_RESIZE_FAILED;
}
return 0;
}
template <typename T>
bool PdistCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
auto output = reinterpret_cast<T *>(outputs[0]->addr);
auto col = input_shape_[input_dim_ - 1];
auto temp = input_shape_[input_dim_ - 1] * input_shape_[input_dim_ - 2];
auto task = [this, &input, &output, col, temp](size_t start, size_t end) {
size_t idx = 0;
for (size_t i = start; i < end; i = i + temp) {
for (size_t j = i; j < i + temp; j = j + col) {
for (size_t k = j + col; k < i + temp; k = k + col) {
if (p_ == 0.0) {
PdistZeroNormalcompute(input, output, j, k, p_, col, idx);
} else if (std::isinf(p_)) {
PdistInfNormalcompute(input, output, j, k, p_, col, idx);
} else if (p_ == 1.0) {
PdistOneNormalcompute(input, output, j, k, p_, col, idx);
} else {
PdistNormalcompute(input, output, j, k, p_, col, idx);
}
idx++;
}
}
}
};
ParallelLaunchAutoSearch(task, input_size_, this, &parallel_search_info_, pool_);
return true;
}
std::vector<KernelAttr> PdistCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Pdist, PdistCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PDIST_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PDIST_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class PdistCpuKernelMod : public NativeCpuKernelMod {
public:
PdistCpuKernelMod() = default;
~PdistCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using PdistKernel = std::function<bool(PdistCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
PdistKernel kernel_func_;
std::vector<size_t> input_shape_;
size_t input_size_;
size_t input_dim_;
float p_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PDIST_CPU_KERNEL_H_

View File

@ -660,6 +660,7 @@ GVAR_DEF(PrimitivePtr, kPrimFractionalAvgPoolGrad, std::make_shared<Primitive>("
GVAR_DEF(PrimitivePtr, kPrimNthElement, std::make_shared<Primitive>("NthElement"));
GVAR_DEF(PrimitivePtr, kPrimGridSampler2D, std::make_shared<Primitive>(kGridSampler2D));
GVAR_DEF(PrimitivePtr, kPrimGridSampler2DGrad, std::make_shared<Primitive>(kGridSampler2DGrad));
GVAR_DEF(PrimitivePtr, kPrimPdist, std::make_shared<Primitive>("Pdist"));
// Comm ops
GVAR_DEF(PrimitivePtr, kPrimMirror, std::make_shared<Primitive>("_MirrorOperator"));

View File

@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/pdist.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr PdistInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_size = x_shape.size();
const int64_t input_dim = 2;
CheckAndConvertUtils::CheckInteger("x dim", x_size, kGreaterEqual, input_dim, "Pdist");
int64_t dim_R = x_shape[x_size - 2];
const float out_shape_used = 0.5;
dim_R = dim_R * (dim_R - 1) * out_shape_used;
std::vector<int64_t> out_shape;
for (size_t i = 0; i < x_size - input_dim; i++) {
out_shape.push_back(x_shape[i]);
}
out_shape.push_back(dim_R);
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr PdistInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kFloat64, kFloat32, kFloat16};
auto x_dtype = input_args[0]->BuildType();
return CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, primitive->name());
}
} // namespace
float Pdist::get_p() const {
auto value_ptr = this->GetAttr(kP);
return GetValue<float>(value_ptr);
}
void Pdist::set_p(const float p) { (void)this->AddAttr(kP, api::MakeValue(p)); }
MIND_API_OPERATOR_IMPL(Pdist, BaseOperator);
AbstractBasePtr PdistInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = PdistInferType(primitive, input_args);
auto infer_shape = PdistInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Pdist, prim::kPrimPdist, PdistInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_PDIST_H_
#define MINDSPORE_CORE_OPS_PDIST_H_
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNamePdist = "Pdist";
/// \brief Computes batched the p norm distance between each pair of row vectors in one collection.
/// Refer to Python API @ref mindspore.ops.Pdist for more details.
class MIND_API Pdist : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Pdist);
/// \brief Constructor.
Pdist() : BaseOperator(kNamePdist) { InitIOName({"x"}, {"y"}); }
void set_p(const float p);
/// \brief Get p.
///
/// \return p.
float get_p() const;
};
abstract::AbstractBasePtr PdistInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_PDIST_H_

View File

@ -2185,3 +2185,11 @@ def gather_nd(input_x, indices):
Refer to :func:`mindspore.ops.gather_nd` for more detail.
"""
return F.gather_nd(input_x, indices)
def pdist(x, p=2.0):
r"""
Computes the p-norm distance between each pair of row vectors in the input.
Refer to :func:`mindspore.ops.pdist` for more detail.
"""
return F.pdist(x, p=p)

View File

@ -3346,6 +3346,39 @@ class Tensor(Tensor_):
return output, counts
return output
def pdist(self, p=2.0):
r"""
Computes the p-norm distance between each pair of row vectors in the input.
.. math::
y[n] = \sqrt[p]{{\mid x_{i} - x_{j} \mid}^p}
where :math:`x_{i}, x_{j}` are two different row vectors in the input.
Args:
p (float): p value for the p norm distance to calculate between each vector pair [0,]. Default: 2.0.
Returns:
Tensor, has the same dtype as self.
Raises:
TypeError: If dtype of Tensor is float16, float32 or float64.
TypeError: If `p` is not a float.
ValueError: If `p` is a negative float.
ValueError: If dimension of Tensor is less than 2.
Supported Platforms:
``CPU``
Examples:
>>> x = Tensor(np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]).astype(np.float32))
>>> y = x.pdist(p=2.0)
>>> print(y)
[1.4142135 2.828427 1.4142135]
"""
self._init_check()
return tensor_operator_registry.get('pdist')(p)(self)
def diag(self):
r"""
Constructs a diagonal tensor with a given diagonal values.

View File

@ -17,6 +17,7 @@
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import nn_ops as NN
from mindspore.ops import functional as F
from mindspore.ops import constexpr
from ..primitive import Primitive
@ -156,6 +157,23 @@ def get_fast_gelu_grad_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(NN.Pdist)
def get_pdist_vmap_rule(prim, axis_size):
"""VmapRule for `Pdist`"""
if isinstance(prim, str):
prim = Primitive(prim)
prim.add_prim_attr('p', 2.0)
def vmap_rule(x_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
x = _bdim_at_front(x, x_dim, axis_size)
out = prim(x)
return out, 0
return vmap_rule
get_unop_vmap_rule = vmap_rules_getters.register(P.Elu)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU6)(get_unop_vmap_rule)

View File

@ -179,6 +179,7 @@ from .nn_func import (
fast_gelu,
hardshrink,
softsign,
pdist,
)
from .linalg_func import (
svd,

View File

@ -16,6 +16,7 @@
"""Defines nn operators with functional form."""
from mindspore.ops import operations as P
from mindspore.ops.operations import nn_ops as NN
fast_gelu_ = P.FastGeLU()
softsign_ = P.Softsign()
@ -126,9 +127,49 @@ def softsign(x):
return softsign_(x)
def pdist(x, p=2.0):
r"""
Computes the p-norm distance between each pair of row vectors in the input. If `x` is a 2D Tensor of
shape :math:`(N, M)`, then `output` must be a 1D Tensor of shape :math:`(N * (N - 1) / 2,)`. If `x` id a
Tensor of shape :math:`(*B, N, M)`, then `output` must be a Tensor of shape :math:`(*B, N * (N - 1) / 2)`.
.. math::
y[n] = \sqrt[p]{{\mid x_{i} - x_{j} \mid}^p}
where :math:`x_{i}, x_{j}` are two different row vectors in the input.
Args:
x (Tensor) - Input tensor of shape :math:`(*B, N, M)`. *B: batch size, one-dim or multi-dim.
dtype: float16, float32, float64.
p (float): p value for the p norm distance to calculate between each vector pair [0,]. Default: 2.0.
Returns:
Tensor, has the same dtype as `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is float16, float32 or float64.
TypeError: If `p` is not a float.
ValueError: If `p` is a negative float.
ValueError: If dimension of `x` is less than 2.
Supported Platforms:
``CPU``
Examples:
>>> x = Tensor(np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]).astype(np.float32))
>>> y = ops.pdist(x, p=2.0)
>>> print(y)
[1.4142135 2.828427 1.4142135]
"""
pdist_ = NN.Pdist(p=p)
return pdist_(x)
__all__ = [
'fast_gelu',
'hardshrink',
'softsign'
'softsign',
'pdist',
]
__all__.sort()

View File

@ -35,6 +35,7 @@ from .operations import _grad_ops
from .operations import _csr_ops
from .operations import linalg_ops
from .operations.array_ops import UniqueConsecutive
from .operations import nn_ops as NN
from .composite import _Grad, Shard, _Vmap, _TaylorOperation
from .._c_expression import security
@ -909,6 +910,7 @@ tensor_operator_registry.register('hardshrink', P.HShrink)
tensor_operator_registry.register('svd', linalg_ops.Svd)
tensor_operator_registry.register('diag', P.Diag)
tensor_operator_registry.register('unique_consecutive', UniqueConsecutive)
tensor_operator_registry.register('pdist', NN.Pdist)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)

View File

@ -9626,3 +9626,18 @@ class GridSampler2D(Primitive):
validator.check_string(padding_mode, ['zeros', 'border', 'reflection'], 'padding_mode', self.name)
validator.check_bool(align_corners, 'align_corners', self.name)
self.init_prim_io_names(inputs=['input', 'grid'], outputs=['output'])
class Pdist(Primitive):
r"""
Computes the p-norm distance between each pair of row vectors in the input.
Refer to :func:`mindspore.ops.pdist` for more detail.
"""
@prim_attr_register
def __init__(self, p=2.0):
"""Initialize Pdist"""
validator.check_value_type("p", p, [float], self.name)
if p < 0:
raise ValueError('Pdist p must be a non-negative value, but got `{p}`.')
self.init_prim_io_names(inputs=['x'], outputs=['y'])

View File

@ -0,0 +1,77 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
from mindspore.ops import functional as F
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)])
def test_pdist_normal(dtype, eps):
"""
Feature: Pdist cpu kernel
Description: test the Pdist p = 2.0.
Expectation: the output matches numpy
"""
x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=dtype))
error = np.ones(shape=(3,)) * eps
output = F.pdist(x, p=2.0)
expect = np.array([1.41421356, 2.82842712, 1.41421356], dtype=dtype)
diff = np.abs(output.asnumpy() - expect)
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)])
def test_pdist_zero(dtype, eps):
"""
Feature: Pdist cpu kernel
Description: test the Pdist p = 0.0.
Expectation: the output matches numpy
"""
x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=dtype))
error = np.ones(shape=(3,)) * eps
output = F.pdist(x, p=0.0)
expect = np.array([2., 2., 2.], dtype=dtype)
diff = np.abs(output.asnumpy() - expect)
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)])
def test_pdist_inf(dtype, eps):
"""
Feature: Pdist cpu kernel
Description: test the Pdist p = inf.
Expectation: the output matches numpy
"""
x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=dtype))
error = np.ones(shape=(3,)) * eps
output = F.pdist(x, p=float('inf'))
expect = np.array([1., 2., 1.], dtype=dtype)
diff = np.abs(output.asnumpy() - expect)
assert np.all(diff < error)