!25703 [assistant][ops] Add math operator LpNorm

Merge pull request !25703 from 孟权令/LpNorm
This commit is contained in:
i-robot 2021-11-15 02:36:54 +00:00 committed by Gitee
commit 9522ee9686
9 changed files with 295 additions and 1 deletions

View File

@ -65,6 +65,7 @@ constexpr auto kBiasAddGrad = "BiasAddGrad";
constexpr auto kCos = "Cos";
constexpr auto kAbs = "Abs";
constexpr auto kTrunc = "Trunc";
constexpr auto kLpNorm = "LpNorm";
constexpr auto kSquare = "Square";
constexpr auto kReal = "Real";
constexpr auto kGer = "Ger";
@ -514,6 +515,7 @@ inline const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
inline const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
inline const PrimitivePtr kPrimLpNorm = std::make_shared<Primitive>(kLpNorm);
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power");

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <string>
#include <vector>
#include <algorithm>
#include <set>
#include "ops/lp_norm.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 1;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto output_shape = input_shape;
auto input_rank = SizeToLong(input_shape.size());
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr("axis"));
auto keep_dims = GetValue<bool>(primitive->GetAttr("keep_dims"));
if (input_rank == 0) {
CheckAndConvertUtils::CheckInteger("axis size", axis.size(), kEqual, input_rank + 1, prim_name);
return std::make_shared<abstract::Shape>(input_shape);
} else {
CheckAndConvertUtils::CheckInRange("axis size", axis.size(), kIncludeNeither, {0, input_rank + 1}, prim_name);
}
if (axis.size() > 1) {
for (size_t i = 0; i < axis.size(); ++i) {
CheckAndConvertUtils::CheckInRange("axis value", axis[i], kIncludeLeft, {-input_rank, input_rank}, prim_name);
if (axis[i] < 0) {
axis[i] += input_rank;
}
}
for (size_t i = 0; i < axis.size(); ++i) {
auto temp = axis;
auto idx = std::find(temp.begin(), temp.end(), axis[i]);
(void)temp.erase(idx);
auto re_idx = std::find(temp.begin(), temp.end(), axis[i]);
if (re_idx != temp.end()) {
MS_EXCEPTION(ValueError) << "The element of the axis should be different";
}
if (keep_dims == false) {
output_shape[axis[i]] = -1;
} else {
output_shape[axis[i]] = 1;
}
}
if (keep_dims == false) {
for (std::vector<int64_t>::iterator iter = output_shape.begin(); iter != output_shape.end(); ++iter) {
if (*iter == -1) {
iter = output_shape.erase(iter);
iter -= 1;
}
}
}
} else {
if (axis[0] < 0) {
axis[0] += input_rank;
}
if (keep_dims == false) {
(void)output_shape.erase(output_shape.begin() + axis[0]);
} else {
output_shape[axis[0]] = 1;
}
}
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(infer_type);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
CheckAndConvertUtils::CheckTensorTypeValid("input", infer_type, valid_types, prim->name());
return infer_type;
}
} // namespace
AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(LpNorm, prim::kPrimLpNorm, LpNormInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LP_NORM_H_
#define MINDSPORE_CORE_OPS_LP_NORM_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLpNorm = "LpNorm";
class LpNorm : public PrimitiveC {
public:
LpNorm() : PrimitiveC(kNameLpNorm) { InitIOName({"input"}, {"output"}); }
~LpNorm() = default;
MS_DECLARE_PARENT(LpNorm, PrimitiveC);
};
AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimLpNormPtr = std::shared_ptr<LpNorm>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LP_NORM_H_

View File

@ -81,6 +81,41 @@ def get_bprop_index_lerp(self):
return bprop
@bprop_getters.register(P.LpNorm)
def get_bprop_lp_norm(self):
"""Grad definition for `LpNorm` operation."""
p = self.p
keep_dims = self.keep_dims
axis = self.axis
if isinstance(axis, int):
axis = [axis]
sign_op = P.Sign()
abs_op = P.Abs()
zeros_like_op = P.ZerosLike()
expand_dims_op = P.ExpandDims()
pow_op = P.Pow()
def bprop(input_x, out, dout):
if not keep_dims and input_x.shape != ():
for i in axis:
dout = expand_dims_op(dout, i)
out = expand_dims_op(out, i)
if p == 0:
return (zeros_like_op(input_x),)
if p == 1:
return (dout * sign_op(input_x),)
if p == 2:
input_scaled = input_x
scale_v = dout / out
else:
input_scaled = pow_op(abs_op(input_x), (p-2)) * input_x
scale_v = dout / pow_op(out, (p-1))
return (input_scaled * scale_v,)
return bprop
@bprop_getters.register(P.Erfinv)
def get_bprop_erfinv(self):
"""Grad definition for `Erfinv` operation."""

View File

@ -428,6 +428,7 @@ from .cdist_grad import _cdist_grad_tbe
from .neg_ds import _neg_ds_tbe
from .not_equal_ds import _not_ds_equal_tbe
from .reciprocal_ds import _reciprocal_ds_tbe
from .lp_norm import _lp_norm_tbe
from .ctc_loss_v2 import _ctc_loss_v2_tbe
from .ctc_loss_v2_grad import _ctc_loss_v2_grad_tbe
from .roll import _roll_tbe

View File

@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LpNorm op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
lp_norm_op_info = TBERegOp("LpNorm") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("lp_norm.so") \
.compute_cost(10) \
.kernel_name("lp_norm") \
.partial_flag(True) \
.attr("p", "optional", "int", "all", "2") \
.attr("axis", "optional", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all", "false") \
.attr("epsilon", "optional", "float", "all", "1e-12") \
.input(0, "input", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.get_op_info()
@op_info_register(lp_norm_op_info)
def _lp_norm_tbe():
"""LpNorm TBE register"""
return

View File

@ -52,7 +52,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
LogicalNot, LogicalOr, MatMul, Maximum, MulNoNan,
LogicalNot, LogicalOr, LpNorm, MatMul, Maximum, MulNoNan,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
@ -406,6 +406,7 @@ __all__ = [
"Conv2DBackpropInput",
"ComputeAccidentalHits",
"Sign",
"LpNorm",
"LARSUpdate",
"Round",
"Eps",

View File

@ -1145,6 +1145,63 @@ class Cdist(Primitive):
self.init_prim_io_names(inputs=['input_x', 'input_y'], outputs=['output'])
class LpNorm(Primitive):
"""
Returns the matrix norm or vector norm of a given tensor.
.. math::
output = sum(abs(input)**p)**(1/p)
Args:
axis(int,list,tuple): Specifies which dimension or dimensions of input to calculate the norm across.
p(int): The order of norm.
keep_dims(bool): Whether the output tensors have dim retained or not.
Inputs:
- **input** (Tensor) - Input tensor.
Outputs:
Tensor, has the same dtype as `input`, which shape is depend on the args axis.For example, if the size of input
is (2, 3, 4), axis is [0, 1], Outputs' shape will be (4,).
Raises:
TypeError: If `input` is not a Tensor.
TypeError: If dtype of `input` is not one of: float16, float32.
TypeError: If `p` is not an int.
TypeError: If `axis` is not an int, a tuple or a list.
TypeError: If `axis` is a tuple or a list, but the element of `axis` is not an int.
TypeError: If `keep_dims` is not a bool.
ValueError: If the element of `axis` is out of the range [-len(input.shape), len(input.shape)).
ValueError: If the length of shape of `axis` is bigger than the length of shape of `input`.
Supported Platforms:
``Ascend``
Examples:
>>> input = Tensor(np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]).astype(np.float32))
>>> op = ops.LpNorm(axis=[0, 1], p=2, keep_dims=False)
>>> output = op(input)
>>> print(output)
[ 9.165152 10.954452]
"""
@prim_attr_register
def __init__(self, axis, p=2, keep_dims=False, epsilon=1e-12):
"""Initialize LpNorm"""
validator.check_value_type("p", p, [int], self.name)
validator.check_value_type("axis", axis, [int, tuple, list], self.name)
validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
validator.check_value_type("epsilon", epsilon, [float], self.name)
validator.check_non_negative_int(p, "p", self.name)
validator.check_non_negative_float(epsilon, "epsilon", self.name)
if isinstance(axis, int):
self.add_prim_attr('axis', [self.axis])
else:
for element_of_axis in axis:
validator.check_value_type("element_of_axis", element_of_axis, [int], self.name)
self.init_prim_io_names(inputs=['input'], outputs=['output'])
class MatMul(PrimitiveWithCheck):
r"""
Multiplies matrix `x` and matrix `y`.

View File

@ -1283,6 +1283,10 @@ test_case_math_ops = [
'block': LaplaceNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)],
'skip': ['backward']}),
('LpNorm', {
'block': P.LpNorm(axis=[0, 1], p=2, keep_dims=False),
'desc_inputs': [Tensor([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mstype.float32)],
'desc_bprop': [Tensor([1, 2], mstype.float32)]}),
('Gamma', {
'block': GammaNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)],