!25271 [feat] [assistant] [I48OC4] add dynamic shape for ReLU6Grad operator

Merge pull request !25271 from 路雄博/Relu6Grad
This commit is contained in:
i-robot 2022-01-11 14:26:58 +00:00 committed by Gitee
commit a9f46e99e2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 198 additions and 10 deletions

View File

@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/relu6_grad.h"
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "abstract/param_validator.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ReLU6GradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr ReLU6GradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
(void)abstract::CheckDtypeSame(prim_name, out, dout);
auto x_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
if (!x_type->isa<TensorType>()) {
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s "
<< " input must be tensor type but got " << x_type->ToString();
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
return x_type;
}
} // namespace
AbstractBasePtr ReLU6GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto type = ReLU6GradInferType(primitive, input_args);
auto shape = ReLU6GradInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLU6Grad, prim::kPrimRelu6Grad, ReLU6GradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_RELU6_GRAD_H_
#define MINDSPORE_CORE_OPS_RELU6_GRAD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameReLU6Grad = "ReLU6Grad";
class ReLU6Grad : public PrimitiveC {
public:
ReLU6Grad() : PrimitiveC(kNameReLU6Grad) {}
~ReLU6Grad() = default;
MS_DECLARE_PARENT(ReLU6Grad, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ABS_GRAD_H_

View File

@ -122,7 +122,9 @@ from .relu_ds import _relu_ds_tbe
from .relu_grad import _relu_grad_tbe
from .relu_grad_ds import _relu_grad_ds_tbe
from .relu6 import _relu6_tbe
from .relu6_ds import _relu6_ds_tbe
from .relu6_grad import _relu6_grad_tbe
from .relu6_grad_ds import _relu6_grad_ds_tbe
from .relu_v2 import _relu_v2_tbe
from .relu_grad_v2 import _relu_grad_v2_tbe
from .relu_v2_ds import _relu_v2_ds_tbe

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReLU6 op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
relu6_ds_op_info = TBERegOp("ReLU6") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("relu6.so") \
.compute_cost(10) \
.kernel_name("relu6") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.get_op_info()
@op_info_register(relu6_ds_op_info)
def _relu6_ds_tbe():
"""Relu6 TBE register"""
return

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReLU6Grad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
relu6_grad_ds_op_info = TBERegOp("ReLU6Grad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("relu6_grad.so") \
.compute_cost(10) \
.kernel_name("relu6_grad") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "gradients", False, "required", "all") \
.input(1, "features", False, "required", "all") \
.output(0, "backprops", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()
@op_info_register(relu6_grad_ds_op_info)
def _relu6_grad_ds_tbe():
"""Relu6Grad TBE register"""
return

View File

@ -1629,7 +1629,7 @@ class ReluGrad(Primitive):
raise NotImplementedError
class ReLU6Grad(PrimitiveWithInfer):
class ReLU6Grad(Primitive):
"""Performs grad of ReLU6 operation."""
@prim_attr_register
@ -1639,15 +1639,6 @@ class ReLU6Grad(PrimitiveWithInfer):
def __call__(self, y_grad, x):
raise NotImplementedError
def infer_shape(self, y_grad_shape, x_shape):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype
class ReluGradV2(Primitive):
"""Performs grad of ReLUV2 operation."""