!18260 [assistant][ops]New operator implementation, include HShrink and HShrinkGrad

Merge pull request !18260 from 张璇/hshrink
This commit is contained in:
i-robot 2021-08-03 11:00:11 +00:00 committed by Gitee
commit a70b0f3bf4
14 changed files with 428 additions and 3 deletions

View File

@ -376,6 +376,8 @@ inline const PrimitivePtr kSquareSumV1 = std::make_shared<Primitive>("SquareSumV
inline const PrimitivePtr kFusedMulAdd = std::make_shared<Primitive>("FusedMulAdd");
inline const PrimitivePtr kPrimSoftShrink = std::make_shared<Primitive>("SoftShrink");
inline const PrimitivePtr kPrimSoftShrinkGrad = std::make_shared<Primitive>("SoftShrinkGrad");
inline const PrimitivePtr kPrimHShrink = std::make_shared<Primitive>("HShrink");
inline const PrimitivePtr kPrimHShrinkGrad = std::make_shared<Primitive>("HShrinkGrad");
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/hshrink_grad.h"
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
abstract::ShapePtr HShrinkGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, primitive->name());
auto prim_name = primitive->name();
auto gradients_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto features_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("gradients_shape", gradients_shape, kEqual, "features_shape", features_shape, prim_name,
TypeError);
return std::make_shared<abstract::Shape>(gradients_shape);
}
TypePtr HShrinkGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
types.emplace("gradients", input_args[0]->BuildType());
types.emplace("features", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
AbstractBasePtr HShrinkGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(HShrinkGradInferType(primitive, input_args),
HShrinkGradInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(HShrinkGrad, prim::kPrimHShrinkGrad, HShrinkGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_HSHRINK_GRAD_H_
#define MINDSPORE_CORE_OPS_HSHRINK_GRAD_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameHShrinkGrad = "HShrinkGrad";
class HShrinkGrad : public PrimitiveC {
public:
HShrinkGrad() : PrimitiveC(kNameHShrinkGrad) { InitIOName({"gradients", "features"}, {"backprops"}); }
~HShrinkGrad() = default;
MS_DECLARE_PARENT(HShrinkGrad, PrimitiveC);
};
AbstractBasePtr HShrinkGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimHShrinkGradPtr = std::shared_ptr<HShrinkGrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_HSHRINK_GRAD_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <map>
#include <set>
#include <string>
#include "ops/hshrink.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, primitive->name());
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types,
primitive->name());
}
} // namespace
AbstractBasePtr HShrinkInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(HShrink, prim::kPrimHShrink, HShrinkInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_HSHRINK_H
#define MINDSPORE_CORE_OPS_HSHRINK_H
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameHShrink = "HShrink";
class HShrink : public PrimitiveC {
public:
HShrink() : PrimitiveC(kNameHShrink) { InitIOName({"input_x"}, {"output"}); }
~HShrink() = default;
MS_DECLARE_PARENT(HShrink, PrimitiveC);
};
AbstractBasePtr HShrinkInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimHShrinkPtr = std::shared_ptr<HShrink>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_HSHRINK_H

View File

@ -40,6 +40,7 @@ __all__ = ['Softmax',
'ELU',
'LogSigmoid',
'SoftShrink',
'HShrink',
]
@ -803,6 +804,51 @@ class SoftShrink(Cell):
output = self.softshrink(input_x)
return output
class HShrink(Cell):
r"""
Applies the hard shrinkage function element-wise, each element complies the follow function:
.. math::
\text{HardShrink}(x) =
\begin{cases}
x, & \text{ if } x > \lambda \\
x, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd (float): The value for the HardShrink formulation. Default: 0.5
Inputs:
- **input_x** (Tensor) - The input of HardShrink with data type of float16 or float32.
Outputs:
Tensor, the same shape and data type as the input.
Supported Platforms:
``Ascend``
Raises:
TypeError: If `lambd` is not a float.
TypeError: If dtype of `input_x` is neither float16 nor float32.
Examples:
>>> input_x = Tensor(np.array([[ 0.5, 1, 2.0],[0.0533,0.0776,-2.1233]]),mstype.float32)
>>> hshrink = nn.HShrink()
>>> output = hshrink(input_x)
>>> print(output)
[[ 0. 1. 2. ]
[ 0. 0. -2.1233]]
"""
def __init__(self, lambd=0.5):
super(HShrink, self).__init__()
self.hshrink = P.HShrink(lambd)
def construct(self, input_x):
return self.hshrink(input_x)
_activation = {
'softmax': Softmax,
'logsoftmax': LogSoftmax,
@ -819,6 +865,7 @@ _activation = {
'hsigmoid': HSigmoid,
'logsigmoid': LogSigmoid,
'softshrink': SoftShrink,
'hshrink': HShrink,
}

View File

@ -44,3 +44,15 @@ def get_bprop_softshrink(self):
return (dx,)
return bprop
@bprop_getters.register(P.HShrink)
def get_bprop_hshrink(self):
"""Grad definition for `HShrinkGrad` operation."""
grad = G.HShrinkGrad(self.lambd)
def bprop(features, out, gradients):
dx = grad(gradients, features)
return (dx,)
return bprop

View File

@ -395,3 +395,5 @@ from .soft_shrink import _soft_shrink_tbe
from .soft_shrink_grad import _soft_shrink_grad_tbe
from .hsigmoid_grad import _hsigmoid_grad_tbe
from .hsigmoid import _hsigmoid_tbe
from .hshrink import _hshrink_tbe
from .hshrink_grad import _hshrink_grad_tbe

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""HardShrink op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
hshrink_op_info = TBERegOp("HShrink") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("hard_shrink.so") \
.compute_cost(10) \
.kernel_name("hard_shrink") \
.partial_flag(True) \
.attr("lambd", "optional", "float", "all", "0.5") \
.input(0, "input_x", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(hshrink_op_info)
def _hshrink_tbe():
return

View File

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""HShrinkGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
hshrink_grad_op_info = TBERegOp("HShrinkGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("hard_shrink_grad.so") \
.compute_cost(10) \
.kernel_name("hard_shrink_grad") \
.partial_flag(True) \
.attr("lambd", "optional", "float", "all", "0.5") \
.input(0, "gradients", False, "required", "all") \
.input(1, "features", False, "required", "all") \
.output(0, "backprops", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(hshrink_grad_op_info)
def _hshrink_grad_tbe():
"""HShrinkGrad TBE register"""
return

View File

@ -76,7 +76,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid, SeLU,
ResizeBilinear, Sigmoid, SeLU, HShrink,
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
SoftmaxCrossEntropyWithLogits, ROIAlign,
@ -485,7 +485,9 @@ __all__ = [
"TensorScatterSub",
"SoftShrink",
"FFT3D",
"IFFT3D"
"IFFT3D",
"HShrink"
]
__all__.sort()

View File

@ -2212,3 +2212,37 @@ class SoftShrinkGrad(Primitive):
self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
validator.check_value_type("lambd", lambd, [float], self.name)
validator.check_number("lambd", lambd, 0, Rel.GE, self.name)
class HShrinkGrad(Primitive):
"""
Computes gradients for HShrinkGrad operation.
Args:
Lambd (float): the λ value for the Hardshrink formulation. Default: 0.5
Inputs:
- **Gradients** (Tensor) - the gradients of loss to output of HShrink function.
Currently gradients data type only support float16 and float32.
- **Features** (Tensor) - Must be the input `input_x` of the forward operator HSHrink.
Currently features data type only support float16 and float32.
Outputs:
backprops - Tensor, with the same shape and data type as `features`.
Rasise:
ValueError: If `lambd` is not a float.
ValueError: If shape of `gradients` is not the same as `features`.
TypeError: If dtype of `gradients` is not the same as `features`.
TypeError: If dtype of `gradients` or `features` is neither float16 nor float32.
Supported Platforms:
``Ascend``
"""
@prim_attr_register
def __init__(self, lambd=0.5):
validator.check_value_type("lambd", lambd, [float], self.name)
if lambd < 0.0:
lambd = 0.0
self.add_prim_attr('lambd', lambd)

View File

@ -8606,7 +8606,6 @@ class SoftShrink(Primitive):
x + \lambda, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd: the :math:`\lambda` must be no less than zero value for the Softshrink formulation. Default: 0.5.
@ -8640,3 +8639,49 @@ class SoftShrink(Primitive):
"""Initialize SoftShrink"""
validator.check_value_type("lambd", lambd, [float], self.name)
validator.check_number("lambd", lambd, 0, Rel.GE, self.name)
class HShrink(Primitive):
r"""
Applies the hard shrinkage function element-wise, each element complies the follow function:
.. math::
\text{HardShrink}(x) =
\begin{cases}
x, & \text{ if } x > \lambda \\
x, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd (float): The value for the HardShrink formulation. Default: 0.5
Inputs:
- **input_x** (Tensor) - The input of HardShrink with data type of float16 or float32.
Outputs:
Tensor, the same shape and data type as the input.
Supported Platforms:
``Ascend``
Raises:
TypeError: If `lambd` is not a float.
TypeError: If dtype of `input_x` is neither float16 nor float32.
Examples:
>>> input_x = Tensor(np.array([[ 0.5, 1, 2.0],[0.0533,0.0776,-2.1233]]),mstype.float32)
>>> hshrink = P.HShrink()
>>> output = hshrink(input_x)
>>> print(output)
[[ 0. 1. 2. ]
[ 0. 0. -2.1233]]
"""
@prim_attr_register
def __init__(self, lambd=0.5):
"""Initialize HShrink"""
validator.check_value_type('lambd', lambd, [float], self.name)
if lambd < 0.0:
lambd = 0.0
self.add_prim_attr('lambd', lambd)

View File

@ -2204,6 +2204,16 @@ test_case_nn_ops = [
'desc_inputs': [Tensor(np.array([[-4, 4, 1]]), mstype.float32)],
'desc_bprop': [Tensor(np.array([[0, 1, 0.6666]]), mstype.float32)],
'skip': ['backward']}),
('HardShrink', {
'block': P.HShrink(),
'desc_inputs': [Tensor(np.array([[0.5, 1, 2.0], [0.0533, 0.0776, -2.1233]]), mstype.float32)],
'desc_bprop': [],
'skip': ['backward']}),
('HShrinkGrad', {
'block': G.HShrinkGrad(),
'desc_inputs': [Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), mstype.float16),
Tensor(np.array([[-4, -3, -2], [1, 2, 4]]), mstype.float16)],
'skip': ['backward']}),
]
test_case_array_ops = [