[fix][assistant][I3PYD4] fix bug in Ascend operator HShrink and HShrinkGrad

This commit is contained in:
danansheng 2021-07-20 15:26:16 +08:00
parent b04036e13c
commit 556e67402d
11 changed files with 44 additions and 42 deletions

View File

@ -18,19 +18,17 @@
#include <string>
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
abstract::ShapePtr HShrinkGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, primitive->name());
auto prim_name = primitive->name();
auto gradients_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto features_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];

View File

@ -14,11 +14,9 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_HShrink_GRAD_H_
#define MINDSPORE_CORE_OPS_HShrink_GRAD_H_
#include <map>
#ifndef MINDSPORE_CORE_OPS_HSHRINK_GRAD_H_
#define MINDSPORE_CORE_OPS_HSHRINK_GRAD_H_
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
@ -39,5 +37,4 @@ AbstractBasePtr HShrinkGradInfer(const abstract::AnalysisEnginePtr &, const Prim
using PrimHShrinkGradPtr = std::shared_ptr<HShrinkGrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_HShrink_GRAD_H_
#endif // MINDSPORE_CORE_OPS_HSHRINK_GRAD_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,10 +25,10 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -36,13 +36,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, primitive->name());
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
types.emplace("input_x", input_args[0]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types,
primitive->name());
}
} // namespace

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,17 +28,14 @@ namespace ops {
constexpr auto kNameHShrink = "HShrink";
class HShrink : public PrimitiveC {
public:
HShrink() : PrimitiveC(kNameHShrink) {
InitIOName({"input_x"}, {"output"});
}
HShrink() : PrimitiveC(kNameHShrink) { InitIOName({"input_x"}, {"output"}); }
~HShrink() = default;
MS_DECLARE_PARENT(HShrink, PrimitiveC);
};
AbstractBasePtr HShrinkInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
const std::vector<AbstractBasePtr> &input_args);
using PrimHShrinkPtr = std::shared_ptr<HShrink>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_HSHRINK_H

View File

@ -806,7 +806,7 @@ class SoftShrink(Cell):
class HShrink(Cell):
r"""
Applies the hard shrinkage function element-wise, each element comply the follow function:
Applies the hard shrinkage function element-wise, each element complies the follow function:
.. math::
\text{HardShrink}(x) =
@ -817,16 +817,16 @@ class HShrink(Cell):
\end{cases}
Args:
lambd (float): The value for the Hardshrink formulation. Default: 0.5
lambd (float): The value for the HardShrink formulation. Default: 0.5
Inputs:
- **input_x** (Tensor) - The input of hshrink with data type of float16 or float32.
- **input_x** (Tensor) - The input of HardShrink with data type of float16 or float32.
Outputs:
Tensor, the same shape as the input.
Tensor, the same shape and data type as the input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend``
Raises:
TypeError: If `lambd` is not a float.
@ -840,9 +840,11 @@ class HShrink(Cell):
[[ 0. 1. 2. ]
[ 0. 0. -2.1233]]
"""
def __init__(self, lambd=0.5):
super(HShrink, self).__init__()
self.hshrink = P.HShrink(lambd)
def construct(self, input_x):
return self.hshrink(input_x)

View File

@ -49,7 +49,7 @@ def get_bprop_softshrink(self):
@bprop_getters.register(P.HShrink)
def get_bprop_hshrink(self):
"""Grad definition for `HShrinkGrad` operation."""
grad = G.HShrinkGrad()
grad = G.HShrinkGrad(self.lambd)
def bprop(features, out, gradients):
dx = grad(gradients, features)

View File

@ -15,7 +15,7 @@
"""HardShrink op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
hshrink_op_info = TBERegOp("HShrink") \
.fusion_type("ELEMWISE") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("hard_shrink.so") \
.compute_cost(10) \

View File

@ -22,7 +22,7 @@ hshrink_grad_op_info = TBERegOp("HShrinkGrad") \
.compute_cost(10) \
.kernel_name("hard_shrink_grad") \
.partial_flag(True) \
.attr("lambda", "optional", "float", "all", "0.5") \
.attr("lambd", "optional", "float", "all", "0.5") \
.input(0, "gradients", False, "required", "all") \
.input(1, "features", False, "required", "all") \
.output(0, "backprops", False, "required", "all") \

View File

@ -76,7 +76,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid, SeLU,
ResizeBilinear, Sigmoid, SeLU, HShrink,
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
SoftmaxCrossEntropyWithLogits, ROIAlign,
@ -86,7 +86,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
FusedSparseFtrl, FusedSparseProximalAdagrad,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink, HShrink)
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink)
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
@ -485,7 +485,7 @@ __all__ = [
"TensorScatterSub",
"SoftShrink",
"FFT3D",
"IFFT3D"
"IFFT3D",
"HShrink"
]

View File

@ -2219,20 +2219,20 @@ class HShrinkGrad(Primitive):
Computes gradients for HShrinkGrad operation.
Args:
lambd (float): the λ value for the Hardshrink formulation. Default: 0.5
Lambd (float): the λ value for the Hardshrink formulation. Default: 0.5
Inputs:
- **gradients** (Tensor) - the gradients of loss to output of HShrink function.
- **Gradients** (Tensor) - the gradients of loss to output of HShrink function.
Currently gradients data type only support float16 and float32.
- **features** (Tensor) - Must be the input `input_x` of the forward operator HSHrink.
- **Features** (Tensor) - Must be the input `input_x` of the forward operator HSHrink.
Currently features data type only support float16 and float32.
Outputs:
backprops - Tensor, with the same shape and data type as `features`.
Rasise:
TypeError: If `lambd` is not a float.
TypeError: If shape of `gradients` is not the same as `features`.
ValueError: If `lambd` is not a float.
ValueError: If shape of `gradients` is not the same as `features`.
TypeError: If dtype of `gradients` is not the same as `features`.
TypeError: If dtype of `gradients` or `features` is neither float16 nor float32.
@ -2243,3 +2243,6 @@ class HShrinkGrad(Primitive):
@prim_attr_register
def __init__(self, lambd=0.5):
validator.check_value_type("lambd", lambd, [float], self.name)
if lambd < 0.0:
lambd = 0.0
self.add_prim_attr('lambd', lambd)

View File

@ -8702,7 +8702,7 @@ class SoftShrink(Primitive):
class HShrink(Primitive):
r"""
Applies the hard shrinkage function element-wise, each element comply the follow function:
Applies the hard shrinkage function element-wise, each element complies the follow function:
.. math::
\text{HardShrink}(x) =
@ -8711,18 +8711,18 @@ class HShrink(Primitive):
x, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd (float): The value for the Hardshrink formulation. Default: 0.5
lambd (float): The value for the HardShrink formulation. Default: 0.5
Inputs:
- **input_x** (Tensor) - The input of hshrink with data type of float16 or float32.
- **input_x** (Tensor) - The input of HardShrink with data type of float16 or float32.
Outputs:
Tensor, the same shape as the input.
Tensor, the same shape and data type as the input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend``
Raises:
TypeError: If `lambd` is not a float.
@ -8736,7 +8736,11 @@ class HShrink(Primitive):
[[ 0. 1. 2. ]
[ 0. 0. -2.1233]]
"""
@prim_attr_register
def __init__(self, lambd=0.5):
"""Initialize HShrink"""
validator.check_value_type('lambd', lambd, [float], self.name)
if lambd < 0.0:
lambd = 0.0
self.add_prim_attr('lambd', lambd)