diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/dropout_for_ge.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/dropout_for_ge.cc index 03e63ef3e6a..65a8a5a48c0 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/dropout_for_ge.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/dropout_for_ge.cc @@ -45,6 +45,75 @@ std::vector CalGenMaskOutputShape(const std::vector &shape) { return {ret}; } +abstract::ShapePtr GetDropoutInputShape(const CNodePtr &dropout_node) { + MS_EXCEPTION_IF_NULL(dropout_node); + auto input = dropout_node->input(kInputIndexOne); + MS_EXCEPTION_IF_NULL(input); + auto input_base_shape = input->Shape(); + MS_EXCEPTION_IF_NULL(input_base_shape); + auto input_shape = input_base_shape->cast(); + MS_EXCEPTION_IF_NULL(input_shape); + return input_shape; +} + +CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_input, + const abstract::ShapePtr &input_shape) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input_shape); + std::vector dynamic_shape_inputs{NewValueNode(std::make_shared("Shape")), node_input}; + CNodePtr dynamic_shape = func_graph->NewCNode(dynamic_shape_inputs); + MS_EXCEPTION_IF_NULL(dynamic_shape); + ShapeVector tensor_shp({static_cast(input_shape->shape().size())}); + auto dynamic_shape_abstract = + std::make_shared(kInt64, std::make_shared(tensor_shp)); + MS_EXCEPTION_IF_NULL(dynamic_shape_abstract); + dynamic_shape->set_abstract(dynamic_shape_abstract); + return dynamic_shape; +} + +CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const CNodePtr &dropout, + const AnfNodePtr &keep_prob_value, const abstract::ShapePtr &input_shape, + const ValuePtr &seed_0, const ValuePtr &seed_1) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(dropout); + MS_EXCEPTION_IF_NULL(input_shape); + std::vector dropout_gen_mask_inputs = + std::vector{NewValueNode(std::make_shared(kDropoutGenMaskOpName))}; + if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) { + CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout->input(kIndex1), input_shape); + dynamic_shape->set_scope(dropout->scope()); + dropout_gen_mask_inputs.push_back(dynamic_shape); + dropout_gen_mask_inputs.push_back(keep_prob_value); + } else { + auto shape_value = CreateShapeValueNode(func_graph, input_shape->shape(), true); + dropout_gen_mask_inputs.push_back(shape_value); + dropout_gen_mask_inputs.push_back(keep_prob_value); + } + CNodePtr dropout_gen_mask = opt::NewCNode(dropout_gen_mask_inputs, func_graph, {dropout}); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + if (dropout->HasPrimalAttr(kAttrFusion)) { + dropout_gen_mask->AddPrimalAttr(kAttrFusion, dropout->GetPrimalAttr(kAttrFusion)); + } + + std::shared_ptr gen_mask_abstract; + if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) { + ShapeVector mask_shp = {abstract::Shape::kShapeDimAny}; + auto gen_mask_shape = std::make_shared(mask_shp); + MS_EXCEPTION_IF_NULL(gen_mask_shape); + gen_mask_abstract = std::make_shared(kUInt8, gen_mask_shape); + } else { + auto gen_mask_shape = CalGenMaskOutputShape(input_shape->shape()); + gen_mask_abstract = std::make_shared(kUInt8, gen_mask_shape); + } + MS_EXCEPTION_IF_NULL(gen_mask_abstract); + dropout_gen_mask->set_abstract(gen_mask_abstract); + dropout_gen_mask->set_scope(dropout->scope()); + common::AnfAlgo::CopyNodeAttrs(dropout, dropout_gen_mask); + common::AnfAlgo::SetNodeAttr(kSeed0AttrName, seed_0, dropout_gen_mask); + common::AnfAlgo::SetNodeAttr(kSeed1AttrName, seed_1, dropout_gen_mask); + return dropout_gen_mask; +} + const BaseRef DropoutForGE::DefinePattern() const { VarPtr x1 = std::make_shared(); return VectorRef({prim::kPrimDropout, x1}); @@ -61,21 +130,10 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP auto seed_0 = origin_prim->GetAttr(kSeed0AttrName); auto seed_1 = origin_prim->GetAttr(kSeed1AttrName); - auto shape = dropout_node->input(kInputIndexOne)->Shape(); - MS_EXCEPTION_IF_NULL(shape); - auto input_shape_ptr = shape->cast(); - MS_EXCEPTION_IF_NULL(input_shape_ptr); - if (input_shape_ptr->IsDynamic()) { - MS_LOG(EXCEPTION) << "Dropout does not support dynamic shape in GE backend for now"; - } - auto shape_vector = input_shape_ptr->shape(); - auto shape_value = MakeValue(shape_vector); - auto gen_mask_input_shape = NewValueNode(shape_value); - gen_mask_input_shape->set_abstract(shape_value->ToAbstract()); + auto input_shape_ptr = GetDropoutInputShape(dropout_node); auto keep_prob_node = NewValueNode(keep_prob); MS_EXCEPTION_IF_NULL(keep_prob_node); keep_prob_node->set_abstract(keep_prob->ToAbstract()); - AnfNodePtr gen_mask_input_prob = keep_prob_node; auto dtype_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(dropout_node, 0); if (dtype_id == TypeId::kNumberTypeFloat16) { @@ -94,15 +152,8 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP cast_node->set_abstract(cast_abstract); gen_mask_input_prob = cast_node; } - - auto mask_shape = CalGenMaskOutputShape(shape_vector); - auto dropout_gen_mask_node = node->func_graph()->NewCNode( - {NewValueNode(std::make_shared(kDropoutGenMaskOpName)), gen_mask_input_shape, gen_mask_input_prob}); - MS_EXCEPTION_IF_NULL(dropout_gen_mask_node); - common::AnfAlgo::SetNodeAttr(kSeed0AttrName, seed_0, dropout_gen_mask_node); - common::AnfAlgo::SetNodeAttr(kSeed1AttrName, seed_1, dropout_gen_mask_node); - auto gen_mask_abstract = std::make_shared(kUInt8, mask_shape); - dropout_gen_mask_node->set_abstract(gen_mask_abstract); + CNodePtr dropout_gen_mask_node = + CreateDropoutGenMaskCNode(graph, dropout_node, gen_mask_input_prob, input_shape_ptr, seed_0, seed_1); auto dropout_do_mask_node = node->func_graph()->NewCNode({NewValueNode(std::make_shared(kDropoutDoMaskOpName)), @@ -112,7 +163,7 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP std::vector make_tuple_input; make_tuple_input.push_back(do_mask_abstract); - make_tuple_input.push_back(gen_mask_abstract); + make_tuple_input.push_back(dropout_gen_mask_node->abstract()); std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), dropout_do_mask_node, dropout_gen_mask_node}; auto new_make_tuple_node = NewCNode(make_tuple_inputs, graph); diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h index 6aba7bb52ed..9d43c16d948 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h @@ -348,6 +348,7 @@ constexpr const char kNameMaxPoolV3[] = "MaxPoolV3"; constexpr const char kNameAvgPoolV2[] = "AvgPoolV2"; constexpr const char kNameShape[] = "Shape"; constexpr const char kNameTensorShape[] = "TensorShape"; +constexpr const char kNameDynamicShape[] = "DynamicShape"; constexpr const char kNameGather[] = "Gather"; constexpr const char kNameUnsqueeze[] = "Unsqueeze"; constexpr const char kNamePadV3[] = "PadV3"; diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.cc index 4eb289a8a4c..c21772b0404 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.cc @@ -55,6 +55,9 @@ REG_ADPT_DESC(Shape, kNameShape, ADPT_DESC(Shape)) // TensorShape REG_ADPT_DESC(TensorShape, kNameTensorShape, ADPT_DESC(Shape)) +// DynamicShape +REG_ADPT_DESC(DynamicShape, kNameDynamicShape, ADPT_DESC(Shape)) + // GetShape INPUT_MAP(GetShape) = EMPTY_INPUT_MAP; DYN_INPUT_MAP(GetShape) = {{1, DYN_INPUT_DESC(x)}};