From 9969ef726059fc7262fbeed53ff1e994bf6ace7f Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 7 Dec 2021 14:09:07 +0800 Subject: [PATCH] [feat] [assistant] [I4809E] add new data operator --- mindspore/core/base/core_ops.h | 1 + mindspore/core/ops/apply_power_sign_d.cc | 120 ++++++++++++++++++ mindspore/core/ops/apply_power_sign_d.h | 42 ++++++ .../mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../ops/_op_impl/tbe/apply_power_sign_ds.py | 66 ++++++++++ .../python/mindspore/ops/operations/nn_ops.py | 34 +---- 6 files changed, 231 insertions(+), 33 deletions(-) create mode 100644 mindspore/core/ops/apply_power_sign_d.cc create mode 100644 mindspore/core/ops/apply_power_sign_d.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 512b6d5e685..83627ff8843 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -302,6 +302,7 @@ inline const PrimitivePtr kPrimAsinh = std::make_shared("Asinh"); inline const PrimitivePtr kPrimAcosh = std::make_shared("Acosh"); inline const PrimitivePtr kPrimAtanh = std::make_shared("Atanh"); inline const PrimitivePtr kPrimApplyGradientDescent = std::make_shared("ApplyGradientDescent"); +inline const PrimitivePtr kPrimApplyPowerSignD = std::make_shared("ApplyPowerSign"); inline const PrimitivePtr kPrimBesselI0e = std::make_shared("BesselI0e"); inline const PrimitivePtr kPrimBesselI1e = std::make_shared("BesselI1e"); inline const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); diff --git a/mindspore/core/ops/apply_power_sign_d.cc b/mindspore/core/ops/apply_power_sign_d.cc new file mode 100644 index 00000000000..f4b3ef69d89 --- /dev/null +++ b/mindspore/core/ops/apply_power_sign_d.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/apply_power_sign_d.h" + +#include +#include +#include +#include +#include + +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +namespace { +const int64_t kInputNum = 7; +abstract::TupleShapePtr ApplyPowerSignDInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum, + prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto var_shape = input_args[kInputIndex0]->BuildShape(); + auto m_shape = input_args[kInputIndex1]->BuildShape(); + auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape]; + auto logbase_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShapeTrack())[kShape]; + auto sign_decay_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->GetShapeTrack())[kShape]; + auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShapeTrack())[kShape]; + auto grad_shape = input_args[kInputIndex6]->BuildShape(); + (void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kLessEqual, 1, prim_name); + if (lr_shape.size() == 1) { + (void)CheckAndConvertUtils::CheckInteger("lr_shape[0] size", lr_shape[0], kEqual, 1, prim_name); + } + (void)CheckAndConvertUtils::CheckInteger("logbase_shape size", logbase_shape.size(), kLessEqual, 1, prim_name); + if (logbase_shape.size() == 1) { + (void)CheckAndConvertUtils::CheckInteger("logbase_shape[0] size", logbase_shape[0], kEqual, 1, prim_name); + } + (void)CheckAndConvertUtils::CheckInteger("sign_decay_shape size", sign_decay_shape.size(), kLessEqual, 1, prim_name); + if (sign_decay_shape.size() == 1) { + (void)CheckAndConvertUtils::CheckInteger("sign_decay_shape[0] size", sign_decay_shape[0], kEqual, 1, prim_name); + } + (void)CheckAndConvertUtils::CheckInteger("beta_shape size", beta_shape.size(), kLessEqual, 1, prim_name); + if (beta_shape.size() == 1) { + (void)CheckAndConvertUtils::CheckInteger("beta_shape[0] size", beta_shape[0], kEqual, 1, prim_name); + } + if (grad_shape->IsDynamic() || var_shape->IsDynamic() || m_shape->IsDynamic()) { + return std::make_shared(std::vector{var_shape, m_shape}); + } + // var, m and grad must have the same shape + std::map same_shape_args_map; + same_shape_args_map.insert({"m", m_shape}); + same_shape_args_map.insert({"grad", grad_shape}); + for (auto &elem : same_shape_args_map) { + if (*elem.second != *var_shape) { + MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString() + << " are not consistent with var shape " << var_shape->ToString(); + } + } + return std::make_shared(std::vector{var_shape, m_shape}); +} + +TuplePtr ApplyPowerSignDInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum, + prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto var_type = input_args[kInputIndex0]->BuildType(); + auto m_type = input_args[kInputIndex1]->BuildType(); + auto lr_type = input_args[kInputIndex2]->BuildType(); + auto logbase_type = input_args[kInputIndex3]->BuildType(); + auto sign_decay_type = input_args[kInputIndex4]->BuildType(); + auto beta_type = input_args[kInputIndex5]->BuildType(); + auto grad_type = input_args[kInputIndex6]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32}; + std::map args; + args.insert({"var", var_type}); + args.insert({"m", m_type}); + args.insert({"grad", grad_type}); + (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("lr_dtype", lr_type, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("logbase_dtype", logbase_type, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("sign_decay_dtype", sign_decay_type, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("beta_dtype", beta_type, valid_types, prim_name); + return std::make_shared(std::vector{var_type, m_type}); +} +} // namespace + +AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto infer_type = ApplyPowerSignDInferType(primitive, input_args); + auto infer_shape = ApplyPowerSignDInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(ApplyPowerSign, prim::kPrimApplyPowerSignD, ApplyPowerSignDInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/apply_power_sign_d.h b/mindspore/core/ops/apply_power_sign_d.h new file mode 100644 index 00000000000..a06eb4d4aac --- /dev/null +++ b/mindspore/core/ops/apply_power_sign_d.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_APPLY_POWER_SIGN_D_H_ +#define MINDSPORE_CORE_OPS_APPLY_POWER_SIGN_D_H_ +#include +#include + +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameApplyPowerSign = "ApplyPowerSign"; +class ApplyPowerSign : public PrimitiveC { + public: + ApplyPowerSign() : PrimitiveC(kNameApplyPowerSign) { + InitIOName({"var", "m", "lr", "logbase", "sign_decay", "beta", "grad"}, {"var", "m"}); + } + ~ApplyPowerSign() = default; + MS_DECLARE_PARENT(ApplyPowerSign, PrimitiveC); +}; +AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using kPrimApplyPowerSignDPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_APPLY_POWERSIGND_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index da4308eb069..8d29b812ca2 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -56,6 +56,7 @@ from .apply_adagrad_d_a import _apply_adagrad_d_a_tbe from .apply_add_sign import _apply_add_sign_tbe from .apply_add_sign_ds import _apply_add_sign_ds_tbe from .apply_power_sign import _apply_power_sign_tbe +from .apply_power_sign_ds import _apply_power_sign_ds_tbe from .apply_gradient_descent import _apply_gradient_descent_tbe from .apply_gradient_descent_ds import _apply_gradient_descent_ds_tbe from .apply_proximal_gradient_descent import _apply_proximal_gradient_descent_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py b/mindspore/python/mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py new file mode 100644 index 00000000000..54a6b64a1da --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py @@ -0,0 +1,66 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ApplyPowerSignD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +apply_power_sign_d_op_info = TBERegOp("ApplyPowerSign") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("apply_power_sign_d.so") \ + .compute_cost(10) \ + .dynamic_shape(True)\ + .kernel_name("apply_power_sign_d") \ + .partial_flag(True) \ + .input(0, "var", False, "required", "all") \ + .input(1, "m", False, "required", "all") \ + .input(2, "lr", False, "required", "all") \ + .input(3, "logbase", False, "required", "all") \ + .input(4, "sign_decay", False, "required", "all") \ + .input(5, "beta", False, "required", "all") \ + .input(6, "grad", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .output(1, "m", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, + DataType.F16_5HD) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, + DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, + DataType.F16_FracZ) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, + DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ + .get_op_info() + + +@op_info_register(apply_power_sign_d_op_info) +def _apply_power_sign_ds_tbe(): + """ApplyPowerSignD TBE register""" + return diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 666e5f13de1..93cf3f29b87 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -6209,7 +6209,7 @@ class ApplyAddSign(PrimitiveWithInfer): """Initialize ApplyAddSign.""" self.add_prim_attr('side_effect_mem', True) -class ApplyPowerSign(PrimitiveWithInfer): +class ApplyPowerSign(Primitive): r""" Updates relevant entries according to the AddSign algorithm. @@ -6303,38 +6303,6 @@ class ApplyPowerSign(PrimitiveWithInfer): """Initialize ApplyPowerSign.""" self.add_prim_attr('side_effect_mem', True) - def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape, - beta_shape, grad_shape): - validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) - validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) - lr_shape_len = len(lr_shape) - validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name) - if lr_shape_len == 1: - validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) - logbase_shape_len = len(logbase_shape) - validator.check_int(logbase_shape_len, 1, Rel.LE, "logbase's rank", self.name) - if logbase_shape_len == 1: - validator.check_int(logbase_shape[0], 1, Rel.EQ, "logbase_shape[0]", self.name) - sign_decay_shape_len = len(sign_decay_shape) - validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name) - if sign_decay_shape_len == 1: - validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name) - beta_shape_len = len(beta_shape) - validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name) - if beta_shape_len == 1: - validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name) - return var_shape, m_shape - - def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, - beta_dtype, grad_dtype): - valid_dtypes = [mstype.float16, mstype.float32] - args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype} - validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) - validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name) - validator.check_scalar_or_tensor_types_same({"logbase": logbase_dtype}, valid_dtypes, self.name) - validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name) - validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name) - return var_dtype, m_dtype class ApplyGradientDescent(Primitive):