[feat] [assistant] [I40GGH] add new ascend operator Cummin

This commit is contained in:
luon 2021-09-02 18:22:54 +08:00
parent 50f4f96c77
commit 5485ff0cb6
10 changed files with 273 additions and 7 deletions

View File

@ -173,6 +173,7 @@ inline const PrimitivePtr kPrimStackPop = std::make_shared<Primitive>("StackPop"
// Arrays
inline const PrimitivePtr kPrimDynamicBroadcastTo = std::make_shared<Primitive>(kDynamicBroadcastTo);
inline const PrimitivePtr kPrimCummin = std::make_shared<Primitive>("Cummin");
inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo");
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
inline const PrimitivePtr kPrimTopK = std::make_shared<Primitive>("TopK");

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <set>
#include "abstract/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "ops/cummin.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_rank = x_shape.size();
const int64_t min_dim = 0;
(void)CheckAndConvertUtils::CheckInteger("the rank of input", SizeToLong(x_shape.size()), kGreaterThan, min_dim,
prim_name);
int64_t axis = GetValue<int64_t>(primitive->GetAttr("axis"));
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-y_rank, y_rank - 1}, prim_name);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{y_shape, y_shape});
}
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
const std::set<TypePtr> valid_types = {kFloat32, kFloat16, kInt32, kInt8, kUInt8};
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
TypePtr argmin_type = kInt32;
return std::make_shared<Tuple>(std::vector{x_type, argmin_type});
}
} // namespace
AbstractBasePtr CumminInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
MS_EXCEPTION_IF_NULL(primitive);
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Cummin, prim::kPrimCummin, CumminInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CUMMIN_H_
#define MINDSPORE_CORE_OPS_CUMMIN_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCummin = "kCummin";
class Cummin : public PrimitiveC {
public:
Cummin() : PrimitiveC(kNameCummin) { InitIOName({"x"}, {"y"}); }
~Cummin() = default;
MS_DECLARE_PARENT(Cummin, PrimitiveC);
};
AbstractBasePtr CumminInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCumminPtr = std::shared_ptr<Cummin>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CUMMIN_H_

View File

@ -15,6 +15,7 @@
"""tbe ops"""
from .celu import _celu_tbe
from .cummin import _cummin_tbe
from .abs import _abs_tbe
from .abs_ds import _abs_ds_tbe
from .inplace_add import _inplace_add_tbe

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cummin op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
celu_op_info = TBERegOp("Cummin") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("cummin.so") \
.compute_cost(10) \
.kernel_name("cummin") \
.partial_flag(True) \
.attr("axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.output(0, "argmin", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(celu_op_info)
def _cummin_tbe():
"""Cummin TBE register"""
return

View File

@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
from .math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul
from .math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin
from .array_ops import repeat_elements, sequence_mask
@ -53,6 +53,7 @@ __all__ = [
'clip_by_value',
'clip_by_global_norm',
'count_nonzero',
'cummin',
'tensor_dot',
'dot',
'batch_dot',

View File

@ -19,12 +19,11 @@ import numpy as np
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F
from .. import operations as P
# count_nonzero
@constexpr
def _check_validate_axis(axis, name):
@ -104,8 +103,6 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
return nonzero_num
# tensor dot
@constexpr
def _int_to_tuple_conv(axes):
@ -207,7 +204,7 @@ def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
for i in range(len(axes[0])): # sizes already validated
if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
invalid_a = True
if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0])-1-i]]:
if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0]) - 1 - i]]:
invalid_b = True
if invalid_a and invalid_b:
raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
@ -836,3 +833,71 @@ def matmul(x1, x2, dtype=None):
if dtype is not None:
res = res.astype(dtype)
return F.reshape(res, shape_out)
@constexpr
def _create_cummin_perm(axis, x_shape):
"""Insure axis is in [-len(x_shape),len(s_shape)-1]"""
len_axis = len(x_shape)
if not isinstance(axis, int):
raise TypeError(f"The date type of 'axis' should be Int, but got {axis}.")
if axis < -len_axis or axis > len_axis:
raise ValueError(f"The value of axis should be in [{-len_axis}, {len_axis}], but got {axis}.")
prem = [i for i in range(len_axis)]
if axis < 0:
axis = axis + len_axis
prem[0], prem[axis] = axis, 0
prem = tuple(prem)
return prem
def cummin(x, axis):
r"""
Computation of the cumulative minimum of elements of 'x' in the dimension axis,
and the index location of each maximum value found in the dimension 'axis'.
It returns the cumulative minimum of elements and the index.
..math::
y{i} = min(x{1}, x{2}, ... , x{i})
Args:
x (Tensor): The input tensor, rank of `input_x` > 0.
axis (Int): The dimension to do the operation, The axis is in the range from -len(`input_x`.shape)
to len(`input_x`.shape) - 1. When it's in the range from 0 to len(`input_x`.shape) - 1, it means starting
from the first dimension and counting forwards, When it's less than 0, it means we're counting backwards
from the last dimension. for example, -1 means the last dimension.
Outputs:
- **output** (Tensor) - The output tensor of the cumulative minimum of elements.
- **indices** (Tensor) - The result tensor of the index of each minimum value been found.
Raises:
TypeError: If `input_x` is not a Tensor.
TypeError: If 'axis' is not a int.
ValueError:If 'axis' is out the range of [-len(`input_x`.shape) to len(`input_x`.shape) - 1]
Supported Platforms:
``Ascend``
Examples:
>>> a = Tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220], mindspore.float32)
>>> output = ops.cummin(a, axis=0)
>>> print(output[0])
[-0.2284 -0.6628 -0.6628 -0.6628 -1.3298 -1.3298]
>>> print(output[1])
[0 1 1 1 4 4]
"""
cummin_op = inner.Cummin(axis=0)
if axis == 0:
out1, out2 = cummin_op(x)
else:
transpose = P.Transpose()
x_shape = P.Shape()(x)
prem = _create_cummin_perm(axis, x_shape)
x = transpose(x, prem)
out1, out2 = cummin_op(x)
out1 = transpose(out1, prem)
out2 = transpose(out2, prem)
return [out1, out2]

View File

@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc,)
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,

View File

@ -1429,3 +1429,49 @@ class DynamicBroadcastTo(Primitive):
def __init__(self):
"""Initialize DynamicBroadcastTo"""
self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y'])
class Cummin(Primitive):
r"""
Computation of the cumulative minimum of elements of 'input' in the dimension dim,
and the index location of each maximum value found in the dimension 'dim'.
It returns the cumulative minimum of elements and the index.
..math::
y{i} = min(x{1}, x{2}, ... , x{i})
Args:
- **axis** (int) - The dimension to do the operation, The axis is in the range from -len(`input_x`.shape)
to len(`input_x`.shape) - 1. When it's in the range from 0 to len(`input_x`.shape) - 1, it means starting
from the first dimension and counting forwards, When it's less than 0, it means we're counting backwards
from the last dimension. for example, -1 means the last dimension.
Inputs:
- **input_x** (Tensor) - The input tensor, rank of `input_x` > 0.
Outputs:
- **output** (Tensor) - The output tensor of the cumulative minimum of elements.
- **indices** (Tensor) - The result tensor of the index of each minimum value been found.
Raises:
TypeError: If `input_x` is not a Tensor.
TypeError: If 'axis' is not a int.
ValueError:If 'axis' is out the range from -len(`input_x`.shape) to len(`input_x`.shape) - 1
Supported Platforms:
``Ascend``
Examples:
>>> a = Tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220], mindspore.float32)
>>> output = ops.cummin(a, axis=0)
>>> print(output[0])
[-0.2284 -0.6628 -0.6628 -0.6628 -1.3298 -1.3298]
>>> print(output[1])
[0 1 1 1 4 4]
"""
@prim_attr_register
def __init__(self, axis):
"""Initialize Cummin"""
validator.check_value_type('axis', axis, [int], self.name)

View File

@ -2180,6 +2180,10 @@ test_case_nn_ops = [
'desc_const': [0],
'desc_inputs': [[3, 2]],
'desc_bprop': [[3, 2]]}),
('Cummin', {
'block': inner.Cummin(axis=0),
'desc_inputs': [[1, 3, 3, 3]],
'skip': ['backward']}),
('ApplyFtrl', {
'block': ApplyFtrlNet(),
'desc_inputs': [[3, 3]],