From cd9173fdfd7ec4233bad84b01b868642ae9793ef Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Sat, 30 Jan 2021 16:29:58 +0800 Subject: [PATCH] unify the output num of optimizer ops --- .../ascend/ascend_backend_optimization.cc | 1 - .../ascend/mindir/optimizer_unify_output.cc | 126 ++++++++++++++++++ .../ascend/mindir/optimizer_unify_output.h | 58 ++++++++ .../ccsrc/backend/session/ascend_session.cc | 5 + .../op_declare/nn_pooling_ops_declare.h | 2 +- .../op_declare/nn_training_ops_declare.cc | 31 +++-- .../op_declare/nn_training_ops_declare.h | 8 +- mindspore/core/base/core_ops.h | 3 +- mindspore/ops/_grad/grad_array_ops.py | 3 +- mindspore/ops/operations/nn_ops.py | 74 ++-------- .../insert_memcpy_async_for_hccl_op_test.cc | 9 +- .../ir_fusion/add_input_to_output_test.cc | 74 ---------- .../pre_activate/add_input_to_output_test.py | 39 ------ .../insert_memcpy_async_for_hccl_op.py | 10 +- .../momentum_lossscale_fusion_test.py | 2 +- 15 files changed, 234 insertions(+), 211 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.h delete mode 100644 tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc delete mode 100644 tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 3908ed00d3f..c7ff281706d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -292,7 +292,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); } ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.cc new file mode 100644 index 00000000000..14799217f1e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/mindir/optimizer_unify_output.h" + +#include +#include + +#include "abstract/abstract_value.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kFtrlOutputNum = 3; +constexpr size_t kMomentumOutputNum = 2; +constexpr size_t kRMSPropOutputNum = 3; +constexpr size_t kCenteredRMSPropOutputNum = 4; + +CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + auto cnode_ptr = node->cast(); + MS_EXCEPTION_IF_NULL(cnode_ptr); + + auto abstract = cnode_ptr->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + + if (AnfAlgo::HasNodeAttr("optim_output_passed", cnode_ptr) && abstract->isa()) { + return nullptr; + } + AnfAlgo::SetNodeAttr("optim_output_passed", MakeValue(true), cnode_ptr); + + std::vector abstract_list; + for (size_t i = 0; i < output_size; i++) { + abstract_list.push_back(abstract->Clone()); + } + auto abstract_tuple = std::make_shared(abstract_list); + cnode_ptr->set_abstract(abstract_tuple); + + auto index = NewValueNode(static_cast(0)); + auto get_item = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index}); + MS_EXCEPTION_IF_NULL(get_item); + + get_item->set_abstract(abstract->Clone()); + return get_item; +} +} // namespace + +const BaseRef FtrlUnifyOutput::DefinePattern() const { + VarPtr var = std::make_shared(); + VarPtr accum = std::make_shared(); + VarPtr linear = std::make_shared(); + VarPtr grad = std::make_shared(); + VarPtr lr = std::make_shared(); + VarPtr l1 = std::make_shared(); + VarPtr l2 = std::make_shared(); + VarPtr lr_power = std::make_shared(); + VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power}); + return pattern; +} + +const AnfNodePtr FtrlUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + return ProcessOutput(graph, node, kFtrlOutputNum); +} + +const BaseRef MomentumUnifyOutput::DefinePattern() const { + VarPtr var = std::make_shared(); + VarPtr accum = std::make_shared(); + VarPtr lr = std::make_shared(); + VarPtr grad = std::make_shared(); + VarPtr momentum = std::make_shared(); + VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum}); + return pattern; +} + +const AnfNodePtr MomentumUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + return ProcessOutput(graph, node, kMomentumOutputNum); +} + +const BaseRef RMSPropUnifyOutput::DefinePattern() const { + VarPtr inputs = std::make_shared(); + VectorRef pattern({prim::kPrimApplyRMSProp, inputs}); + return pattern; +} + +const AnfNodePtr RMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + return ProcessOutput(graph, node, kRMSPropOutputNum); +} + +const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const { + VarPtr var = std::make_shared(); + VarPtr mg = std::make_shared(); + VarPtr ms = std::make_shared(); + VarPtr mom = std::make_shared(); + VarPtr grad = std::make_shared(); + VarPtr lr = std::make_shared(); + VarPtr rho = std::make_shared(); + VarPtr momentum = std::make_shared(); + VarPtr epsilon = std::make_shared(); + VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon}); + return pattern; +} + +const AnfNodePtr CenteredRMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + return ProcessOutput(graph, node, kCenteredRMSPropOutputNum); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.h new file mode 100644 index 00000000000..4596d9f4105 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.h @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class FtrlUnifyOutput : public PatternProcessPass { + public: + explicit FtrlUnifyOutput(bool multigraph = true) : PatternProcessPass("ftrl_unify_output", multigraph) {} + ~FtrlUnifyOutput() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class MomentumUnifyOutput : public PatternProcessPass { + public: + explicit MomentumUnifyOutput(bool multigraph = true) : PatternProcessPass("momentum_unify_output", multigraph) {} + ~MomentumUnifyOutput() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class CenteredRMSPropUnifyOutput : public PatternProcessPass { + public: + explicit CenteredRMSPropUnifyOutput(bool multigraph = true) + : PatternProcessPass("centered_rmsprop_unify_output", multigraph) {} + ~CenteredRMSPropUnifyOutput() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class RMSPropUnifyOutput : public PatternProcessPass { + public: + explicit RMSPropUnifyOutput(bool multigraph = true) : PatternProcessPass("rmsprop_unify_output", multigraph) {} + ~RMSPropUnifyOutput() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_ diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index d36369be7f1..7008fa3023d 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -38,6 +38,7 @@ #include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" #include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h" +#include "backend/optimizer/ascend/mindir/optimizer_unify_output.h" #include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h" #include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h" #include "runtime/device/kernel_adjust.h" @@ -217,6 +218,10 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/nn_pooling_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/nn_pooling_ops_declare.h index 6b423bdff79..da313111b24 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/nn_pooling_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/nn_pooling_ops_declare.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.cc index 934329d2c36..13de37ff108 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -143,13 +143,13 @@ ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}}; -REG_ADPT_DESC(ApplyFtrlD, kNameApplyFtrl, ADPT_DESC(ApplyFtrlD)) +// ApplyFtrl +INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, + {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, + {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; +ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}}; +REG_ADPT_DESC(ApplyFtrl, kNameApplyFtrl, ADPT_DESC(ApplyFtrl)) // ApplyRMSPropD INPUT_MAP(ApplyRMSPropD) = { @@ -161,12 +161,11 @@ ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyCenteredRMSPropD) = { - {0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}}; -REG_ADPT_DESC(ApplyCenteredRMSPropD, kNameApplyCenteredRMSProp, ADPT_DESC(ApplyCenteredRMSPropD)) +// ApplyCenteredRMSProp +INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, + {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; +ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; +REG_ADPT_DESC(ApplyCenteredRMSProp, kNameApplyCenteredRMSProp, ADPT_DESC(ApplyCenteredRMSProp)) } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.h index 5668edaf8db..2bc725323ca 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.h @@ -62,8 +62,8 @@ DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) DECLARE_OP_ADAPTER(LarsV2Update) DECLARE_OP_USE_OUTPUT(LarsV2Update) -DECLARE_OP_ADAPTER(ApplyFtrlD) -DECLARE_OP_USE_OUTPUT(ApplyFtrlD) +DECLARE_OP_ADAPTER(ApplyFtrl) +DECLARE_OP_USE_OUTPUT(ApplyFtrl) DECLARE_OP_ADAPTER(SparseApplyFtrlD) DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) @@ -72,7 +72,7 @@ DECLARE_OP_ADAPTER(ApplyRMSPropD) DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) -DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD) -DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD) +DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) +DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) } // namespace mindspore::transform #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_TRAINING_OPS_DECLARE_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 34ea77ddd83..a2d7c9b030f 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -239,6 +239,7 @@ inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); inline const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); inline const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); +inline const PrimitivePtr kPrimApplyFtrl = std::make_shared("ApplyFtrl"); inline const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); inline const PrimitivePtr kPrimLrn = std::make_shared("Lrn"); inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); @@ -452,7 +453,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_ inline const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); inline const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); -// Other primitve not used by backend but used in core; +// Other primitive not used by backend but used in core; inline const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); inline const PrimitivePtr kPrimJ = std::make_shared("J"); diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 3b8ce7c1e21..2336579b851 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -308,7 +308,6 @@ def _concat_grad_uniform(input_shapes, input_nums): def get_bprop_concat(self): """Generate bprop for Concat""" axis = self.axis - is_ascend = context.get_context('device_target') == "Ascend" def bprop(x, out, dout): dx = () @@ -318,7 +317,7 @@ def get_bprop_concat(self): for i in range(input_nums): input_shapes = input_shapes + (shape_op(x[i]),) is_uniform = _concat_grad_uniform(input_shapes, input_nums) - if is_uniform and is_ascend: + if is_uniform: dx = P.Split(axis, input_nums)(dout) else: for i in range(input_nums): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index c354d855e0a..d43e59d8a37 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2413,12 +2413,8 @@ class ApplyMomentum(PrimitiveWithInfer): validator.check_value_type('gradient_scale', gradient_scale, [float], self.name) self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], outputs=['output']) - self.is_tbe = context.get_context("device_target") == "Ascend" - self.is_ge = context.get_context("enable_ge") def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): - if not self.is_ge and self.is_tbe: - return v_shape, v_shape return v_shape def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): @@ -2429,9 +2425,7 @@ class ApplyMomentum(PrimitiveWithInfer): validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name) - if not self.is_ge and self.is_tbe: - return g_dtype, g_dtype - return g_dtype + return v_dtype class SmoothL1Loss(PrimitiveWithInfer): @@ -2763,9 +2757,8 @@ class ApplyRMSProp(PrimitiveWithInfer): >>> momentum = 1e-10 >>> epsilon = 0.001 >>> output = apply_rms(input_x, mean_square, moment, learning_rate, grad, decay, momentum, epsilon) - >>> print(output) - (Tensor(shape=[], dtype=Float32, value= 0.100112), Tensor(shape=[], dtype=Float32, value= 4), - Tensor(shape=[], dtype=Float32, value= 0.899888)) + >>> output + Tensor(shape=[], dtype=Float32, value= 0.100112) """ @prim_attr_register @@ -2773,16 +2766,12 @@ class ApplyRMSProp(PrimitiveWithInfer): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad', 'rho', 'momentum', 'epsilon'], outputs=['output']) - self.is_ge = context.get_context("enable_ge") - self.is_d = context.get_context("device_target") == "Ascend" def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape, momentum_shape, epsilon_shape): validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) - if not self.is_ge and self.is_d: - return var_shape, var_shape, var_shape return var_shape def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype, @@ -2795,8 +2784,6 @@ class ApplyRMSProp(PrimitiveWithInfer): validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name) args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype} validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True) - if not self.is_ge and self.is_d: - return var_dtype, var_dtype, var_dtype return var_dtype def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): @@ -2867,22 +2854,15 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): >>> epsilon = 0.05 >>> output = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad, ... learning_rate, decay, momentum, epsilon) - >>> print(output) - (Tensor(shape=[2, 2], dtype=Float32, value= + >>> output + Tensor(shape=[2, 2], dtype=Float32, value= [[-2.00000000e+00, -5.02492237e+00], - [-8.04984474e+00, -1.10747662e+01]]), Tensor(shape=[2, 2], dtype=Float32, value= - [[ 0.00000000e+00, 1.00000000e+00], - [ 2.00000000e+00, 3.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= - [[ 0.00000000e+00, 1.00000000e+00], - [ 4.00000000e+00, 9.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= - [[ 0.00000000e+00, 4.02492237e+00], - [ 8.04984474e+00, 1.20747662e+01]])) + [-8.04984474e+00, -1.10747662e+01]]) """ @prim_attr_register def __init__(self, use_locking=False): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) - self.is_ascend = context.get_context("device_target") == "Ascend" def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): @@ -2890,8 +2870,6 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) - if self.is_ascend: - return var_shape, mean_gradient_shape, mean_square_shape, moment_shape return var_shape def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, @@ -2905,8 +2883,6 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name) args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True) - if self.is_ascend: - return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype return var_dtype @@ -6176,15 +6152,8 @@ class ApplyFtrl(PrimitiveWithInfer): Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type. Outputs: - There are three outputs for Ascend environment. - - - **var** (Tensor) - represents the updated `var`. - - **accum** (Tensor) - represents the updated `accum`. - - **linear** (Tensor) - represents the updated `linear`. - - There is only one output for GPU environment. - - - **var** (Tensor) - This value is always zero and the input parameters has been updated in-place. + - **var** (Tensor) - represents the updated `var`. As the input parameters has been updated in-place, this + value is always zero when the platforms is GPU. Supported Platforms: ``Ascend`` ``GPU`` @@ -6217,26 +6186,10 @@ class ApplyFtrl(PrimitiveWithInfer): >>> net = ApplyFtrlNet() >>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32) >>> output = net(input_x) - >>> is_tbe = context.get_context("device_target") == "Ascend" - >>> if is_tbe: - ... print(output) - (Tensor(shape=[2, 2], dtype=Float32, value= + >>> output + Tensor(shape=[2, 2], dtype=Float32, value= [[ 4.61418092e-01, 5.30964255e-01], - [ 2.68715084e-01, 3.82065028e-01]]), Tensor(shape=[2, 2], dtype=Float32, value= - [[ 1.64236546e+01, 9.64589405e+00], - [ 1.43758726e+00, 9.89177322e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= - [[-1.86994812e+03, -1.64906018e+03], - [-3.22187836e+02, -1.20163989e+03]])) - ... else: - ... print(net.var.asnumpy()) - [[0.4614181 0.5309642 ] - [0.2687151 0.38206503]] - ... print(net.accum.asnumpy()) - [[16.423655 9.645894 ] - [ 1.4375873 9.891773 ]] - ... print(net.linear.asnumpy()) - [[-1869.9479 -1649.0599] - [ -322.1879 -1201.6399]] + [ 2.68715084e-01, 3.82065028e-01]]) """ @prim_attr_register @@ -6244,14 +6197,11 @@ class ApplyFtrl(PrimitiveWithInfer): self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], outputs=['output']) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) - self.is_tbe = context.get_context("device_target") == "Ascend" def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, lr_power_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) - if self.is_tbe: - return var_shape, var_shape, var_shape return var_shape def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): @@ -6263,8 +6213,6 @@ class ApplyFtrl(PrimitiveWithInfer): validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name) - if self.is_tbe: - return var_type, var_type, var_type return var_type diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc index dadf84c5950..3a16ee0c8a8 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,7 +48,8 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { if (!node->isa()) { return false; } - return AnfAlgo::GetCNodeName(node->cast()) == "ApplyMomentum"; + auto node_name = AnfAlgo::GetCNodeName(node->cast()); + return node_name == "ApplyMomentum" || node_name == "AssignAdd"; } }; @@ -103,9 +104,9 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond3) { get_py_fun_.SetDoResolve(true); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "before"); ASSERT_TRUE(g != nullptr); - std::vector shp_x{1, 64, 112, 112}; + std::vector shp_x{3, 2}; auto x_abstract = std::make_shared(kFloat32, shp_x); - AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; + AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); EXPECT_NE(kg, nullptr); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc deleted file mode 100644 index cb427449b14..00000000000 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/backend_common_test.h" -#include "common/py_func_graph_fetcher.h" -#include "debug/anf_ir_dump.h" - -#define private public -#define protected public -#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" -#undef private -#undef protected - -namespace mindspore { -namespace opt { -class TestHWAddInputToOutput : public BackendCommon { - public: - TestHWAddInputToOutput() : getPyFun_("gtest_input.pre_activate.add_input_to_output_test", true) {} - ~TestHWAddInputToOutput() override = default; - - public: - UT::PyFuncGraphFetcher getPyFun_; -}; - -class MockOpFinder : public OpFinder { - public: - MockOpFinder() = default; - ~MockOpFinder() override = default; - int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) override { return 2; } -}; - -TEST_F(TestHWAddInputToOutput, test_add_input_to_output) { - FuncGraphPtr g = getPyFun_.CallAndParseRet("test_add_input_to_output", "before"); - EXPECT_NE(g, nullptr); - std::vector shp{2, 32, 224, 224}; - auto x_abstract = std::make_shared(kFloat32, shp); - AbstractBasePtrList args_spec_list; - for (size_t i = 0; i < 5; ++i) { - args_spec_list.push_back(x_abstract); - } - auto kg = GetKernelGraph(g, args_spec_list); - EXPECT_NE(kg, nullptr); - auto ret = kg->get_return(); - EXPECT_NE(ret, nullptr); - auto make_tuple = ret->input(1); - EXPECT_NE(make_tuple, nullptr); - auto momentum = make_tuple->cast()->input(1); - EXPECT_NE(momentum, nullptr); - EXPECT_NE(momentum->abstract(), nullptr); - EXPECT_FALSE(momentum->abstract()->isa()); - - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - auto pass = std::make_shared(); - pass->op_finder_ = std::make_shared(); - pm->AddPass(pass); - optimizer->AddPassManager(pm); - (void)optimizer->Optimize(kg); - EXPECT_TRUE(momentum->abstract()->isa()); -} -} // namespace opt -} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py deleted file mode 100644 index 4d4fa1fe963..00000000000 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -from mindspore.ops import operations as P - -ApplyMomentum = P.ApplyMomentum() - - -class FnDict: - def __init__(self): - self.fnDict = {} - - def __call__(self, fn): - self.fnDict[fn.__name__] = fn - - def __getitem__(self, name): - return self.fnDict[name] - - -def test_add_input_to_output(tag): - fns = FnDict() - - @fns - def before(input0, input1, input2, input3, input4): - return ApplyMomentum(input0, input1, input2, input3, input4) - - return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py index d7cfd5af3c2..9f7abbaa64f 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py @@ -22,7 +22,7 @@ broadcast = P.Broadcast(1) memcpy_async = Primitive('memcpy_async') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) -apply_momentun = P.ApplyMomentum() +assign_add = P.AssignAdd() control_depend = P.ControlDepend() relu = P.ReLU() @@ -84,14 +84,14 @@ def test_insert_memcpy_async_for_hccl_op_cond3(tag): fns = FnDict() @fns - def before(a, b, c, d, e): - res = apply_momentun(a, b, c, d, e) + def before(a, b): + res = assign_add(a, b) res = all_reduce(res) return res @fns - def after(a, b, c, d, e): - res = apply_momentun(a, b, c, d, e) + def after(a, b): + res = assign_add(a, b) res = memcpy_async(res) res = all_reduce(res) return make_tuple(res) diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py index fd6f44b021f..160e7da73de 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py @@ -48,6 +48,6 @@ def test_momentum_lossscale_fusion(tag): @fns def after(input0, input1, input2, input3, input4): - return make_tuple(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant)) + return make_tuple(tuple_getitem(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant), 0)) return fns[tag]