!35675 adapter dynamic shape of hswish & hswish grad
Merge pull request !35675 from jjfeing/add_hswish_operator
This commit is contained in:
commit
6d6327882f
|
@ -24,4 +24,5 @@ mindspore.ops.hardswish
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `x` 不是一个Tensor。
|
||||
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
|
@ -55,6 +55,8 @@ const std::map<std::string, std::string> opTypeAdapter = {{"ReLUV2", "ReluV2"},
|
|||
{"ParallelResizeBilinear", "SyncResizeBilinearV2"},
|
||||
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
|
||||
{"Split", "SplitD"},
|
||||
{"HSwish", "HardSwish"},
|
||||
{"HSwishGrad", "HardSwishGrad"},
|
||||
{"CeLU", "CeluV2"},
|
||||
{"ArgminV2", "ArgMin"},
|
||||
{"IndexAdd", "InplaceIndexAdd"}};
|
||||
|
|
|
@ -59,6 +59,8 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type)
|
|||
{"DynamicResizeNearestNeighbor", "ResizeNearestNeighborV2"},
|
||||
{"ParallelResizeBilinear", "SyncResizeBilinearV2"},
|
||||
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
|
||||
{"HSwish", "HardSwish"},
|
||||
{"HSwishGrad", "HardSwishGrad"},
|
||||
{"CeLU", "CeluV2"},
|
||||
{"TransposeNOD", "Transpose"},
|
||||
{"IndexAdd", "InplaceIndexAdd"},
|
||||
|
|
|
@ -199,6 +199,8 @@ constexpr auto kGridSampler3DGrad = "GridSampler3DGrad";
|
|||
constexpr auto kAdaptiveMaxPool2D = "AdaptiveMaxPool2D";
|
||||
constexpr auto kUpsampleTrilinear3D = "UpsampleTrilinear3D";
|
||||
constexpr auto kUpsampleNearest3D = "UpsampleNearest3D";
|
||||
constexpr auto kHSwish = "HSwish";
|
||||
constexpr auto kHSwishGrad = "HSwishGrad";
|
||||
|
||||
// CSRTensor
|
||||
constexpr auto kMakeCSRTensor = "MakeCSRTensor";
|
||||
|
@ -655,6 +657,8 @@ GVAR_DEF(PrimitivePtr, kPrimSoftShrink, std::make_shared<Primitive>("SoftShrink"
|
|||
GVAR_DEF(PrimitivePtr, kPrimSoftShrinkGrad, std::make_shared<Primitive>("SoftShrinkGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimHShrink, std::make_shared<Primitive>("HShrink"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimHShrinkGrad, std::make_shared<Primitive>("HShrinkGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimHSwish, std::make_shared<Primitive>(kHSwish));
|
||||
GVAR_DEF(PrimitivePtr, kPrimHSwishGrad, std::make_shared<Primitive>(kHSwishGrad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimHSVToRGB, std::make_shared<Primitive>("HSVToRGB"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDeformableOffsets, std::make_shared<Primitive>("DeformableOffsets"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimApplyAdagradDA, std::make_shared<Primitive>("ApplyAdagradDA"));
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/hswish_grad.h"
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
namespace {
|
||||
abstract::ShapePtr HSwishGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
primitive->name());
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto shape = input_shape[kShape];
|
||||
auto min_shape = input_shape[kMinShape];
|
||||
auto max_shape = input_shape[kMaxShape];
|
||||
return std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr HSwishGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
(void)types.emplace("y_grad", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
MIND_API_OPERATOR_IMPL(HSwishGrad, BaseOperator);
|
||||
AbstractBasePtr HSwishGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_type = HSwishGradInferType(primitive, input_args);
|
||||
auto infer_shape = HSwishGradInferShape(primitive, input_args);
|
||||
MS_EXCEPTION_IF_NULL(infer_shape);
|
||||
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(HSwishGrad, prim::kPrimHSwishGrad, HSwishGradInfer, nullptr, true);
|
||||
} // namespace mindspore::ops
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_HSWISH_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_HSWISH_GRAD_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
constexpr auto kNameHSwishGrad = "HSwishGrad";
|
||||
class MIND_API HSwishGrad : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(HSwishGrad);
|
||||
HSwishGrad() : BaseOperator(kNameHSwishGrad) { InitIOName({"y_grad", "x"}, {"output"}); }
|
||||
};
|
||||
} // namespace mindspore::ops
|
||||
#endif // MINDSPORE_CORE_OPS_HSWISH_GRAD_H_
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/hswish.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
namespace {
|
||||
TypePtr HSwishInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
|
||||
return x_dtype;
|
||||
}
|
||||
|
||||
abstract::ShapePtr HSwishInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
constexpr int64_t kInputSize = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
|
||||
prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto shape = input_shape[kShape];
|
||||
auto min_shape = input_shape[kMinShape];
|
||||
auto max_shape = input_shape[kMaxShape];
|
||||
auto out_shape = shape;
|
||||
|
||||
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(HSwish, BaseOperator);
|
||||
AbstractBasePtr HSwishInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return abstract::MakeAbstract(HSwishInferShape(primitive, input_args), HSwishInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(HSwish, prim::kPrimHSwish, HSwishInfer, nullptr, true);
|
||||
} // namespace mindspore::ops
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_HSWISH_H_
|
||||
#define MINDSPORE_CORE_OPS_HSWISH_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
constexpr auto kNameHSwish = "HSwish";
|
||||
/// \brief Calculates kNameHSwish .
|
||||
/// Refer to Python API @ref mindspore.ops.HSwish for more details.
|
||||
class MIND_API HSwish : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(HSwish);
|
||||
/// \brief Constructor.
|
||||
HSwish() : BaseOperator(kNameHSwish) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.HSwish for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr HSwishInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace mindspore::ops
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_HSWISH_H_
|
|
@ -1957,8 +1957,12 @@ class _ActivationGrad(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class HSwishGrad(_ActivationGrad):
|
||||
class HSwishGrad(Primitive):
|
||||
"""Gets the gradient of HSwish operation."""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize HSwishGrad"""
|
||||
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
|
||||
|
||||
|
||||
class HSigmoidGrad(_ActivationGrad):
|
||||
|
|
|
@ -816,7 +816,7 @@ class Elu(Primitive):
|
|||
self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask'])
|
||||
|
||||
|
||||
class HSwish(PrimitiveWithInfer):
|
||||
class HSwish(Primitive):
|
||||
r"""
|
||||
Hard swish activation function.
|
||||
|
||||
|
@ -838,13 +838,6 @@ class HSwish(PrimitiveWithInfer):
|
|||
"""Initialize HSwish."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, xshape):
|
||||
return xshape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Sigmoid(Primitive):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue