some operators issues:

FFTWithSize api description
  SquareSumAll cn & en inconsistent
  ctc_greedy_decoder, trunc, population_count api description
  ctc_greedy_decoder runtime_error to value_error
  trunc dynamic shape fix
This commit is contained in:
panfengfeng 2022-07-31 22:15:05 -04:00
parent 38bb09add2
commit 26af16f65b
18 changed files with 54 additions and 177 deletions

View File

@ -187,12 +187,8 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支
:template: classtemplate.rst
mindspore.ops.ComputeAccidentalHits
mindspore.ops.GridSampler2D
mindspore.ops.GridSampler3D
mindspore.ops.LogUniformCandidateSampler
mindspore.ops.UniformCandidateSampler
mindspore.ops.UpsampleNearest3D
mindspore.ops.UpsampleTrilinear3D
图像处理
^^^^^^^^^^

View File

@ -1,8 +0,0 @@
mindspore.ops.GridSampler2D
===========================
.. py:class:: mindspore.ops.GridSampler2D(interpolation_mode='bilinear', padding_mode='zeros', align_corners=False)
给定一个输入和一个网格,使用网格中的输入值和像素位置计算输出。
更多参考详见 :func:`mindspore.ops.grid_sample`

View File

@ -1,8 +0,0 @@
mindspore.ops.GridSampler3D
===========================
.. py:class:: mindspore.ops.GridSampler3D(interpolation_mode='bilinear', padding_mode='zeros', align_corners=False)
给定一个输入和一个网格,使用网格中的输入值和像素位置计算输出。
更多参考详见 :func:`mindspore.ops.grid_sample`

View File

@ -11,7 +11,7 @@
\end{matrix}\right.
.. note::
SquareSumAll只支持float32和float64类型的输入值。
SquareSumAll只支持float16和float32类型的输入值。
输入:
- **x** (Tensor) - SquareSumAll的输入其数据类型为数值型shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。

View File

@ -1,28 +0,0 @@
mindspore.ops.UpsampleNearest3D
===============================
.. py:class:: mindspore.ops.UpsampleNearest3D(output_size=None, scales=None)
执行最近邻上采样操作。
使用指定的 `output_size``scales` 因子采用最近邻算法对输入进行缩放。其中 `output_size``scales` 必须给出一个,且不能同时指定。
参数:
- **output_size** (Union[tuple[int], list[int]]) - 指定输出卷大小的int型列表。默认值为None。
- **scales** (Union[tuple[float], list[float]]) - 指定上采样因子的浮点数列表。默认值为None。
输入:
- **x** (Tensor) - shape: :math:`(N, C, D_{in}, H_{in}, W_{in})`
输出:
Tensorshape: :math:`(N, C, D_{out}, H_{out}, W_{out})`,数据类型与输入 `x` 相同。
异常:
- **TypeError** - `x` 的维度不为5。
- **TypeError** - `x` 的数据类型不为float16float32。
- **TypeError** - `output_size` 的数据类型不为int型列表。
- **TypeError** - `scales` 的数据类型不为float型列表。
- **ValueError** - `output_size` 的类型为列表其长度不为3。
- **ValueError** - `scales` 的类型为列表其长度不为3。
- **ValueError** - `output_size``scales` 两者都为None。
- **ValueError** - `output_size``scales` 两者都为非空列表。

View File

@ -1,30 +0,0 @@
mindspore.ops.UpsampleTrilinear3D
=================================
.. py:class:: mindspore.ops.UpsampleTrilinear3D(output_size=None, scales=None, align_corners=False)
对于5D输入在3D上用三线性插值进行上采样。
使用指定的 `output_size``scales` 因子对容量输入进行缩放。其中 `output_size``scales` 必须给出一个,且不能同时指定。
参数:
- **output_size** (Union[tuple[int], list[int]]) - 指定输出卷大小的int型列表。默认值为None。
- **scales** (Union[tuple[float], list[float]]) - 指定上采样因子的浮点数列表。默认值为None。
- **align_corners** (bool) - True表示使用 :math:`(new\_height - 1) / (height - 1)` 进行缩放使其输入图像与被调整之后的图像四角对齐False表示使用 :math:`new\_height / height` 进行缩放。 默认值为False。
输入:
- **x** (Tensor) - shape: :math:`(N, C, D_{in}, H_{in}, W_{in})`
输出:
Tensorshape: :math:`(N, C, D_{out}, H_{out}, W_{out})`,数据类型与输入 `x` 相同。
异常:
- **TypeError** - `x` 的维度不为5。
- **TypeError** - `x` 的数据类型不为float16float32。
- **TypeError** - `output_size` 的数据类型不为int型列表。
- **TypeError** - `scales` 的数据类型不为float型列表。
- **TypeError** - `align_corners` 不是一个布尔值。
- **ValueError** - `output_size` 的类型为列表其长度不为3。
- **ValueError** - `scales` 的类型为列表其长度不为3。
- **ValueError** - `output_size``scales` 两者都为None。
- **ValueError** - `output_size``scales` 两者都为非空列表。

View File

@ -6,11 +6,11 @@ mindspore.ops.population_count
计算二进制数中1的个数。
参数:
- **input_x** (Tensor) - 任意维度的Tensor。Ascend平台支持的数据类型为int16、uint16CPU平台支持的数据类型为int8、int16、int32、int64、uint8、uint16、uint32、uint64。
- **input_x** (Tensor) - 任意维度的Tensor。Ascend平台支持的数据类型为int16、uint16CPU和GPU平台支持的数据类型为int8、int16、int32、int64、uint8、uint16、uint32、uint64。
返回:
Tensorshape与 `input_x` 相同数据类型为uint8。
异常:
- **TypeError** - `input_x` 不是Tensor。
- **TypeError** - `input_x` 的数据类型不是int16或uint16Ascend平台`input` 的数据类型不是int8、int16、int32、int64、uint8、uint16、uint32、uint64CPU平台
- **TypeError** - `input_x` 的数据类型不是int16或uint16Ascend平台`input` 的数据类型不是int8、int16、int32、int64、uint8、uint16、uint32、uint64CPU和GPU平台)。

View File

@ -186,12 +186,8 @@ Sampling Operator
:template: classtemplate.rst
mindspore.ops.ComputeAccidentalHits
mindspore.ops.GridSampler2D
mindspore.ops.GridSampler3D
mindspore.ops.LogUniformCandidateSampler
mindspore.ops.UniformCandidateSampler
mindspore.ops.UpsampleNearest3D
mindspore.ops.UpsampleTrilinear3D
Image Processing
^^^^^^^^^^^^^^^^

View File

@ -45,12 +45,23 @@ bool TruncCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto input_shape = inputs[kZero]->GetShapeVector();
input_size_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<size_t>());
dtype_ = inputs[kZero]->GetDtype();
return true;
}
int TruncCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = inputs[kZero]->GetShapeVector();
input_size_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<size_t>());
return KRET_OK;
}
bool TruncCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {

View File

@ -20,6 +20,7 @@
#include <vector>
#include <utility>
#include <memory>
#include <map>
#include "mindspore/core/ops/trunc.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
@ -34,6 +35,9 @@ class TruncCpuKernelMod : public NativeCpuKernelMod {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

View File

@ -84,19 +84,20 @@ AbstractBasePtr CTCGreedyDecoderInfer(const abstract::AnalysisEnginePtr &, const
}
if (inputs_x_shape.size() != kInputsRank) {
MS_LOG(EXCEPTION) << "For '" << prim_name << "', inputs's dim must be 3, but got: " << inputs_x_shape.size() << ".";
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', inputs's dim must be 3, but got: " << inputs_x_shape.size()
<< ".";
}
if (sequence_length_shape.size() != kSeqLenRank) {
MS_LOG(EXCEPTION) << "For '" << prim_name
<< "', sequence_length's dims must be 1, but got: " << sequence_length_shape.size() << ".";
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', sequence_length's dims must be 1, but got: " << sequence_length_shape.size() << ".";
}
if (inputs_x_shape[1] != sequence_length_shape[0]) {
MS_LOG(EXCEPTION) << "For '" << prim_name
<< "', inputs batch_size must be the same with sequence_length batch_size, "
<< "but now inputs batch_size: " << inputs_x_shape[1]
<< " and sequence_length batch_size: " << sequence_length_shape[0] << ".";
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', inputs batch_size must be the same with sequence_length batch_size, "
<< "but now inputs batch_size: " << inputs_x_shape[1]
<< " and sequence_length batch_size: " << sequence_length_shape[0] << ".";
}
return std::make_shared<abstract::AbstractTuple>(ret);

View File

@ -32,7 +32,7 @@ abstract::ShapePtr UpsampleNearest3DInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
(void)CheckAndConvertUtils::CheckInteger("dimension of x", SizeToLong(x_shape.size()), kEqual, SizeToLong(kDim5),
prim_name);
@ -68,6 +68,9 @@ abstract::ShapePtr UpsampleNearest3DInferShape(const PrimitivePtr &primitive,
<< " But get both.";
}
if (x_shape_ptr->IsDynamic()) {
return std::make_shared<abstract::Shape>(y_shape);
}
for (size_t i = 0; i < y_shape.size(); i++) {
(void)CheckAndConvertUtils::CheckInteger("output shape", y_shape[i], kGreaterThan, 0, prim_name);
}

View File

@ -255,7 +255,7 @@ from .math_func import (
approximate_equal,
frac,
kron,
rot90
rot90,
)
from .nn_func import (
adaptive_avg_pool2d,

View File

@ -3538,15 +3538,14 @@ def population_count(input_x):
Raises:
TypeError: If `input_x` is not a Tensor.
TypeError: If dtype of `input_x` is not int16, uint16 (Ascend).
If dtype of `input_x` is not int8, int16, int32, int64, uint8, uint16, uint32, uint64 (CPU).
If dtype of `input_x` is not int8, int16, int32, int64, uint8, uint16, uint32, uint64 (CPU and GPU).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x_input = Tensor([0, 1, 3], mindspore.int16)
>>> population_count = ops.PopulationCount()
>>> output = population_count(x_input)
>>> input_x = Tensor([0, 1, 3], mindspore.int16)
>>> output = ops.population_count(input_x)
>>> print(output)
[0 1 2]
"""

View File

@ -2340,7 +2340,7 @@ def trunc(x):
Returns a new tensor with the truncated integer values of the elements of input.
Args:
- **x** (Tensor) - Input_x is a tensor.
- **input_x** (Tensor) - input_x is a tensor.
Returns:
Tensor, the same shape and data type as the input.
@ -2352,9 +2352,8 @@ def trunc(x):
``Ascend`` ``CPU``
Examples:
>>> x = Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32)
>>> trunc = ops.Trunc()
>>> output = trunc(x)
>>> input_x = Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32)
>>> output = ops.trunc(input_x)
>>> print(output)
[ 3. 0. 0. -3.]
"""
@ -5402,6 +5401,6 @@ __all__ = [
'approximate_equal',
'frac',
'kron',
'rot90'
'rot90',
]
__all__.sort()

View File

@ -1819,7 +1819,7 @@ def grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zero
Examples:
>>> input_x = Tensor(np.arange(16).reshape((2, 2, 2, 2)).astype(np.float32))
>>> grid = Tensor(np.arange(0.2, 1, 0.1).reshape((2, 2, 1, 2)).astype(np.float32))
>>> output = grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zeros',
>>> output = ops.grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zeros',
align_corners=True)
>>> print(output)
[[[[ 1.9 ]
@ -1943,7 +1943,8 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
>>> inputs = Tensor(np.array([[[0.6, 0.4, 0.2], [0.8, 0.6, 0.3]],
... [[0.0, 0.6, 0.0], [0.5, 0.4, 0.5]]]), mindspore.float32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> decoded_indices, decoded_values, decoded_shape, log_probability = ctc_greedy_decode(inputs, sequence_length)
>>> decoded_indices, decoded_values, decoded_shape, log_probability = ops.ctc_greedy_decoder(inputs,
sequence_length)
>>> print(decoded_indices)
[[0 0]
[0 1]

View File

@ -5722,23 +5722,7 @@ class Trunc(Primitive):
"""
Returns a new tensor with the truncated integer values of the elements of input.
Inputs:
- **input_x** (Tensor) - Input_x is a tensor.
Outputs:
Tensor, the same shape and data type as the input.
Raises:
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> trunc = ops.Trunc()
>>> output = trunc(Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32))
>>> print(output)
[ 3. 0. 0. -3.]
Refer to :func:`mindspore.ops.trunc` for more detail.
"""
@prim_attr_register
@ -6685,17 +6669,19 @@ class FFTWithSize(Primitive):
IRFFT requires complex64 or complex128 inputs, return float32 or float64 outputs.
Args:
signal_ndim(int): The number of dimensions in each signal, this controls how many dimensions of the fourier
transform are realized, can only be 1, 2 or 3.
inverse(bool): Whether it is the inverse transformation, used to select FFT or IFFT and RFFT or IRFFT.
real(bool): Whether it is the real transformation, used to select FFT/IFFT or RFFT/IRFFT.
norm(str): The default normalization ("backward") has the direct (forward) transforms unscaled
signal_ndim (int): The number of dimensions in each signal, this controls how many dimensions of the fourier
transform are realized, can only be 1, 2 or 3.
inverse (bool): Whether it is the inverse transformation, used to select FFT or IFFT and RFFT or IRFFT.
inverse=False means FFT or RFFT, inverse=True means IFFT or IRFFT.
real (bool): Whether it is the real transformation, used to select FFT/IFFT or RFFT/IRFFT.
real=False means FFT or IFFT, real=True means RFFT or IRFFT.
norm (str): The default normalization ("backward") has the direct (forward) transforms unscaled
and the inverse (backward) transforms scaled by 1/n.
"ortho" has both direct and inverse transforms are scaled by 1/sqrt(n).
"forward" has the direct transforms scaled by 1/n and the inverse transforms unscaled.
n is the input x's element numbers.
onesided(bool): Controls whether the input is halved to avoid redundancy. Default: True.
signal_sizes(list): Size of the original signal (the signal before rfft, no batch dimension),
onesided (bool): Controls whether the input is halved to avoid redundancy. Default: True.
signal_sizes (list): Size of the original signal (the signal before rfft, no batch dimension),
only in irfft mode and set onesided=true requires the parameter. Default: [].
Inputs:

View File

@ -7228,52 +7228,7 @@ class CTCGreedyDecoder(Primitive):
r"""
Performs greedy decoding on the logits given in inputs.
Args:
merge_repeated (bool): If true, merge repeated classes in output. Default: True.
Inputs:
- **inputs** (Tensor) - The input Tensor must be a 3-D tensor whose shape is
:math:`(max\_time, batch\_size, num\_classes)`. `num_classes` must be `num_labels + 1` classes,
`num_labels` indicates the number of actual labels. Blank labels are reserved.
Default blank label is `num_classes - 1`. Data type must be float32 or float64.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch\_size, )`.
The type must be int32. Each value in the tensor must be equal to or less than `max_time`.
Outputs:
- **decoded_indices** (Tensor) - A tensor with shape of :math:`(total\_decoded\_outputs, 2)`.
Data type is int64.
- **decoded_values** (Tensor) - A tensor with shape of :math:`(total\_decoded\_outputs, )`,
it stores the decoded classes. Data type is int64.
- **decoded_shape** (Tensor) - A tensor with shape of :math:`(batch\_size, max\_decoded\_legth)`.
Data type is int64.
- **log_probability** (Tensor) - A tensor with shape of :math:`(batch\_size, 1)`,
containing sequence log-probability, has the same type as `inputs`.
Raises:
TypeError: If `merge_repeated` is not a bool.
ValueError: If length of shape of `inputs` is not equal to 3.
ValueError: If length of shape of `sequence_length` is not equal to 1.
Supported Platforms:
``Ascend``
Examples:
>>> inputs = Tensor(np.array([[[0.6, 0.4, 0.2], [0.8, 0.6, 0.3]],
... [[0.0, 0.6, 0.0], [0.5, 0.4, 0.5]]]), mindspore.float32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> ctc_greedyDecoder = ops.CTCGreedyDecoder()
>>> decoded_indices, decoded_values, decoded_shape, log_probability = ctc_greedyDecoder(inputs, sequence_length)
>>> print(decoded_indices)
[[0 0]
[0 1]
[1 0]]
>>> print(decoded_values)
[0 1 0]
>>> print(decoded_shape)
[2 2]
>>> print(log_probability)
[[-1.2]
[-1.3]]
Refer to :func:`mindspore.ops.ctc_greedy_decoder` for more detail.
"""
@prim_attr_register