diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.cc index 658b8e3d192..3212cb1317c 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.cc @@ -124,6 +124,22 @@ std::vector CalGenMaskV3OutputShape(const std::vector &shape, return shape; } +std::shared_ptr GetDropoutMaskShapeAbstract(const abstract::ShapePtr &input_shape, + const CNodePtr &dropout, bool use_v3) { + std::shared_ptr gen_mask_abstract; + if (input_shape->IsDynamic() || (dropout != nullptr && common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout))) { + ShapeVector mask_shp = {abstract::Shape::kShapeDimAny}; + auto gen_mask_shape = std::make_shared(mask_shp); + MS_EXCEPTION_IF_NULL(gen_mask_shape); + gen_mask_abstract = std::make_shared(kUInt8, gen_mask_shape); + } else { + auto gen_mask_shape = use_v3 ? CalGenMaskV3OutputShape(input_shape->shape(), kNumberTypeUInt8) + : CalGenMaskOutputShape(input_shape->shape()); + gen_mask_abstract = std::make_shared(kUInt8, gen_mask_shape); + } + return gen_mask_abstract; +} + bool NeedUpdate(const CNodePtr &getitem_cnode) { MS_EXCEPTION_IF_NULL(getitem_cnode); MS_EXCEPTION_IF_NULL(getitem_cnode->input(kIndex2)); @@ -200,17 +216,7 @@ CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const CNodePt dropout_gen_mask->AddPrimalAttr(kAttrFusion, dropout->GetPrimalAttr(kAttrFusion)); } - std::shared_ptr gen_mask_abstract; - if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) { - ShapeVector mask_shp = {abstract::Shape::kShapeDimAny}; - auto gen_mask_shape = std::make_shared(mask_shp); - MS_EXCEPTION_IF_NULL(gen_mask_shape); - gen_mask_abstract = std::make_shared(kUInt8, gen_mask_shape); - } else { - auto gen_mask_shape = use_v3 ? CalGenMaskV3OutputShape(input_shape->shape(), kNumberTypeUInt8) - : CalGenMaskOutputShape(input_shape->shape()); - gen_mask_abstract = std::make_shared(kUInt8, gen_mask_shape); - } + auto gen_mask_abstract = GetDropoutMaskShapeAbstract(input_shape, dropout, use_v3); MS_EXCEPTION_IF_NULL(gen_mask_abstract); dropout_gen_mask->set_abstract(gen_mask_abstract); dropout_gen_mask->set_scope(dropout->scope()); @@ -345,10 +351,7 @@ AnfNodePtr BuildDropoutDoMask(const PatternMap &m, const AnfNodePtr &) { // update abstract auto mask_abstract = mask_input->abstract(); MS_EXCEPTION_IF_NULL(mask_abstract); - auto grad_shape_vec = grad_input_shape->shape(); - auto mask_shape = - use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec); - mask_abstract = std::make_shared(kUInt8, mask_shape); + mask_abstract = GetDropoutMaskShapeAbstract(grad_input_shape, nullptr, use_v3); mask_input->set_abstract(mask_abstract); // update kernel info auto kernel_build_info_builder = std::make_shared(); @@ -375,10 +378,7 @@ AnfNodePtr BuildDropoutDoMask(const PatternMap &m, const AnfNodePtr &) { // modify mask abstract auto mask_abstract = mask_input->abstract(); MS_EXCEPTION_IF_NULL(mask_abstract); - auto grad_shape_vec = grad_input_shape->shape(); - auto mask_shape = - use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec); - mask_abstract = std::make_shared(kUInt8, mask_shape); + mask_abstract = GetDropoutMaskShapeAbstract(grad_input_shape, nullptr, use_v3); mask_input->set_abstract(mask_abstract); abs.push_back(mask_abstract); auto new_abstract = std::make_shared(abs);