!43978 Fix doc && bug of CumSum.

Merge pull request !43978 from hezhenhao1/add_cumsum
This commit is contained in:
i-robot 2022-10-18 02:26:10 +00:00 committed by Gitee
commit 792c41d95c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 111 additions and 5 deletions

View File

@ -208,6 +208,7 @@ Reduction函数
mindspore.ops.argmin
mindspore.ops.cummax
mindspore.ops.cummin
mindspore.ops.cumsum
mindspore.ops.logsumexp
mindspore.ops.max
mindspore.ops.mean

View File

@ -0,0 +1,24 @@
mindspore.ops.cumsum
====================
.. py:function:: mindspore.ops.cumsum(x, axis, dtype=None)
计算输入张量 `x` 沿维度 `axis` 的累积和。
.. math::
y_i = x_1 + x_2 + x_3 + ... + x_i
.. note::
目前Ascend平台上对于静态shape的场景 `x` 的数据类型暂仅支持int8、uint8、int32float32和float16对于动态shape的场景 `x` 的数据类型暂仅支持int32、float32和float16。
参数:
- **x** (Tensor) - 输入要累积和的Tensor。
- **axis** (int) - 累积和计算的维度。
- **dtype** (:class:`mindspore.dtype`, optional) - 输出数据类型。如果不为None则输入会转化为 `dtype`。这有利于防止数值溢出。如果为None则输出和输入的数据类型一致。默认值None。
返回:
Tensor和输入Tensor的形状相同。
异常:
- **TypeError** - 如果 `x` 不是Tensor。
- **ValueError** - 如果 `axis` 超出范围。

View File

@ -208,6 +208,7 @@ Reduction Functions
mindspore.ops.argmin
mindspore.ops.cummax
mindspore.ops.cummin
mindspore.ops.cumsum
mindspore.ops.logsumexp
mindspore.ops.max
mindspore.ops.mean

View File

@ -178,3 +178,11 @@ template CUDA_LIB_EXPORT void CumSum<float>(const float *input, float *output, f
template CUDA_LIB_EXPORT void CumSum<half>(const half *input, half *output, half *workspace, size_t dim0, size_t dim1,
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
const uint32_t &device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void CumSum<Complex<float>>(const Complex<float> *input, Complex<float> *output,
Complex<float> *workspace, size_t dim0, size_t dim1, size_t dim2,
size_t stride, size_t stride2, bool exclusive_, bool reverse_,
const uint32_t &device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void CumSum<Complex<double>>(const Complex<double> *input, Complex<double> *output,
Complex<double> *workspace, size_t dim0, size_t dim1, size_t dim2,
size_t stride, size_t stride2, bool exclusive_, bool reverse_,
const uint32_t &device_id, cudaStream_t stream);

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMSUM_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMSUM_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
template <typename T>
CUDA_LIB_EXPORT void CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2,
size_t stride, size_t stride2, bool exclusive_, bool reverse_, const uint32_t &device_id,

View File

@ -150,6 +150,10 @@ std::vector<std::pair<KernelAttr, CumSumGpuKernelMod::CumSumLaunchFunc>> CumSumG
&CumSumGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&CumSumGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&CumSumGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&CumSumGpuKernelMod::LaunchKernel<utils::Complex<double>>},
// Dynamic shape related.
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
&CumSumGpuKernelMod::LaunchKernel<int8_t>},
@ -173,6 +177,10 @@ std::vector<std::pair<KernelAttr, CumSumGpuKernelMod::CumSumLaunchFunc>> CumSumG
&CumSumGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&CumSumGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
&CumSumGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
&CumSumGpuKernelMod::LaunchKernel<utils::Complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
&CumSumGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
@ -195,6 +203,10 @@ std::vector<std::pair<KernelAttr, CumSumGpuKernelMod::CumSumLaunchFunc>> CumSumG
&CumSumGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&CumSumGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
&CumSumGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
&CumSumGpuKernelMod::LaunchKernel<utils::Complex<double>>},
};
std::vector<KernelAttr> CumSumGpuKernelMod::GetOpSupport() {

View File

@ -1,5 +1,6 @@
# cuda
find_package(CUDA)
add_compile_definitions(ENABLE_GPU)
file(GLOB_RECURSE CUDA_KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/*.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cu
@ -10,6 +11,6 @@ file(GLOB_RECURSE CUDA_KERNEL_SRC
)
set_source_files_properties(${CUDA_KERNEL_SRC} PROPERTIES CUDA_SOURCE_PROPERTY_FORMAT OBJ)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGES} -std=c++14 -fPIC")
SET(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-std=c++14;-arch=sm_53)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGES} -fPIC")
SET(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-arch=sm_53)
cuda_add_library(cuda_kernel_mid STATIC ${CUDA_KERNEL_SRC})

View File

@ -255,6 +255,7 @@ from .math_func import (
baddbmm,
cummin,
cummax,
cumsum,
amin,
amax,
mean,

View File

@ -163,6 +163,7 @@ sparse_segment_mean_ = SparseSegmentMean()
xlogy_ = P.Xlogy()
square_ = P.Square()
sqrt_ = P.Sqrt()
cumsum_ = P.CumSum()
#####################################
@ -4253,6 +4254,62 @@ def cummax(x, axis):
return _cummax(x)
def cumsum(x, axis, dtype=None):
"""
Computes the cumulative sum of input Tensor along axis.
.. math::
y_i = x_1 + x_2 + x_3 + ... + x_i
Note:
On Ascend, the dtype of `x` only support :int8, uint8, int32, float16 or float32 in case of static shape.
For the case of dynamic shape, the dtype of `x` only support int32, float16 or float32.
Args:
x (Tensor): The input Tensor to accumulate.
axis (int): Axis along which the cumulative sum is computed.
dtype (:class:`mindspore.dtype`, optional): The desired dtype of returned Tensor. If specified,
the input Tensor will be cast to `dtype` before the computation. This is useful for preventing overflows.
If not specified, stay the same as original Tensor. Default: None.
Returns:
Tensor, the shape of the output Tensor is consistent with the input Tensor's.
Raises:
TypeError: If `x` is not a Tensor.
ValueError: If the axis is out of range.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.ops as ops
>>> x = Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float32))
>>> # case 1: along the axis 0
>>> y = ops.cumsum(x, 0)
>>> print(y)
[[ 3. 4. 6. 10.]
[ 4. 10. 13. 19.]
[ 8. 13. 21. 26.]
[ 9. 16. 28. 35.]]
>>> # case 2: along the axis 1
>>> y = ops.cumsum(x, 1)
>>> print(y)
[[ 3. 7. 13. 23.]
[ 1. 7. 14. 23.]
[ 4. 7. 15. 22.]
[ 1. 4. 11. 20.]]
"""
if dtype is not None and x.dtype != dtype:
x = x.astype(dtype, copy=False)
validator.check_axis_in_range(axis, x.ndim)
return cumsum_(x, axis)
def sparse_segment_mean(x, indices, segment_ids):
r"""
Computes a Tensor such that :math:`output_i = \frac{\sum_j x_{indices[j]}}{N}` where mean is over :math:`j` such
@ -6069,7 +6126,7 @@ def dotrapezoid(y, dx, dim):
y_left = select_(y, dim, 0)
y_right = select_(y, dim, -1)
y_sum = y.sum(dim)
return (y_sum - (y_left + y_right) * 0.5) * dx
return (y_sum - (y_left + y_right) * 0.5) * dx
def dotrapezoid_tensor(y, dx, dim):
@ -6096,7 +6153,7 @@ def add_padding_to_shape(curr_shape, target_n_dim):
def zeros_like_except(y, dim):
_check_dim_in_range(dim, y.ndim)
dim = dim + y.ndim if dim < 0 else dim
sizes = y.shape[:dim] + y.shape[dim+1:]
sizes = y.shape[:dim] + y.shape[dim + 1:]
zeros = P.Zeros()(sizes, y.dtype)
return zeros
@ -6587,6 +6644,7 @@ __all__ = [
'baddbmm',
'cummin',
'cummax',
'cumsum',
'amin',
'amax',
'mean',

View File

@ -300,7 +300,6 @@ bool_not = Primitive("bool_not")
bool_or = Primitive("bool_or")
bool_and = Primitive("bool_and")
bool_eq = Primitive("bool_eq")
cumsum = P.CumSum()
cumprod = P.CumProd()
array_to_scalar = Primitive('array_to_scalar')
is_ = Primitive("is_")