!35675 adapter dynamic shape of hswish & hswish grad

Merge pull request !35675 from jjfeing/add_hswish_operator
This commit is contained in:
i-robot 2022-06-13 12:45:54 +00:00 committed by Gitee
commit 6d6327882f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 217 additions and 9 deletions

View File

@ -24,4 +24,5 @@ mindspore.ops.hardswish
**异常:** **异常:**
- **TypeError** - `x` 不是一个Tensor。
- **TypeError** - `x` 的数据类型既不是float16也不是float32。 - **TypeError** - `x` 的数据类型既不是float16也不是float32。

View File

@ -55,6 +55,8 @@ const std::map<std::string, std::string> opTypeAdapter = {{"ReLUV2", "ReluV2"},
{"ParallelResizeBilinear", "SyncResizeBilinearV2"}, {"ParallelResizeBilinear", "SyncResizeBilinearV2"},
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"}, {"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
{"Split", "SplitD"}, {"Split", "SplitD"},
{"HSwish", "HardSwish"},
{"HSwishGrad", "HardSwishGrad"},
{"CeLU", "CeluV2"}, {"CeLU", "CeluV2"},
{"ArgminV2", "ArgMin"}, {"ArgminV2", "ArgMin"},
{"IndexAdd", "InplaceIndexAdd"}}; {"IndexAdd", "InplaceIndexAdd"}};

View File

@ -59,6 +59,8 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type)
{"DynamicResizeNearestNeighbor", "ResizeNearestNeighborV2"}, {"DynamicResizeNearestNeighbor", "ResizeNearestNeighborV2"},
{"ParallelResizeBilinear", "SyncResizeBilinearV2"}, {"ParallelResizeBilinear", "SyncResizeBilinearV2"},
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"}, {"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
{"HSwish", "HardSwish"},
{"HSwishGrad", "HardSwishGrad"},
{"CeLU", "CeluV2"}, {"CeLU", "CeluV2"},
{"TransposeNOD", "Transpose"}, {"TransposeNOD", "Transpose"},
{"IndexAdd", "InplaceIndexAdd"}, {"IndexAdd", "InplaceIndexAdd"},

View File

@ -199,6 +199,8 @@ constexpr auto kGridSampler3DGrad = "GridSampler3DGrad";
constexpr auto kAdaptiveMaxPool2D = "AdaptiveMaxPool2D"; constexpr auto kAdaptiveMaxPool2D = "AdaptiveMaxPool2D";
constexpr auto kUpsampleTrilinear3D = "UpsampleTrilinear3D"; constexpr auto kUpsampleTrilinear3D = "UpsampleTrilinear3D";
constexpr auto kUpsampleNearest3D = "UpsampleNearest3D"; constexpr auto kUpsampleNearest3D = "UpsampleNearest3D";
constexpr auto kHSwish = "HSwish";
constexpr auto kHSwishGrad = "HSwishGrad";
// CSRTensor // CSRTensor
constexpr auto kMakeCSRTensor = "MakeCSRTensor"; constexpr auto kMakeCSRTensor = "MakeCSRTensor";
@ -655,6 +657,8 @@ GVAR_DEF(PrimitivePtr, kPrimSoftShrink, std::make_shared<Primitive>("SoftShrink"
GVAR_DEF(PrimitivePtr, kPrimSoftShrinkGrad, std::make_shared<Primitive>("SoftShrinkGrad")); GVAR_DEF(PrimitivePtr, kPrimSoftShrinkGrad, std::make_shared<Primitive>("SoftShrinkGrad"));
GVAR_DEF(PrimitivePtr, kPrimHShrink, std::make_shared<Primitive>("HShrink")); GVAR_DEF(PrimitivePtr, kPrimHShrink, std::make_shared<Primitive>("HShrink"));
GVAR_DEF(PrimitivePtr, kPrimHShrinkGrad, std::make_shared<Primitive>("HShrinkGrad")); GVAR_DEF(PrimitivePtr, kPrimHShrinkGrad, std::make_shared<Primitive>("HShrinkGrad"));
GVAR_DEF(PrimitivePtr, kPrimHSwish, std::make_shared<Primitive>(kHSwish));
GVAR_DEF(PrimitivePtr, kPrimHSwishGrad, std::make_shared<Primitive>(kHSwishGrad));
GVAR_DEF(PrimitivePtr, kPrimHSVToRGB, std::make_shared<Primitive>("HSVToRGB")); GVAR_DEF(PrimitivePtr, kPrimHSVToRGB, std::make_shared<Primitive>("HSVToRGB"));
GVAR_DEF(PrimitivePtr, kPrimDeformableOffsets, std::make_shared<Primitive>("DeformableOffsets")); GVAR_DEF(PrimitivePtr, kPrimDeformableOffsets, std::make_shared<Primitive>("DeformableOffsets"));
GVAR_DEF(PrimitivePtr, kPrimApplyAdagradDA, std::make_shared<Primitive>("ApplyAdagradDA")); GVAR_DEF(PrimitivePtr, kPrimApplyAdagradDA, std::make_shared<Primitive>("ApplyAdagradDA"));

View File

@ -0,0 +1,66 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/hswish_grad.h"
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore::ops {
namespace {
abstract::ShapePtr HSwishGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
primitive->name());
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto shape = input_shape[kShape];
auto min_shape = input_shape[kMinShape];
auto max_shape = input_shape[kMaxShape];
return std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
}
TypePtr HSwishGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
(void)types.emplace("y_grad", input_args[0]->BuildType());
(void)types.emplace("x", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
MIND_API_OPERATOR_IMPL(HSwishGrad, BaseOperator);
AbstractBasePtr HSwishGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = HSwishGradInferType(primitive, input_args);
auto infer_shape = HSwishGradInferShape(primitive, input_args);
MS_EXCEPTION_IF_NULL(infer_shape);
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(HSwishGrad, prim::kPrimHSwishGrad, HSwishGradInfer, nullptr, true);
} // namespace mindspore::ops

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_HSWISH_GRAD_H_
#define MINDSPORE_CORE_OPS_HSWISH_GRAD_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore::ops {
constexpr auto kNameHSwishGrad = "HSwishGrad";
class MIND_API HSwishGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(HSwishGrad);
HSwishGrad() : BaseOperator(kNameHSwishGrad) { InitIOName({"y_grad", "x"}, {"output"}); }
};
} // namespace mindspore::ops
#endif // MINDSPORE_CORE_OPS_HSWISH_GRAD_H_

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/hswish.h"
#include <set>
#include <map>
#include <string>
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore::ops {
namespace {
TypePtr HSwishInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto x_dtype = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
return x_dtype;
}
abstract::ShapePtr HSwishInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
constexpr int64_t kInputSize = 1;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto shape = input_shape[kShape];
auto min_shape = input_shape[kMinShape];
auto max_shape = input_shape[kMaxShape];
auto out_shape = shape;
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);
}
} // namespace
MIND_API_OPERATOR_IMPL(HSwish, BaseOperator);
AbstractBasePtr HSwishInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return abstract::MakeAbstract(HSwishInferShape(primitive, input_args), HSwishInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(HSwish, prim::kPrimHSwish, HSwishInfer, nullptr, true);
} // namespace mindspore::ops

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_HSWISH_H_
#define MINDSPORE_CORE_OPS_HSWISH_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore::ops {
constexpr auto kNameHSwish = "HSwish";
/// \brief Calculates kNameHSwish .
/// Refer to Python API @ref mindspore.ops.HSwish for more details.
class MIND_API HSwish : public BaseOperator {
public:
MIND_API_BASE_MEMBER(HSwish);
/// \brief Constructor.
HSwish() : BaseOperator(kNameHSwish) { InitIOName({"x"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.HSwish for the inputs.
void Init() const {}
};
abstract::AbstractBasePtr HSwishInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace mindspore::ops
#endif // MINDSPORE_CORE_OPS_HSWISH_H_

View File

@ -1957,8 +1957,12 @@ class _ActivationGrad(PrimitiveWithInfer):
return x_dtype return x_dtype
class HSwishGrad(_ActivationGrad): class HSwishGrad(Primitive):
"""Gets the gradient of HSwish operation.""" """Gets the gradient of HSwish operation."""
@prim_attr_register
def __init__(self):
"""Initialize HSwishGrad"""
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
class HSigmoidGrad(_ActivationGrad): class HSigmoidGrad(_ActivationGrad):

View File

@ -816,7 +816,7 @@ class Elu(Primitive):
self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask']) self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask'])
class HSwish(PrimitiveWithInfer): class HSwish(Primitive):
r""" r"""
Hard swish activation function. Hard swish activation function.
@ -838,13 +838,6 @@ class HSwish(PrimitiveWithInfer):
"""Initialize HSwish.""" """Initialize HSwish."""
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, xshape):
return xshape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
return x_dtype
class Sigmoid(Primitive): class Sigmoid(Primitive):
r""" r"""