diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc index 87b8d15cca0..e88868c772d 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc @@ -204,7 +204,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; } @@ -215,8 +215,7 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { } auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - MS_EXCEPTION_IF_NULL(dropout_gen_mask_cnode); - if (dropout_gen_mask_cnode->inputs().size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; } if (!IsValueNode(dropout_gen_mask_cnode->input(0))) { @@ -233,11 +232,45 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { return prim; } +void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + + AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + if (!dropout_gen_mask->isa()) { + MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode."; + } + + auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); + if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; + } + + if (!IsValueNode(dropout_gen_mask_cnode->input(1))) { + MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple."; + } + + FuncGraphPtr func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr."; + } + + ValuePtr new_shape = MakeValue(input_slice_shape); + AnfNodePtr val = NewValueNode(new_shape); + (void)manager->Replace(dropout_gen_mask_cnode->input(1), val); +} + // DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. -Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { +std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { + std::vector replace_ops; MS_EXCEPTION_IF_NULL(cnode); PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); MS_EXCEPTION_IF_NULL(prim); @@ -260,15 +293,20 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) { MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1"; } + + Shape input_slice_shape = inputs_tensor_info_[0].slice_shape(); int32_t seed_0 = GetValue(attr[SEED0]); int32_t seed_1 = GetValue(attr[SEED1]); if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) { seed_0 = SEED_NUM; seed_1 = SEED_NUM; SEED_NUM++; + } else { + SetGenMaskShape(cnode, input_slice_shape); + MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape); + return replace_ops; } - Shape input_slice_shape = inputs_tensor_info_[0].slice_shape(); ValuePtr new_shape = MakeValue(input_slice_shape); Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0)); Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1)); @@ -278,7 +316,8 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)}; OperatorArgs args = std::make_pair(attrs, params); Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)}; - return replace_op; + replace_ops.push_back(replace_op); + return replace_ops; } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h index c0d112f52d4..c51a0a95135 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h @@ -41,7 +41,7 @@ class DropoutDoMaskInfo : public OperatorInfo { Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; - Operator GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); + std::vector GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); protected: Status CheckStrategy(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 4528ff8639c..39dd2c96e00 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1876,11 +1876,15 @@ void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePt DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast(distribute_operator); MS_EXCEPTION_IF_NULL(dropout_do_mask); - Operator replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); + std::vector replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); + if (replace_op.empty()) { + MS_LOG(DEBUG) << "No need to replace dropout_gen_mask"; + return; + } if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; } - ReplaceOneOp(replace_op, cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); + ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); } void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {