From 380a57f3c1b86df4d7a9249c55c16dba9bc043f4 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Wed, 15 Jul 2020 14:58:50 +0800 Subject: [PATCH] Adapt ApplyCenteredRmsProp. --- .../kernel_compiler/tbe/tbe_adapter.cc | 1 + mindspore/ccsrc/transform/graph_ir/convert.cc | 2 +- .../ccsrc/transform/graph_ir/op_declare.cc | 13 +++---- .../ccsrc/transform/graph_ir/op_declare.h | 4 +-- .../_op_impl/tbe/apply_centered_rms_prop.py | 35 ++++++++++++------- mindspore/ops/operations/nn_ops.py | 5 +++ 6 files changed, 39 insertions(+), 21 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc index 449a9f45564..e24663fb10a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc @@ -81,6 +81,7 @@ static std::map tbe_func_adapter_map = { {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, {"apply_add_sign", "apply_add_sign_d"}, {"apply_power_sign", "apply_power_sign_d"}, + {"apply_centered_rms_prop", "apply_centered_rms_prop_d"}, {"transpose", "transpose_d"}, {"fill", "fill_d"}, {"unsorted_segment_sum", "unsorted_segment_sum_d"}, diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 7419dd2cc98..56028bbdd90 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -409,7 +409,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, {string(kNameAtan2), ADPT_DESC(Atan2)}, {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, - {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, + {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)}, {string(kNameL2Loss), ADPT_DESC(L2Loss)}, {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, {string(kNameRange), ADPT_DESC(RangeD)}, diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare.cc index e3751e0c925..939e5feba18 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.cc @@ -1284,12 +1284,13 @@ INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits())}, ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; -// ApplyCenteredRMSProp -INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, - {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, - {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; -ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; +// ApplyCenteredRMSPropD +INPUT_MAP(ApplyCenteredRMSPropD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, + {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; +ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyCenteredRMSPropD) = { + {0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}}; // L2Loss INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare.h index e493ea0e528..2774ac1ff83 100755 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.h @@ -486,8 +486,8 @@ DECLARE_OP_USE_OUTPUT(Atan2) DECLARE_OP_ADAPTER(ApplyRMSPropD) DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) -DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) -DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) +DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD) +DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD) DECLARE_OP_ADAPTER(L2Loss) DECLARE_OP_USE_OUTPUT(L2Loss) DECLARE_OP_ADAPTER(CTCLoss) diff --git a/mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py b/mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py index 8499614324c..f372583a85d 100644 --- a/mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +++ b/mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py @@ -13,15 +13,15 @@ # limitations under the License. # ============================================================================ -"""ApplyCenteredRMSProp op""" +"""ApplyCenteredRMSPropD op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("apply_centered_rms_prop.so") \ + .binfile_name("apply_centered_rms_prop_d.so") \ .compute_cost(10) \ - .kernel_name("apply_centered_rms_prop") \ + .kernel_name("apply_centered_rms_prop_d") \ .partial_flag(True) \ .input(0, "var", False, "required", "all") \ .input(1, "mg", False, "required", "all") \ @@ -33,34 +33,45 @@ apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \ .input(7, "epsilon", False, "required", "all") \ .input(8, "grad", False, "required", "all") \ .output(0, "var", False, "required", "all") \ + .output(1, "mg", False, "required", "all") \ + .output(2, "ms", False, "required", "all") \ + .output(3, "mom", False, "required", "all") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_5HD, DataType.F16_5HD) \ + DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, + DataType.F16_5HD) \ .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_FracZ, DataType.F16_FracZ) \ + DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, + DataType.F16_FracZ) \ .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, + DataType.F16_C1HWNCoC0) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default) \ + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_5HD, DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_FracZ, DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, + DataType.F32_C1HWNCoC0) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ .get_op_info() @op_info_register(apply_centered_rms_prop_op_info) def _apply_centered_rms_prop_tbe(): - """ApplyCenteredRMSProp TBE register""" + """ApplyCenteredRMSPropD TBE register""" return diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index e97c4c91c8c..d2b47357d93 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1962,6 +1962,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + self.is_ascend = context.get_context("device_target") == "Ascend" def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): @@ -1969,6 +1970,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) + if self.is_ascend: + return var_shape, mean_gradient_shape, mean_square_shape, moment_shape return var_shape def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, @@ -1982,6 +1985,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): validator.check_type_same(args_rho, valid_types, self.name) args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) + if self.is_ascend: + return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype return var_dtype