!35675 adapter dynamic shape of hswish & hswish grad
Merge pull request !35675 from jjfeing/add_hswish_operator
This commit is contained in:
commit
6d6327882f
|
@ -24,4 +24,5 @@ mindspore.ops.hardswish
|
||||||
|
|
||||||
**异常:**
|
**异常:**
|
||||||
|
|
||||||
|
- **TypeError** - `x` 不是一个Tensor。
|
||||||
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
|
@ -55,6 +55,8 @@ const std::map<std::string, std::string> opTypeAdapter = {{"ReLUV2", "ReluV2"},
|
||||||
{"ParallelResizeBilinear", "SyncResizeBilinearV2"},
|
{"ParallelResizeBilinear", "SyncResizeBilinearV2"},
|
||||||
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
|
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
|
||||||
{"Split", "SplitD"},
|
{"Split", "SplitD"},
|
||||||
|
{"HSwish", "HardSwish"},
|
||||||
|
{"HSwishGrad", "HardSwishGrad"},
|
||||||
{"CeLU", "CeluV2"},
|
{"CeLU", "CeluV2"},
|
||||||
{"ArgminV2", "ArgMin"},
|
{"ArgminV2", "ArgMin"},
|
||||||
{"IndexAdd", "InplaceIndexAdd"}};
|
{"IndexAdd", "InplaceIndexAdd"}};
|
||||||
|
|
|
@ -59,6 +59,8 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type)
|
||||||
{"DynamicResizeNearestNeighbor", "ResizeNearestNeighborV2"},
|
{"DynamicResizeNearestNeighbor", "ResizeNearestNeighborV2"},
|
||||||
{"ParallelResizeBilinear", "SyncResizeBilinearV2"},
|
{"ParallelResizeBilinear", "SyncResizeBilinearV2"},
|
||||||
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
|
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
|
||||||
|
{"HSwish", "HardSwish"},
|
||||||
|
{"HSwishGrad", "HardSwishGrad"},
|
||||||
{"CeLU", "CeluV2"},
|
{"CeLU", "CeluV2"},
|
||||||
{"TransposeNOD", "Transpose"},
|
{"TransposeNOD", "Transpose"},
|
||||||
{"IndexAdd", "InplaceIndexAdd"},
|
{"IndexAdd", "InplaceIndexAdd"},
|
||||||
|
|
|
@ -199,6 +199,8 @@ constexpr auto kGridSampler3DGrad = "GridSampler3DGrad";
|
||||||
constexpr auto kAdaptiveMaxPool2D = "AdaptiveMaxPool2D";
|
constexpr auto kAdaptiveMaxPool2D = "AdaptiveMaxPool2D";
|
||||||
constexpr auto kUpsampleTrilinear3D = "UpsampleTrilinear3D";
|
constexpr auto kUpsampleTrilinear3D = "UpsampleTrilinear3D";
|
||||||
constexpr auto kUpsampleNearest3D = "UpsampleNearest3D";
|
constexpr auto kUpsampleNearest3D = "UpsampleNearest3D";
|
||||||
|
constexpr auto kHSwish = "HSwish";
|
||||||
|
constexpr auto kHSwishGrad = "HSwishGrad";
|
||||||
|
|
||||||
// CSRTensor
|
// CSRTensor
|
||||||
constexpr auto kMakeCSRTensor = "MakeCSRTensor";
|
constexpr auto kMakeCSRTensor = "MakeCSRTensor";
|
||||||
|
@ -655,6 +657,8 @@ GVAR_DEF(PrimitivePtr, kPrimSoftShrink, std::make_shared<Primitive>("SoftShrink"
|
||||||
GVAR_DEF(PrimitivePtr, kPrimSoftShrinkGrad, std::make_shared<Primitive>("SoftShrinkGrad"));
|
GVAR_DEF(PrimitivePtr, kPrimSoftShrinkGrad, std::make_shared<Primitive>("SoftShrinkGrad"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimHShrink, std::make_shared<Primitive>("HShrink"));
|
GVAR_DEF(PrimitivePtr, kPrimHShrink, std::make_shared<Primitive>("HShrink"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimHShrinkGrad, std::make_shared<Primitive>("HShrinkGrad"));
|
GVAR_DEF(PrimitivePtr, kPrimHShrinkGrad, std::make_shared<Primitive>("HShrinkGrad"));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimHSwish, std::make_shared<Primitive>(kHSwish));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimHSwishGrad, std::make_shared<Primitive>(kHSwishGrad));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimHSVToRGB, std::make_shared<Primitive>("HSVToRGB"));
|
GVAR_DEF(PrimitivePtr, kPrimHSVToRGB, std::make_shared<Primitive>("HSVToRGB"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimDeformableOffsets, std::make_shared<Primitive>("DeformableOffsets"));
|
GVAR_DEF(PrimitivePtr, kPrimDeformableOffsets, std::make_shared<Primitive>("DeformableOffsets"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimApplyAdagradDA, std::make_shared<Primitive>("ApplyAdagradDA"));
|
GVAR_DEF(PrimitivePtr, kPrimApplyAdagradDA, std::make_shared<Primitive>("ApplyAdagradDA"));
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/grad/hswish_grad.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore::ops {
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr HSwishGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t input_num = 2;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
|
||||||
|
primitive->name());
|
||||||
|
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||||
|
auto shape = input_shape[kShape];
|
||||||
|
auto min_shape = input_shape[kMinShape];
|
||||||
|
auto max_shape = input_shape[kMaxShape];
|
||||||
|
return std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr HSwishGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
const int64_t input_num = 2;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
|
||||||
|
prim->name());
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
std::map<std::string, TypePtr> types;
|
||||||
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||||
|
(void)types.emplace("y_grad", input_args[0]->BuildType());
|
||||||
|
(void)types.emplace("x", input_args[1]->BuildType());
|
||||||
|
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
MIND_API_OPERATOR_IMPL(HSwishGrad, BaseOperator);
|
||||||
|
AbstractBasePtr HSwishGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto infer_type = HSwishGradInferType(primitive, input_args);
|
||||||
|
auto infer_shape = HSwishGradInferShape(primitive, input_args);
|
||||||
|
MS_EXCEPTION_IF_NULL(infer_shape);
|
||||||
|
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(HSwishGrad, prim::kPrimHSwishGrad, HSwishGradInfer, nullptr, true);
|
||||||
|
} // namespace mindspore::ops
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_HSWISH_GRAD_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_HSWISH_GRAD_H_
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore::ops {
|
||||||
|
constexpr auto kNameHSwishGrad = "HSwishGrad";
|
||||||
|
class MIND_API HSwishGrad : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(HSwishGrad);
|
||||||
|
HSwishGrad() : BaseOperator(kNameHSwishGrad) { InitIOName({"y_grad", "x"}, {"output"}); }
|
||||||
|
};
|
||||||
|
} // namespace mindspore::ops
|
||||||
|
#endif // MINDSPORE_CORE_OPS_HSWISH_GRAD_H_
|
|
@ -0,0 +1,60 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ops/hswish.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore::ops {
|
||||||
|
namespace {
|
||||||
|
TypePtr HSwishInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||||
|
auto x_dtype = input_args[0]->BuildType();
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
|
||||||
|
return x_dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract::ShapePtr HSwishInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
constexpr int64_t kInputSize = 1;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
|
||||||
|
prim_name);
|
||||||
|
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||||
|
auto shape = input_shape[kShape];
|
||||||
|
auto min_shape = input_shape[kMinShape];
|
||||||
|
auto max_shape = input_shape[kMaxShape];
|
||||||
|
auto out_shape = shape;
|
||||||
|
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(HSwish, BaseOperator);
|
||||||
|
AbstractBasePtr HSwishInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
return abstract::MakeAbstract(HSwishInferShape(primitive, input_args), HSwishInferType(primitive, input_args));
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(HSwish, prim::kPrimHSwish, HSwishInfer, nullptr, true);
|
||||||
|
} // namespace mindspore::ops
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_HSWISH_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_HSWISH_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore::ops {
|
||||||
|
constexpr auto kNameHSwish = "HSwish";
|
||||||
|
/// \brief Calculates kNameHSwish .
|
||||||
|
/// Refer to Python API @ref mindspore.ops.HSwish for more details.
|
||||||
|
class MIND_API HSwish : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(HSwish);
|
||||||
|
/// \brief Constructor.
|
||||||
|
HSwish() : BaseOperator(kNameHSwish) { InitIOName({"x"}, {"output"}); }
|
||||||
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.HSwish for the inputs.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr HSwishInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
} // namespace mindspore::ops
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_HSWISH_H_
|
|
@ -1957,8 +1957,12 @@ class _ActivationGrad(PrimitiveWithInfer):
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
class HSwishGrad(_ActivationGrad):
|
class HSwishGrad(Primitive):
|
||||||
"""Gets the gradient of HSwish operation."""
|
"""Gets the gradient of HSwish operation."""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize HSwishGrad"""
|
||||||
|
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
class HSigmoidGrad(_ActivationGrad):
|
class HSigmoidGrad(_ActivationGrad):
|
||||||
|
|
|
@ -816,7 +816,7 @@ class Elu(Primitive):
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask'])
|
self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask'])
|
||||||
|
|
||||||
|
|
||||||
class HSwish(PrimitiveWithInfer):
|
class HSwish(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Hard swish activation function.
|
Hard swish activation function.
|
||||||
|
|
||||||
|
@ -838,13 +838,6 @@ class HSwish(PrimitiveWithInfer):
|
||||||
"""Initialize HSwish."""
|
"""Initialize HSwish."""
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
def infer_shape(self, xshape):
|
|
||||||
return xshape
|
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
|
||||||
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
|
|
||||||
return x_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class Sigmoid(Primitive):
|
class Sigmoid(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue