forked from mindspore-Ecosystem/mindspore
!25856 [assistant][ops]New operator implementation, include ApplyPowerSignD
Merge pull request !25856 from 傅嘉俊/ApplyPowerSignD
This commit is contained in:
commit
2399ed98be
|
@ -319,6 +319,7 @@ inline const PrimitivePtr kPrimAsinh = std::make_shared<Primitive>("Asinh");
|
|||
inline const PrimitivePtr kPrimAcosh = std::make_shared<Primitive>("Acosh");
|
||||
inline const PrimitivePtr kPrimAtanh = std::make_shared<Primitive>("Atanh");
|
||||
inline const PrimitivePtr kPrimApplyGradientDescent = std::make_shared<Primitive>("ApplyGradientDescent");
|
||||
inline const PrimitivePtr kPrimApplyPowerSignD = std::make_shared<Primitive>("ApplyPowerSign");
|
||||
inline const PrimitivePtr kPrimBesselI0e = std::make_shared<Primitive>("BesselI0e");
|
||||
inline const PrimitivePtr kPrimBesselI1e = std::make_shared<Primitive>("BesselI1e");
|
||||
inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/apply_power_sign_d.h"
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const int64_t kInputNum = 7;
|
||||
abstract::TupleShapePtr ApplyPowerSignDInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto var_shape = input_args[kInputIndex0]->BuildShape();
|
||||
auto m_shape = input_args[kInputIndex1]->BuildShape();
|
||||
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
|
||||
auto logbase_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShapeTrack())[kShape];
|
||||
auto sign_decay_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->GetShapeTrack())[kShape];
|
||||
auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShapeTrack())[kShape];
|
||||
auto grad_shape = input_args[kInputIndex6]->BuildShape();
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kLessEqual, 1, prim_name);
|
||||
if (lr_shape.size() == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0] size", lr_shape[0], kEqual, 1, prim_name);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("logbase_shape size", logbase_shape.size(), kLessEqual, 1, prim_name);
|
||||
if (logbase_shape.size() == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("logbase_shape[0] size", logbase_shape[0], kEqual, 1, prim_name);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("sign_decay_shape size", sign_decay_shape.size(), kLessEqual, 1, prim_name);
|
||||
if (sign_decay_shape.size() == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("sign_decay_shape[0] size", sign_decay_shape[0], kEqual, 1, prim_name);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("beta_shape size", beta_shape.size(), kLessEqual, 1, prim_name);
|
||||
if (beta_shape.size() == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("beta_shape[0] size", beta_shape[0], kEqual, 1, prim_name);
|
||||
}
|
||||
if (grad_shape->IsDynamic() || var_shape->IsDynamic() || m_shape->IsDynamic()) {
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape});
|
||||
}
|
||||
// var, m and grad must have the same shape
|
||||
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
|
||||
same_shape_args_map.insert({"m", m_shape});
|
||||
same_shape_args_map.insert({"grad", grad_shape});
|
||||
for (auto &elem : same_shape_args_map) {
|
||||
if (*elem.second != *var_shape) {
|
||||
MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString()
|
||||
<< " are not consistent with var shape " << var_shape->ToString();
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape});
|
||||
}
|
||||
|
||||
TuplePtr ApplyPowerSignDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto var_type = input_args[kInputIndex0]->BuildType();
|
||||
auto m_type = input_args[kInputIndex1]->BuildType();
|
||||
auto lr_type = input_args[kInputIndex2]->BuildType();
|
||||
auto logbase_type = input_args[kInputIndex3]->BuildType();
|
||||
auto sign_decay_type = input_args[kInputIndex4]->BuildType();
|
||||
auto beta_type = input_args[kInputIndex5]->BuildType();
|
||||
auto grad_type = input_args[kInputIndex6]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.insert({"var", var_type});
|
||||
args.insert({"m", m_type});
|
||||
args.insert({"grad", grad_type});
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("lr_dtype", lr_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("logbase_dtype", logbase_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("sign_decay_dtype", sign_decay_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("beta_dtype", beta_type, valid_types, prim_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto infer_type = ApplyPowerSignDInferType(primitive, input_args);
|
||||
auto infer_shape = ApplyPowerSignDInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyPowerSign, prim::kPrimApplyPowerSignD, ApplyPowerSignDInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_APPLY_POWER_SIGN_D_H_
|
||||
#define MINDSPORE_CORE_OPS_APPLY_POWER_SIGN_D_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyPowerSign = "ApplyPowerSign";
|
||||
class ApplyPowerSign : public PrimitiveC {
|
||||
public:
|
||||
ApplyPowerSign() : PrimitiveC(kNameApplyPowerSign) {
|
||||
InitIOName({"var", "m", "lr", "logbase", "sign_decay", "beta", "grad"}, {"var", "m"});
|
||||
}
|
||||
~ApplyPowerSign() = default;
|
||||
MS_DECLARE_PARENT(ApplyPowerSign, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using kPrimApplyPowerSignDPtr = std::shared_ptr<ApplyPowerSign>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_APPLY_POWERSIGND_H_
|
|
@ -58,6 +58,7 @@ from .apply_adagrad_d_a import _apply_adagrad_d_a_tbe
|
|||
from .apply_add_sign import _apply_add_sign_tbe
|
||||
from .apply_add_sign_ds import _apply_add_sign_ds_tbe
|
||||
from .apply_power_sign import _apply_power_sign_tbe
|
||||
from .apply_power_sign_ds import _apply_power_sign_ds_tbe
|
||||
from .apply_gradient_descent import _apply_gradient_descent_tbe
|
||||
from .apply_gradient_descent_ds import _apply_gradient_descent_ds_tbe
|
||||
from .apply_proximal_gradient_descent import _apply_proximal_gradient_descent_tbe
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyPowerSignD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_power_sign_d_op_info = TBERegOp("ApplyPowerSign") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_power_sign_d.so") \
|
||||
.compute_cost(10) \
|
||||
.dynamic_shape(True)\
|
||||
.kernel_name("apply_power_sign_d") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "m", False, "required", "all") \
|
||||
.input(2, "lr", False, "required", "all") \
|
||||
.input(3, "logbase", False, "required", "all") \
|
||||
.input(4, "sign_decay", False, "required", "all") \
|
||||
.input(5, "beta", False, "required", "all") \
|
||||
.input(6, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "m", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||
DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_power_sign_d_op_info)
|
||||
def _apply_power_sign_ds_tbe():
|
||||
"""ApplyPowerSignD TBE register"""
|
||||
return
|
|
@ -6184,7 +6184,7 @@ class ApplyAddSign(PrimitiveWithInfer):
|
|||
"""Initialize ApplyAddSign."""
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
class ApplyPowerSign(PrimitiveWithInfer):
|
||||
class ApplyPowerSign(Primitive):
|
||||
r"""
|
||||
Updates relevant entries according to the AddSign algorithm.
|
||||
|
||||
|
@ -6278,38 +6278,6 @@ class ApplyPowerSign(PrimitiveWithInfer):
|
|||
"""Initialize ApplyPowerSign."""
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape,
|
||||
beta_shape, grad_shape):
|
||||
validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
|
||||
validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
|
||||
lr_shape_len = len(lr_shape)
|
||||
validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name)
|
||||
if lr_shape_len == 1:
|
||||
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
|
||||
logbase_shape_len = len(logbase_shape)
|
||||
validator.check_int(logbase_shape_len, 1, Rel.LE, "logbase's rank", self.name)
|
||||
if logbase_shape_len == 1:
|
||||
validator.check_int(logbase_shape[0], 1, Rel.EQ, "logbase_shape[0]", self.name)
|
||||
sign_decay_shape_len = len(sign_decay_shape)
|
||||
validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name)
|
||||
if sign_decay_shape_len == 1:
|
||||
validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name)
|
||||
beta_shape_len = len(beta_shape)
|
||||
validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name)
|
||||
if beta_shape_len == 1:
|
||||
validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name)
|
||||
return var_shape, m_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype,
|
||||
beta_dtype, grad_dtype):
|
||||
valid_dtypes = [mstype.float16, mstype.float32]
|
||||
args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"logbase": logbase_dtype}, valid_dtypes, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name)
|
||||
return var_dtype, m_dtype
|
||||
|
||||
|
||||
class ApplyGradientDescent(Primitive):
|
||||
|
|
Loading…
Reference in New Issue