forked from mindspore-Ecosystem/mindspore
!27184 [assistant][ops]Add Cummax
Merge pull request !27184 from 陈钧/cummax
This commit is contained in:
commit
47a4860043
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/cummax_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void CummaxCPUKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output1_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
output2_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 1);
|
||||
dim_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "dim");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CummaxCPUKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output1_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto output2_data_addr = reinterpret_cast<int64_t *>(outputs[1]->addr);
|
||||
|
||||
const size_t dims = input_shape_.size();
|
||||
if (dims == 0) {
|
||||
MS_LOG(EXCEPTION) << "The value of `dims` can not be 0";
|
||||
}
|
||||
dim_ = (dim_ + dims) % dims;
|
||||
std::vector<size_t> p{1};
|
||||
|
||||
for (int64_t i = (int64_t)input_shape_.size() - 1; i >= 0; i--)
|
||||
p.push_back(p[(int64_t)input_shape_.size() - 1 - i] * input_shape_[i]);
|
||||
reverse(p.begin(), p.end());
|
||||
|
||||
size_t input_stride = p[dim_ + 1];
|
||||
size_t output1_stride = p[dim_ + 1];
|
||||
size_t output2_stride = p[dim_ + 1];
|
||||
size_t input_dim_size = input_shape_[dim_];
|
||||
|
||||
int exit_ok = 0;
|
||||
std::vector<size_t> counter(dims, 0);
|
||||
|
||||
while (!exit_ok) {
|
||||
T out = input_data_addr[0];
|
||||
int idx = 0;
|
||||
for (size_t i = 0; i < input_dim_size; i++) {
|
||||
T cur = input_data_addr[i * input_stride];
|
||||
if (cur >= out) {
|
||||
out = cur;
|
||||
idx = i;
|
||||
}
|
||||
output1_data_addr[i * output1_stride] = out;
|
||||
output2_data_addr[i * output2_stride] = idx;
|
||||
}
|
||||
|
||||
if (dims == 1) break;
|
||||
for (size_t dim_i = 0; dim_i < dims; dim_i++) {
|
||||
if (dim_i == dim_) {
|
||||
if (dim_i == dims - 1) {
|
||||
exit_ok = 1;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
counter[dim_i]++;
|
||||
input_data_addr += p[dim_i + 1];
|
||||
output1_data_addr += p[dim_i + 1];
|
||||
output2_data_addr += p[dim_i + 1];
|
||||
|
||||
if (counter[dim_i] == input_shape_[dim_i]) {
|
||||
if (dim_i == dims - 1) {
|
||||
exit_ok = 1;
|
||||
break;
|
||||
} else {
|
||||
input_data_addr -= counter[dim_i] * p[dim_i + 1];
|
||||
output1_data_addr -= counter[dim_i] * p[dim_i + 1];
|
||||
output2_data_addr -= counter[dim_i] * p[dim_i + 1];
|
||||
counter[dim_i] = 0;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CummaxCPUKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CummaxCPUKernelMod() = default;
|
||||
~CummaxCPUKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output1_shape_;
|
||||
std::vector<size_t> output2_shape_;
|
||||
size_t dim_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, uint32_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_
|
|
@ -119,6 +119,7 @@ constexpr auto kReshape = "Reshape";
|
|||
constexpr auto kLstsq = "Lstsq";
|
||||
constexpr auto kLowerBound = "LowerBound";
|
||||
constexpr auto kUpperBound = "UpperBound";
|
||||
constexpr auto kCummax = "Cummax";
|
||||
|
||||
// NN
|
||||
constexpr auto kCTCLoss = "CTCLoss";
|
||||
|
@ -350,6 +351,7 @@ MS_CORE_API inline const PrimitivePtr kPrimExtractVolumePatches = std::make_shar
|
|||
MS_CORE_API inline const PrimitivePtr kPrimLstsq = std::make_shared<Primitive>(kLstsq);
|
||||
MS_CORE_API inline const PrimitivePtr kPrimLowerBound = std::make_shared<Primitive>(kLowerBound);
|
||||
MS_CORE_API inline const PrimitivePtr kPrimUpperBound = std::make_shared<Primitive>(kUpperBound);
|
||||
MS_CORE_API inline const PrimitivePtr kPrimCummax = std::make_shared<Primitive>(kCummax);
|
||||
|
||||
// NN
|
||||
MS_CORE_API inline const PrimitivePtr kPrimCeLU = std::make_shared<Primitive>("CeLU");
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/cummax.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_shape = input_args[0]->BuildShape();
|
||||
auto x_shape_value = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape)[kShape];
|
||||
auto dim = GetValue<int64_t>(primitive->GetAttr("dim"));
|
||||
if (x_shape_value.size() <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "Inputs should not be a " << x_shape_value.size() << " dimensional tensor.";
|
||||
}
|
||||
if (dim >= static_cast<int64_t>(x_shape_value.size()) || dim < -static_cast<int64_t>(x_shape_value.size())) {
|
||||
MS_EXCEPTION(ValueError) << "The value of `dim` should be in the range of ["
|
||||
<< -static_cast<int64_t>(x_shape_value.size()) << ","
|
||||
<< static_cast<int64_t>(x_shape_value.size()) << ")";
|
||||
}
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_shape, x_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kUInt8, kUInt32, kFloat16, kFloat32};
|
||||
auto y_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, op_name);
|
||||
auto indices_type = kInt64;
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{y_type, indices_type});
|
||||
}
|
||||
|
||||
AbstractBasePtr CummaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto types = InferType(primitive, input_args);
|
||||
auto shapes = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Cummax, prim::kPrimCummax, CummaxInfer, nullptr, true);
|
||||
} // namespace
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CUMMAX_H_
|
||||
#define MINDSPORE_CORE_OPS_CUMMAX_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCummax = "Cummax";
|
||||
class Cummax : public PrimitiveC {
|
||||
public:
|
||||
Cummax() : PrimitiveC(kNameCummax) { InitIOName({"x"}, {"y", "indices"}); }
|
||||
~Cummax() = default;
|
||||
MS_DECLARE_PARENT(Cummax, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr CummaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimCummaxPtr = std::shared_ptr<Cummax>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_CUMMAX_H_
|
|
@ -103,3 +103,4 @@ from .upper_bound import _upper_bound_aicpu
|
|||
from .grid_sampler_3d import _grid_sampler_3d_aicpu
|
||||
from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu
|
||||
from .cross import _cross_aicpu
|
||||
from .cummax import _cummax_aicpu
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Cummax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
cummax_op_info = AiCPURegOp("Cummax") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.output(1, "indices", "required") \
|
||||
.attr("dim", "int") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(cummax_op_info)
|
||||
def _cummax_aicpu():
|
||||
"""Cummax AiCPU register"""
|
||||
return
|
|
@ -35,9 +35,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
|
||||
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches,
|
||||
LowerBound, UpperBound)
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
|
||||
Broadcast,
|
||||
LowerBound, UpperBound, Cummax)
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
||||
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
|
||||
|
@ -517,6 +516,7 @@ __all__ = [
|
|||
"Custom",
|
||||
"LuSolve",
|
||||
"CholeskyInverse",
|
||||
"Cummax",
|
||||
]
|
||||
|
||||
__sponge__ = [
|
||||
|
|
|
@ -7017,3 +7017,63 @@ class UpperBound(Primitive):
|
|||
valid_values = (mstype.int32, mstype.int64)
|
||||
validator.check_type_name("out_type", out_type, valid_values, self.name)
|
||||
self.init_prim_io_names(inputs=['sorted_x', 'values'], outputs=['y'])
|
||||
|
||||
|
||||
class Cummax(Primitive):
|
||||
"""
|
||||
Computes the cumulative max and indice of input tensor along dim.Returns a tuple (values,indices) where 'values'
|
||||
is the cumulative maximum value of input elements in the dimension 'dim'and 'indices' is the index position for
|
||||
each maximum value.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
.. math::
|
||||
|
||||
y_i = max(x_1 , x_2 , x_3 ,... ,x_i)
|
||||
|
||||
Args:
|
||||
dim (int): The dim to accumulate the tensor's value. Must be in the range [-rank(input), rank(input)).
|
||||
The default value is -1.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input tensor whose dtype is int8, int32, int64, uint8, uint32, float16, float32.
|
||||
|
||||
Outputs:
|
||||
- **values** (Tensor), the shape of the output tensor is consistent with the input tensor's.
|
||||
- **indices** (Tensor), the shape of the output tensor is consistent with the input tensor's.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input` is not a Tensor.
|
||||
TypeError: If `dim` is not an int.
|
||||
ValueError: If `dim` is out of range, `dim` should be [-len(input.shape), len(input.shape)-1].
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.ops as ops
|
||||
>>> cummax = ops.Cummax(dim=0)
|
||||
>>> x = Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float32))
|
||||
>>> output = cummax(x)
|
||||
>>> print(output)
|
||||
values:
|
||||
[[ 3. 4. 6. 10.]
|
||||
[ 3. 6. 7. 10.]
|
||||
[ 4. 6. 8. 10.]
|
||||
[ 4. 6. 8. 10.]]
|
||||
indices:
|
||||
[[0 0 0 0]
|
||||
[0 1 1 0]
|
||||
[2 1 2 0]
|
||||
[2 1 2 0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, dim=-1):
|
||||
"""Initialize Cummax"""
|
||||
validator.check_value_type("dim", dim, [int], self.name)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y', 'indices'])
|
||||
|
|
|
@ -2803,6 +2803,11 @@ test_case_array_ops = [
|
|||
Tensor([[3], [6], [7], [8]], mstype.int8)],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
('Cummax', {
|
||||
'block': P.Cummax(dim=-1),
|
||||
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_image_ops = [
|
||||
|
|
Loading…
Reference in New Issue