!27184 [assistant][ops]Add Cummax

Merge pull request !27184 from 陈钧/cummax
This commit is contained in:
i-robot 2022-01-28 06:40:25 +00:00 committed by Gitee
commit 47a4860043
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 384 additions and 3 deletions

View File

@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/cummax_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void CummaxCPUKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output1_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
output2_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 1);
dim_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "dim");
}
template <typename T>
bool CummaxCPUKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output1_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
auto output2_data_addr = reinterpret_cast<int64_t *>(outputs[1]->addr);
const size_t dims = input_shape_.size();
if (dims == 0) {
MS_LOG(EXCEPTION) << "The value of `dims` can not be 0";
}
dim_ = (dim_ + dims) % dims;
std::vector<size_t> p{1};
for (int64_t i = (int64_t)input_shape_.size() - 1; i >= 0; i--)
p.push_back(p[(int64_t)input_shape_.size() - 1 - i] * input_shape_[i]);
reverse(p.begin(), p.end());
size_t input_stride = p[dim_ + 1];
size_t output1_stride = p[dim_ + 1];
size_t output2_stride = p[dim_ + 1];
size_t input_dim_size = input_shape_[dim_];
int exit_ok = 0;
std::vector<size_t> counter(dims, 0);
while (!exit_ok) {
T out = input_data_addr[0];
int idx = 0;
for (size_t i = 0; i < input_dim_size; i++) {
T cur = input_data_addr[i * input_stride];
if (cur >= out) {
out = cur;
idx = i;
}
output1_data_addr[i * output1_stride] = out;
output2_data_addr[i * output2_stride] = idx;
}
if (dims == 1) break;
for (size_t dim_i = 0; dim_i < dims; dim_i++) {
if (dim_i == dim_) {
if (dim_i == dims - 1) {
exit_ok = 1;
break;
}
continue;
}
counter[dim_i]++;
input_data_addr += p[dim_i + 1];
output1_data_addr += p[dim_i + 1];
output2_data_addr += p[dim_i + 1];
if (counter[dim_i] == input_shape_[dim_i]) {
if (dim_i == dims - 1) {
exit_ok = 1;
break;
} else {
input_data_addr -= counter[dim_i] * p[dim_i + 1];
output1_data_addr -= counter[dim_i] * p[dim_i + 1];
output2_data_addr -= counter[dim_i] * p[dim_i + 1];
counter[dim_i] = 0;
}
} else {
break;
}
}
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_
#include <memory>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class CummaxCPUKernelMod : public NativeCpuKernelMod {
public:
CummaxCPUKernelMod() = default;
~CummaxCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
private:
std::vector<size_t> input_shape_;
std::vector<size_t> output1_shape_;
std::vector<size_t> output2_shape_;
size_t dim_;
};
MS_REG_CPU_KERNEL_T(
Cummax,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, float);
MS_REG_CPU_KERNEL_T(
Cummax,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, float16);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, uint32_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_

View File

@ -119,6 +119,7 @@ constexpr auto kReshape = "Reshape";
constexpr auto kLstsq = "Lstsq";
constexpr auto kLowerBound = "LowerBound";
constexpr auto kUpperBound = "UpperBound";
constexpr auto kCummax = "Cummax";
// NN
constexpr auto kCTCLoss = "CTCLoss";
@ -350,6 +351,7 @@ MS_CORE_API inline const PrimitivePtr kPrimExtractVolumePatches = std::make_shar
MS_CORE_API inline const PrimitivePtr kPrimLstsq = std::make_shared<Primitive>(kLstsq);
MS_CORE_API inline const PrimitivePtr kPrimLowerBound = std::make_shared<Primitive>(kLowerBound);
MS_CORE_API inline const PrimitivePtr kPrimUpperBound = std::make_shared<Primitive>(kUpperBound);
MS_CORE_API inline const PrimitivePtr kPrimCummax = std::make_shared<Primitive>(kCummax);
// NN
MS_CORE_API inline const PrimitivePtr kPrimCeLU = std::make_shared<Primitive>("CeLU");

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/cummax.h"
#include <map>
#include <string>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto x_shape = input_args[0]->BuildShape();
auto x_shape_value = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape)[kShape];
auto dim = GetValue<int64_t>(primitive->GetAttr("dim"));
if (x_shape_value.size() <= 0) {
MS_EXCEPTION(ValueError) << "Inputs should not be a " << x_shape_value.size() << " dimensional tensor.";
}
if (dim >= static_cast<int64_t>(x_shape_value.size()) || dim < -static_cast<int64_t>(x_shape_value.size())) {
MS_EXCEPTION(ValueError) << "The value of `dim` should be in the range of ["
<< -static_cast<int64_t>(x_shape_value.size()) << ","
<< static_cast<int64_t>(x_shape_value.size()) << ")";
}
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_shape, x_shape});
}
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kUInt8, kUInt32, kFloat16, kFloat32};
auto y_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, op_name);
auto indices_type = kInt64;
return std::make_shared<Tuple>(std::vector<TypePtr>{y_type, indices_type});
}
AbstractBasePtr CummaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto types = InferType(primitive, input_args);
auto shapes = InferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Cummax, prim::kPrimCummax, CummaxInfer, nullptr, true);
} // namespace
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CUMMAX_H_
#define MINDSPORE_CORE_OPS_CUMMAX_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCummax = "Cummax";
class Cummax : public PrimitiveC {
public:
Cummax() : PrimitiveC(kNameCummax) { InitIOName({"x"}, {"y", "indices"}); }
~Cummax() = default;
MS_DECLARE_PARENT(Cummax, PrimitiveC);
};
AbstractBasePtr CummaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCummaxPtr = std::shared_ptr<Cummax>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CUMMAX_H_

View File

@ -103,3 +103,4 @@ from .upper_bound import _upper_bound_aicpu
from .grid_sampler_3d import _grid_sampler_3d_aicpu
from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu
from .cross import _cross_aicpu
from .cummax import _cummax_aicpu

View File

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cummax op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
cummax_op_info = AiCPURegOp("Cummax") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.output(1, "indices", "required") \
.attr("dim", "int") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(cummax_op_info)
def _cummax_aicpu():
"""Cummax AiCPU register"""
return

View File

@ -35,9 +35,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches,
LowerBound, UpperBound)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
Broadcast,
LowerBound, UpperBound, Cummax)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
@ -517,6 +516,7 @@ __all__ = [
"Custom",
"LuSolve",
"CholeskyInverse",
"Cummax",
]
__sponge__ = [

View File

@ -7017,3 +7017,63 @@ class UpperBound(Primitive):
valid_values = (mstype.int32, mstype.int64)
validator.check_type_name("out_type", out_type, valid_values, self.name)
self.init_prim_io_names(inputs=['sorted_x', 'values'], outputs=['y'])
class Cummax(Primitive):
"""
Computes the cumulative max and indice of input tensor along dim.Returns a tuple (values,indices) where 'values'
is the cumulative maximum value of input elements in the dimension 'dim'and 'indices' is the index position for
each maximum value.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
.. math::
y_i = max(x_1 , x_2 , x_3 ,... ,x_i)
Args:
dim (int): The dim to accumulate the tensor's value. Must be in the range [-rank(input), rank(input)).
The default value is -1.
Inputs:
- **input** (Tensor) - The input tensor whose dtype is int8, int32, int64, uint8, uint32, float16, float32.
Outputs:
- **values** (Tensor), the shape of the output tensor is consistent with the input tensor's.
- **indices** (Tensor), the shape of the output tensor is consistent with the input tensor's.
Raises:
TypeError: If `input` is not a Tensor.
TypeError: If `dim` is not an int.
ValueError: If `dim` is out of range, `dim` should be [-len(input.shape), len(input.shape)-1].
Supported Platforms:
``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.ops as ops
>>> cummax = ops.Cummax(dim=0)
>>> x = Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float32))
>>> output = cummax(x)
>>> print(output)
values:
[[ 3. 4. 6. 10.]
[ 3. 6. 7. 10.]
[ 4. 6. 8. 10.]
[ 4. 6. 8. 10.]]
indices:
[[0 0 0 0]
[0 1 1 0]
[2 1 2 0]
[2 1 2 0]]
"""
@prim_attr_register
def __init__(self, dim=-1):
"""Initialize Cummax"""
validator.check_value_type("dim", dim, [int], self.name)
self.init_prim_io_names(inputs=['x'], outputs=['y', 'indices'])

View File

@ -2803,6 +2803,11 @@ test_case_array_ops = [
Tensor([[3], [6], [7], [8]], mstype.int8)],
'skip': ['backward'],
}),
('Cummax', {
'block': P.Cummax(dim=-1),
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])],
'skip': ['backward'],
}),
]
test_case_image_ops = [