!32929 [assistant][ops]Add math operator CumulativeLogsumexp

Merge pull request !32929 from 孟权令/CumulativeLogsumexp
This commit is contained in:
i-robot 2022-09-22 11:17:47 +00:00 committed by Gitee
commit 57343fad77
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 514 additions and 1 deletions

View File

@ -0,0 +1,180 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.h"
#include <cmath>
#include <string>
#include <thread>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kCumulativeLogsumexpInputsNum = 2;
constexpr size_t kCumulativeLogsumexpOutputsNum = 1;
constexpr size_t kAxisDimension = 1;
constexpr size_t kAxisShapeSize = 1;
const float float16_exclusive_data = -65504e+0;
const float float_exclusive_data = -3.4028235e+38;
const double double_exclusive_data = -1.7976931348623157e+308;
} // namespace
void CumulativeLogsumexpCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
exclusive_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, EXCLUSIVE);
reverse_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, REVERSE);
}
bool CumulativeLogsumexpCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCumulativeLogsumexpInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCumulativeLogsumexpOutputsNum, kernel_name_);
if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', kernel data type " << TypeIdLabel(dtype_) << "not support.";
}
return true;
}
template <typename t>
void CumulativeLogsumexpCpuKernelMod::CumulativeProcess(t *input_data, t *output_data, uint32_t outer, uint32_t inner,
uint32_t depth) {
for (size_t outer_index = 0; outer_index < outer; ++outer_index) {
size_t outer_index_adj;
if (reverse_) {
outer_index_adj = (outer - 1) - outer_index;
} else {
outer_index_adj = outer_index;
}
for (size_t inner_index = 0; inner_index < inner; ++inner_index) {
double one = 1;
double temp = 0;
size_t inner_index_adj;
if (reverse_) {
inner_index_adj = (inner - 1) - inner_index;
} else {
inner_index_adj = inner_index;
}
for (size_t depth_index = 0; depth_index < depth; ++depth_index) {
size_t depth_index_adj;
if (reverse_) {
depth_index_adj = (depth - 1) - depth_index;
} else {
depth_index_adj = depth_index;
}
size_t index = outer_index_adj;
index += inner_index_adj * depth * outer;
index += depth_index_adj * outer;
if (exclusive_) {
if (depth_index == 0) {
if (dtype_ == kNumberTypeFloat16) {
output_data[index] = static_cast<t>(float16_exclusive_data);
} else if (dtype_ == kNumberTypeFloat32) {
output_data[index] = static_cast<t>(float_exclusive_data);
} else {
output_data[index] = static_cast<t>(double_exclusive_data);
}
temp = static_cast<double>(input_data[index]);
} else {
output_data[index] = static_cast<t>(temp);
double a = temp;
double b, min, max;
b = static_cast<double>(input_data[index]);
if (a < b) {
min = a;
max = b;
} else {
min = b;
max = a;
}
temp = log(one + exp(min - max)) + max;
}
} else {
if (depth_index == 0) {
output_data[index] = input_data[index];
temp = static_cast<double>(input_data[index]);
} else {
double a = temp;
double b, min, max;
b = static_cast<double>(input_data[index]);
if (a < b) {
min = a;
max = b;
} else {
min = b;
max = a;
}
output_data[index] = static_cast<t>(log(one + exp(min - max)) + max);
temp = log(one + exp(min - max)) + max;
}
}
}
}
}
}
template <typename T>
void CumulativeLogsumexpCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto *input_data = static_cast<T *>(inputs[kIndex0]->addr);
auto axis_ = static_cast<int32_t *>(inputs[kIndex1]->addr);
auto *output_data = static_cast<T *>(outputs[kIndex0]->addr);
size_t lens = inputs[kIndex0]->size > 0 ? static_cast<size_t>(inputs[kIndex0]->size / sizeof(T)) : 1;
auto task = [this, input_data, axis_, output_data](const size_t start, const size_t end) {
int32_t x_rank = SizeToInt(shape_.size());
if (axis_[0] >= x_rank || axis_[0] < -x_rank) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", 'axis' must be in range [" << -x_rank << ", " << x_rank
<< "), but got: " << axis_[0];
}
if (axis_[0] < 0) {
axis_[0] += x_rank;
}
uint32_t inner = 1;
uint32_t depth = shape_[IntToSize(axis_[0])];
uint32_t outer = 1;
for (size_t i = 0; i < IntToSize(axis_[0]); i++) {
inner *= shape_[i];
}
for (size_t i = IntToSize(axis_[0]) + 1; i < shape_.size(); i++) {
outer *= shape_[i];
}
CumulativeProcess<T>(input_data, output_data, outer, inner, depth);
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
}
std::vector<KernelAttr> CumulativeLogsumexpCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CumulativeLogsumexp, CumulativeLogsumexpCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMULATIVE_LOGSUMEXP_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMULATIVE_LOGSUMEXP_CPU_KERNEL_H_
#include <memory>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class CumulativeLogsumexpCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
CumulativeLogsumexpCpuKernelMod() = default;
~CumulativeLogsumexpCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename t>
void CumulativeProcess(t *input_data, t *output_data, uint32_t outer, uint32_t inner, uint32_t depth);
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
ShapeVector shape_;
bool exclusive_{false};
bool reverse_{false};
TypeId dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -80,6 +80,7 @@ constexpr auto kExp = "Exp";
constexpr auto kEqual = "Equal";
constexpr auto kNotEqual = "NotEqual";
constexpr auto kNeg = "Neg";
constexpr auto kCumulativeLogsumexp = "CumulativeLogsumexp";
constexpr auto kSub = "Sub";
constexpr auto kMedian = "Median";
constexpr auto kMedianGrad = "MedianGrad";
@ -1142,6 +1143,7 @@ GVAR_DEF(PrimitivePtr, kPrimHistogram, std::make_shared<Primitive>("Histogram"))
GVAR_DEF(PrimitivePtr, kPrimMaximum, std::make_shared<Primitive>("Maximum"));
GVAR_DEF(PrimitivePtr, kPrimSquare, std::make_shared<Primitive>(kSquare));
GVAR_DEF(PrimitivePtr, kPrimCumSum, std::make_shared<Primitive>("CumSum"));
GVAR_DEF(PrimitivePtr, kPrimCumulativeLogsumexp, std::make_shared<Primitive>(kCumulativeLogsumexp));
GVAR_DEF(PrimitivePtr, kPrimCumProd, std::make_shared<Primitive>("CumProd"));
GVAR_DEF(PrimitivePtr, kPrimSubscalar, std::make_shared<Primitive>("Subscalar"));
GVAR_DEF(PrimitivePtr, kPrimInplaceAdd, std::make_shared<Primitive>("InplaceAdd"));

View File

@ -0,0 +1,66 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/cumulative_logsumexp.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr CumulativeLogsumexpInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
const int64_t min_dim = 1;
const int64_t kAxisDim = 0;
(void)CheckAndConvertUtils::CheckInteger("input x rank", SizeToLong(x_shape.size()), kGreaterEqual, min_dim,
prim_name);
auto axis_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto axis_dim = SizeToLong(axis_shape.size());
(void)CheckAndConvertUtils::CheckInteger("axis dimension", axis_dim, kEqual, kAxisDim, prim_name);
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr CumulativeLogsumexpInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
const std::set<TypePtr> valid_types = {kFloat32, kFloat16, kFloat64};
auto x_type = input_args[kInputIndex0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
const std::set<TypePtr> axis_valid_types = {kInt64, kInt32, kInt16};
auto axis_type = input_args[kInputIndex1]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("axis", axis_type, axis_valid_types, prim_name);
return x_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(CumulativeLogsumexp, BaseOperator);
AbstractBasePtr CumulativeLogsumexpInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 2;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto type = CumulativeLogsumexpInferType(primitive, input_args);
auto shape = CumulativeLogsumexpInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(CumulativeLogsumexp, prim::kPrimCumulativeLogsumexp, CumulativeLogsumexpInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_
#define MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCumulativeLogsumexp = "CumulativeLogsumexp";
class MIND_API CumulativeLogsumexp : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CumulativeLogsumexp);
CumulativeLogsumexp() : BaseOperator(kNameCumulativeLogsumexp) { InitIOName({"x", "axis"}, {"y"}); }
void Init() const {}
};
abstract::AbstractBasePtr CumulativeLogsumexpInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_

View File

@ -18,7 +18,7 @@
from mindspore.common import dtype as mstype
from mindspore.scipy.ops import SolveTriangular
from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
from mindspore import nn
from mindspore import nn, ops, Tensor
import mindspore.numpy as mnp
import numpy as np
from ...nn import LGamma
@ -49,6 +49,7 @@ from ..operations.math_ops import Hypot
from ..operations.math_ops import ReduceStd
from ..operations.math_ops import LuUnpack
from ..operations.math_ops import MatrixExp
from ..operations.math_ops import CumulativeLogsumexp
from ..operations.math_ops import MatrixSolve
from ..operations.math_ops import MatrixPower
from ..operations.math_ops import Median
@ -387,6 +388,56 @@ def get_bprop_lp_norm(self):
return bprop
@bprop_getters.register(CumulativeLogsumexp)
def get_brop_cumulative_logsumexp(self):
"""Generate bprop for CumulativeLogsumexp"""
exp_op = P.Exp()
greater_op = P.Greater()
log_op = P.Log()
cumulative_op = CumulativeLogsumexp(self.exclusive, not self.reverse)
less_op = P.Less()
neg_op = P.Neg()
def where_v2(condition, x=None, y=None):
return_all = None
if x is None and y is None:
return_all = mnp.where(condition, x, y)
elif x is not None and y is not None:
shape_ = x.shape
input_y = np.resize(y, shape_)
input_y = Tensor(input_y).astype(x.dtype)
return_all = ops.select(condition, x, input_y)
else:
raise ValueError("x and y must both be non-None or both be None.")
return return_all
def bprop(x, axis, out, dout):
dtype_min = 0
fp64_flag = False
if x.dtype == mstype.float16:
dtype_min = -65500e+0
elif x.dtype == mstype.float32:
dtype_min = -3.4028235e+38
elif x.dtype == mstype.float64:
dout = F.cast(dout, mstype.float32)
x = F.cast(x, mstype.float32)
out = F.cast(out, mstype.float32)
dtype_min = -3.4028235e+38
fp64_flag = True
log_grad_positive = where_v2(greater_op(dout, 0), log_op(dout), dtype_min)
log_grad_negative = where_v2(less_op(dout, 0), log_op(neg_op(dout)), dtype_min)
output_pos = exp_op(cumulative_op(log_grad_positive - out, axis) + x)
output_neg = exp_op(cumulative_op(log_grad_negative - out, axis) + x)
if fp64_flag:
output_pos = F.cast(output_pos, mstype.float64)
output_neg = F.cast(output_neg, mstype.float64)
x = F.cast(x, mstype.float64)
return (output_pos - output_neg, zeros_like(x))
return bprop
@bprop_getters.register(MatrixExp)
def get_bprop_matrix_exp(self):
"""Gegerate brop for MatrixExp"""

View File

@ -0,0 +1,39 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CumulativeLogsumexp op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
cumulative_logsumexp_op_info = AiCPURegOp("CumulativeLogsumexp") \
.fusion_type("OPAQUE") \
.attr("exclusive", "bool") \
.attr("reverse", "bool") \
.input(0, "x", "required") \
.input(1, "axis", "required")\
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.F16_Default, DataType.I16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I16_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(cumulative_logsumexp_op_info)
def _cumulative_logsumexp_aicpu():
"""CumulativeLogsumexp aicpu register"""
return

View File

@ -615,6 +615,80 @@ class ReduceMean(_Reduce):
super(ReduceMean, self).__init__(keep_dims)
class CumulativeLogsumexp(Primitive):
"""
Compute the cumulative product of the tensor `x` along `axis`.
When `exclusive` is set `False`, this operation performs an inclusive cumulative log-sum-exp, which means that the
first element of the input is identical to the first element of the output. For example, when takes a tensor
[a, b, c] as input, this operation outputs [a, log(exp(a) + exp(b)), log(exp(a) + exp(b) + exp(c))]. When `reverse`
is set `True`, the cumulative log-sum-exp is performed in the opposite direction and thus get the output
[log(exp(a) + exp(b) + exp(c)), log(exp(b) + exp(c)), c].
When `exclusive` is set `True`, this operation performs an exclusive cumulative log-sum-exp instead. For example,
when takes a tensor [a, b, c] as input, this operation outputs [-inf, a, log(exp(a) * exp(b))]. Note that the
neutral element of the log-sum-exp operation is -inf, however, for performance reasons, the minimal value
representable by the floating point type is used instead. When `reverse` is set `True`, the cumulative log-sum-exp
is performed in the opposite direction and thus get the output [log(exp(b) * exp(c)), c, -inf].
Args:
exclusive (bool): If true, perform exclusive cumulative log-sum-exp.
If false, perform inclusive cumulative log-sum-exp. Default: False.
reverse (bool): If true, the cumulative log-sum-exp is performed in the opposite direction.
If false, the cumulative log-sum-exp is performed in the forward direction. Default: False.
Inputs:
- **x** (Tensor) - The input tensor. Must be one of the following types: float16, float32, float64.
The dimension of `x` must greater than 0.
- **axis** (Tensor) - A 0-D tensor describing the dimension to compute the cumulative product. Must be one of
the following types: int64, int32, int16. Must be in the range [-rank(x), rank(x)). Default: 0.
Outputs:
Tensor, has the same dtype and shape as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not in [float16, float32, float64].
TypeError: If `axis` is not a Tensor.
TypeError: If dtype of `axis` is not in [int16, int32, int64].
TypeError: If `exclusive` or `reverse` is not a bool.
ValueError: If the dimension of `x` is not greater than 1.
RuntimeError: If `axis` is out of range [-rank(x), rank(x)).
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
>>> op = ops.CumulativeLogsumexp(exclusive=False, reverse=False)
>>> output = op(x, Tensor(0))
>>> print(output)
[1. 2.3132617 3.407606 ]
>>> x = Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
>>> op = ops.CumulativeLogsumexp(exclusive=True, reverse=False)
>>> output = op(x, Tensor(0))
>>> print(output)
[-3.4028235e+38 1.0000000e+00 2.3132617e+00]
>>> x = Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
>>> op = ops.CumulativeLogsumexp(exclusive=False, reverse=True)
>>> output = op(x, Tensor(0))
>>> print(output)
[3.407606 3.3132617 3. ]
>>> x = Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
>>> op = ops.CumulativeLogsumexp(exclusive=True, reverse=True)
>>> output = op(x, Tensor(0))
>>> print(output)
[ 3.3132617e+00 3.0000000e+00 -3.4028235e+38]
"""
@prim_attr_register
def __init__(self, exclusive=False, reverse=False):
"""Initialize CumulativeLogsumexp"""
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
validator.check_bool(exclusive, "exclusive", self.name)
validator.check_bool(reverse, "reverse", self.name)
class ReduceSum(_Reduce):
"""
Reduces a dimension of a tensor by summing all elements in the dimension, by default. And also can reduce a

View File

@ -37,6 +37,7 @@ from mindspore.ops.operations import _quant_ops as Q
from mindspore.ops.operations.math_ops import BesselJ0, BesselJ1, BesselK0, BesselK1, BesselK0e, \
BesselI0, BesselI1, BesselK1e, BesselY0, BesselY1, Bucketize
from mindspore.ops.operations.math_ops import ReduceStd
from mindspore.ops.operations.math_ops import CumulativeLogsumexp
from mindspore.ops.operations.math_ops import Sinc
from mindspore.ops.operations.array_ops import ConjugateTranspose
from mindspore.ops.operations.array_ops import UnravelIndex
@ -2226,6 +2227,11 @@ test_case_math_ops = [
'block': P.Round(),
'desc_inputs': [[3]],
'desc_bprop': [[3]]}),
('CumulativeLogsumexp', {
'block': CumulativeLogsumexp(exclusive=False, reverse=False),
'desc_inputs': [Tensor(np.array([1.0, 2.0, 3.0], np.float32)),
Tensor(0, dtype=mstype.int32)],
'desc_bprop': [Tensor(np.array([1.0, 2.0, 3.0], np.float32))]}),
('Atan2', {
'block': P.Atan2(),
'desc_inputs': [Tensor(np.array([0, 1]).astype(np.float32)),