From 389da5452577d6530d4acd0fc9913bdfb5017f00 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Tue, 15 Dec 2020 16:21:23 +0800 Subject: [PATCH] fix dropout unify_mindir pass --- .../ascend/mindir/dropout_unify_mindir.cc | 392 ++++++++++++------ .../ascend/mindir/dropout_unify_mindir.h | 18 + .../ccsrc/backend/session/ascend_session.cc | 9 + .../ccsrc/backend/session/session_basic.cc | 6 +- mindspore/nn/layer/basic.py | 23 +- 5 files changed, 302 insertions(+), 146 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc index d848c910ba0..89a8d5fbdbb 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc @@ -15,7 +15,9 @@ */ #include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h" +#include #include +#include #include #include #include @@ -23,45 +25,69 @@ #include "backend/session/anf_runtime_algorithm.h" #include "utils/log_adapter.h" +/* + DropoutGenMask: + attr: seed0 seed1: + input: 1.shape <>; + 2. keep_prob: type base on inputx type, if x in float/float16, then use this type, else use float16; + output: shape: (count + 127) % 128 * 16 + */ +namespace mindspore::opt { +namespace { constexpr auto kKeepProb = "keep_prob"; constexpr auto kSeed0 = "Seed0"; constexpr auto kSeed1 = "Seed1"; constexpr auto kUint8BitSize = 8; - -namespace mindspore::opt { +constexpr int64_t kMaskAlignNum = 128; +constexpr int64_t kMaskMultiNum = 16; constexpr size_t kFloat16Len = 2; // size of float16 -namespace { -AnfNodePtr GetDropoutKeepProb(const AnfNodePtr &node, float *keep_prob) { - MS_LOG(INFO) << "GetDropoutNodeInfo start."; +constexpr size_t kInt64Len = 8; // size of int64 + +TypeId GetInputXDataType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(keep_prob); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode) || !AnfAlgo::HasNodeAttr(kSeed0, cnode) || - !AnfAlgo::HasNodeAttr(kSeed1, cnode)) { - MS_LOG(EXCEPTION) << "Dropout node does nothave attr: keep_prob or seed0 or seed1."; + auto dropout_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + if (dropout_input_type != kNumberTypeFloat32 && dropout_input_type != kNumberTypeFloat && + dropout_input_type != kNumberTypeFloat16) { + dropout_input_type = kNumberTypeFloat16; } - *keep_prob = AnfAlgo::GetNodeAttr(node, kKeepProb); - MS_LOG(INFO) << "keep_prob: " << *keep_prob; - // return dropout input. maybe tensor or pre cnode output - return cnode->input(1); + MS_LOG(INFO) << "Dropout input data type: " << TypeIdLabel(dropout_input_type); + return dropout_input_type; } -ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float &keep_prob, const TypePtr &dtype) { - MS_LOG(INFO) << "CreateKeepPorbValueNode start."; +std::vector GetInputXShape(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector shapes; + auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong); + return shapes; +} + +ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, TypeId type_id) { MS_EXCEPTION_IF_NULL(func_graph); - auto kernel_graph = func_graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // Step1: get keep_prob + if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode)) { + MS_LOG(EXCEPTION) << "Dropout node does not have attr: keep_prob."; + } + if (AnfAlgo::GetCNodePrimitive(cnode)->ToString() == kDropoutOpName) { + if (!AnfAlgo::HasNodeAttr(kSeed0, cnode) || !AnfAlgo::HasNodeAttr(kSeed1, cnode)) { + MS_LOG(EXCEPTION) << "Dropout node does not have attr: seed0 or seed1."; + } + } + auto keep_prob = AnfAlgo::GetNodeAttr(node, kKeepProb); + MS_LOG(INFO) << "Keep_prob value: " << keep_prob; + std::vector keep_prob_shape = {}; - ShapeVector shape = {}; - auto keep_prob_tensor = std::make_shared(dtype->type_id(), keep_prob_shape); + auto keep_prob_tensor = std::make_shared(type_id, keep_prob_shape); MS_EXCEPTION_IF_NULL(keep_prob_tensor); auto data_ptr = keep_prob_tensor->data_c(); MS_EXCEPTION_IF_NULL(data_ptr); // keep_prob's datatype is same with input data - if (dtype->type_id() == kNumberTypeFloat16) { - float16 half_data = float16(keep_prob); - auto ret_code = memcpy_s(data_ptr, kFloat16Len, &half_data, kFloat16Len); + if (type_id == kNumberTypeFloat16) { + auto half_data = float16(keep_prob); + auto ret_code = memcpy_s(data_ptr, static_cast(keep_prob_tensor->data().nbytes()), &half_data, kFloat16Len); if (ret_code != 0) { MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; } @@ -69,59 +95,65 @@ ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float auto *val = reinterpret_cast(data_ptr); *val = keep_prob; } - auto abstract = std::make_shared(dtype, shape); + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto abstract = std::make_shared(TypeIdToType(type_id), keep_prob_shape); auto keep_prob_value = kernel_graph->NewValueNode(abstract, keep_prob_tensor); MS_EXCEPTION_IF_NULL(keep_prob_value); kernel_graph->AddValueNodeToGraph(keep_prob_value); return keep_prob_value; } -std::vector GetInputShape(const AnfNodePtr &node, const AnfNodePtr &dropout_input) { - MS_LOG(INFO) << "GetInputShape start."; - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(dropout_input); - std::vector shapes; - if (dropout_input->isa()) { - MS_LOG(INFO) << "Dropout input from parameter node."; - // single test case - auto dropout_input_value = dropout_input->cast(); - MS_EXCEPTION_IF_NULL(dropout_input_value); - MS_EXCEPTION_IF_NULL(dropout_input_value->Shape()); - auto shape = dropout_input_value->Shape()->cast(); - MS_EXCEPTION_IF_NULL(shape); - return shape->shape(); - } else if (dropout_input->isa()) { - MS_LOG(INFO) << "Dropout input from cnode."; - auto dropout_input_node = dropout_input->cast(); - MS_EXCEPTION_IF_NULL(dropout_input_node); - auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong); - return shapes; - } else { - MS_LOG(ERROR) << "Dropout input is not parameter or cnode."; - return {}; - } -} - -ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector &shape) { +ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector &shape, + bool is_pynative = false) { MS_LOG(INFO) << "CreateShapeValueNode start."; MS_EXCEPTION_IF_NULL(func_graph); auto kernel_graph = func_graph->cast(); MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector dim_values{}; - abstract::AbstractBasePtrList abs{}; - for (const auto &dim : shape) { - dim_values.push_back(MakeValue(dim)); - abs.push_back(std::make_shared(dim)); + ValuePtr shape_value = nullptr; + AbstractBasePtr abstract = nullptr; + if (is_pynative) { + // pynative mode need to create tensor + int64_t shape_dim = SizeToLong(shape.size()); + std::vector shape_vec_shape = {shape_dim}; + auto shape_tensor = std::make_shared(kNumberTypeInt64, shape_vec_shape); + MS_EXCEPTION_IF_NULL(shape_tensor); + auto data_ptr = shape_tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + auto elem_num = shape.size() * kInt64Len; + auto ret_code = memcpy_s(data_ptr, static_cast(shape_tensor->data().nbytes()), &shape[0], elem_num); + if (ret_code != 0) { + MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; + } + shape_value = shape_tensor; + abstract = std::make_shared(kInt64, shape_vec_shape); + } else { + std::vector dim_values{}; + abstract::AbstractBasePtrList abs{}; + for (const auto &dim : shape) { + dim_values.push_back(MakeValue(dim)); + abs.push_back(std::make_shared(dim)); + } + shape_value = std::make_shared(dim_values); + abstract = std::make_shared(abs); } - auto shape_value_tuple = std::make_shared(dim_values); - MS_EXCEPTION_IF_NULL(shape_value_tuple); - auto abstract = std::make_shared(abs); - MS_EXCEPTION_IF_NULL(abstract); - auto shape_value = kernel_graph->NewValueNode(abstract, shape_value_tuple); MS_EXCEPTION_IF_NULL(shape_value); - kernel_graph->AddValueNodeToGraph(shape_value); - return shape_value; + MS_EXCEPTION_IF_NULL(abstract); + auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value); + MS_EXCEPTION_IF_NULL(shape_value_node); + kernel_graph->AddValueNodeToGraph(shape_value_node); + return shape_value_node; +} + +std::vector CalDropoutGenMaskOutput(const std::vector &shape) { + auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + auto output_count = output_size / kMaskAlignNum; + if (output_size % kMaskAlignNum != 0) { + output_count++; + } + auto ret = output_count * kMaskMultiNum; + MS_LOG(INFO) << "Output_size: " << ret; + return {ret}; } } // namespace @@ -141,34 +173,34 @@ const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, con MS_EXCEPTION_IF_NULL(tuple_cnode); auto dropout_node = tuple_cnode->input(1); MS_EXCEPTION_IF_NULL(dropout_node); - float keep_prob = 0; - auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob); - auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32; - auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype); - auto shape = GetInputShape(dropout_node, dropout_input); - auto shape_value = CreateShapeValueNode(func_graph, shape); + + auto inputx_type_id = GetInputXDataType(dropout_node); + auto inputx_shape = GetInputXShape(dropout_node); + auto shape_value = CreateShapeValueNode(func_graph, inputx_shape); + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); + // CreateDropoutGenMask - auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); - output_size = output_size / kUint8BitSize; - MS_LOG(INFO) << "Output_size: " << output_size; std::vector dropout_gen_mask_inputs{NewValueNode(std::make_shared(kDropoutGenMaskOpName)), shape_value, keep_prob_value}; CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); MS_EXCEPTION_IF_NULL(dropout_gen_mask); AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); - ShapeVector dropout_gen_mask_output = {output_size}; - auto gen_mask_abstract = std::make_shared(kUInt8, dropout_gen_mask_output); + auto output_shape = CalDropoutGenMaskOutput(inputx_shape); + auto gen_mask_abstract = std::make_shared(kUInt8, output_shape); MS_EXCEPTION_IF_NULL(gen_mask_abstract); dropout_gen_mask->set_abstract(gen_mask_abstract); dropout_gen_mask->set_scope(node->scope()); // CreateDropoutDoMask + MS_EXCEPTION_IF_NULL(dropout_node); + auto dropout_cnode = dropout_node->cast(); + MS_EXCEPTION_IF_NULL(dropout_cnode); + auto dropout_input = dropout_cnode->input(1); std::vector dropout_do_mask_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), dropout_input, dropout_gen_mask, keep_prob_value}; auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); MS_EXCEPTION_IF_NULL(dropout_do_mask); - ShapeVector dropout_do_mask_output = shape; - auto do_mask_abstract = std::make_shared(dropout_dtype, dropout_do_mask_output); + auto do_mask_abstract = std::make_shared(TypeIdToType(inputx_type_id), inputx_shape); dropout_do_mask->set_abstract(do_mask_abstract); dropout_do_mask->set_scope(node->scope()); @@ -178,8 +210,6 @@ const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, con const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { VarPtr X = std::make_shared(); VarPtr Y = std::make_shared(); - MS_EXCEPTION_IF_NULL(X); - MS_EXCEPTION_IF_NULL(Y); auto dropout_prim = std::make_shared(kDropoutOpName); auto tuple_getitem_prim = prim::kPrimTupleGetItem; auto dropout_grad_prim = std::make_shared(kDropoutGradOpName); @@ -194,58 +224,74 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); - auto dropout_grad = node->cast(); - MS_EXCEPTION_IF_NULL(dropout_grad); - auto tuple_getitem = dropout_grad->input(2); - MS_EXCEPTION_IF_NULL(tuple_getitem); - auto tuple_getitem_cnode = tuple_getitem->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); - auto dropout_node = tuple_getitem_cnode->input(1); + auto dropout_grad_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(dropout_grad_cnode); + auto getitem1_node = dropout_grad_cnode->input(2); + MS_EXCEPTION_IF_NULL(getitem1_node); + auto getitem1_cnode = getitem1_node->cast(); + MS_EXCEPTION_IF_NULL(getitem1_cnode); + auto dropout_node = getitem1_cnode->input(1); MS_EXCEPTION_IF_NULL(dropout_node); - float keep_prob = 0; - auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob); - auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32; - auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype); - auto shape = GetInputShape(dropout_node, dropout_input); - auto shape_value = CreateShapeValueNode(func_graph, shape); + auto dropout_cnode = dropout_node->cast(); + MS_EXCEPTION_IF_NULL(dropout_cnode); + + auto inputx_type_id = GetInputXDataType(dropout_node); + auto inputx_shape = GetInputXShape(dropout_node); + auto shape_value = CreateShapeValueNode(func_graph, inputx_shape); + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); + // CreateDropoutGenMask - auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); - output_size = output_size / kUint8BitSize; - MS_LOG(INFO) << "Output_size: " << output_size; std::vector dropout_gen_mask_inputs{NewValueNode(std::make_shared(kDropoutGenMaskOpName)), shape_value, keep_prob_value}; CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); MS_EXCEPTION_IF_NULL(dropout_gen_mask); AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); - ShapeVector dropout_gen_mask_output = {output_size}; - auto gen_mask_abstract = std::make_shared(kUInt8, dropout_gen_mask_output); + auto output_shape = CalDropoutGenMaskOutput(inputx_shape); + auto gen_mask_abstract = std::make_shared(kUInt8, output_shape); MS_EXCEPTION_IF_NULL(gen_mask_abstract); dropout_gen_mask->set_abstract(gen_mask_abstract); - dropout_gen_mask->set_scope(dropout_node->scope()); - // AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); + dropout_gen_mask->set_scope(node->scope()); // CreateDropoutDoMask-forward auto manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); auto &node_users = manager->node_users(); auto iter = node_users.find(dropout_node); + CNodePtr dropout_do_mask1 = nullptr; if (iter != node_users.end()) { for (auto &node_index : iter->second) { - // Dropout has two outputs, so output node is tuple_getitem - auto tuple_getitem_cnode2 = node_index.first->cast(); - // check if Dropout's first output, which is used by forward, is used. - auto getitem_index = GetValue(tuple_getitem_cnode2->input(2)->cast()->value()); - if (getitem_index == 0) { - std::vector dropout_do_mask1_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), - dropout_input, dropout_gen_mask, keep_prob_value}; - auto dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs); - MS_EXCEPTION_IF_NULL(dropout_do_mask1); - ShapeVector dropout_do_mask1_output = shape; - auto do_mask_abstract1 = std::make_shared(dropout_dtype, dropout_do_mask1_output); - dropout_do_mask1->set_abstract(do_mask_abstract1); - dropout_do_mask1->set_scope(dropout_node->scope()); - (void)manager->Replace(tuple_getitem_cnode2, dropout_do_mask1); - break; + auto used_node = node_index.first; + if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimTupleGetItem)) { + // check if Dropout's first output, which is used by forward, is used + if (AnfAlgo::GetTupleGetItemOutIndex(used_node->cast()) == 0) { + // if Dropout's first output is used, create forward DropoutDoMask + auto dropout_input = dropout_cnode->input(1); + std::vector dropout_do_mask1_inputs{ + NewValueNode(std::make_shared(kDropoutDoMaskOpName)), dropout_input, dropout_gen_mask, + keep_prob_value}; + dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs); + MS_EXCEPTION_IF_NULL(dropout_do_mask1); + auto do_mask_abstract1 = + std::make_shared(TypeIdToType(inputx_type_id), inputx_shape); + dropout_do_mask1->set_abstract(do_mask_abstract1); + dropout_do_mask1->set_scope(dropout_node->scope()); + (void)manager->Replace(used_node, dropout_do_mask1); + break; + } + } + } + } + if (dropout_do_mask1 != nullptr) { + // Dropout is used by ControlDepend in some situation, need to replace ControlDepend. + auto &users = manager->node_users(); + iter = users.find(dropout_node); + if (iter != users.end()) { + for (auto &node_index : iter->second) { + auto used_node = node_index.first; + if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) { + (void)manager->Replace(used_node, dropout_do_mask1); + break; + } } } } @@ -254,16 +300,112 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, if (equiv->find(grad_input_) == equiv->end()) { MS_LOG(EXCEPTION) << "Can not find grad_input in this pattern."; } - auto grad_input = utils::cast((*equiv)[grad_input_]); - std::vector dropout_do_mask2_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), - grad_input, dropout_gen_mask, keep_prob_value}; - auto dropout_do_mask2 = func_graph->NewCNode(dropout_do_mask2_inputs); - MS_EXCEPTION_IF_NULL(dropout_do_mask2); - ShapeVector dropout_do_mask2_output = shape; - auto do_mask_abstract2 = std::make_shared(dropout_dtype, dropout_do_mask2_output); - dropout_do_mask2->set_abstract(do_mask_abstract2); - dropout_do_mask2->set_scope(node->scope()); + auto dropout_grad_input = utils::cast((*equiv)[grad_input_]); + std::vector dropout_do_mask_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), + dropout_grad_input, dropout_gen_mask, keep_prob_value}; + auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_do_mask); + auto do_mask_abstract = std::make_shared(TypeIdToType(inputx_type_id), inputx_shape); + dropout_do_mask->set_abstract(do_mask_abstract); + dropout_do_mask->set_scope(node->scope()); - return dropout_do_mask2; + return dropout_do_mask; +} + +const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + VarPtr Z = std::make_shared(); + auto dropout = VectorRef({prim::kPrimDropout, X}); + auto getitem0 = VectorRef({prim::kPrimTupleGetItem, dropout, Y}); + auto getitem1 = VectorRef({prim::kPrimTupleGetItem, dropout, Z}); + auto maketuple = VectorRef({prim::kPrimMakeTuple, getitem0, getitem1}); + return maketuple; +} + +const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto maketuple_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(maketuple_cnode); + auto getitem0_node = maketuple_cnode->input(1); + MS_EXCEPTION_IF_NULL(getitem0_node); + auto getitem1_node = maketuple_cnode->input(2); + MS_EXCEPTION_IF_NULL(getitem1_node); + auto getitem1_cnode = getitem1_node->cast(); + MS_EXCEPTION_IF_NULL(getitem1_cnode); + auto dropout_node = getitem1_cnode->input(1); + MS_EXCEPTION_IF_NULL(dropout_node); + + auto inputx_type_id = GetInputXDataType(dropout_node); + auto inputx_shape = GetInputXShape(dropout_node); + auto shape_value = CreateShapeValueNode(func_graph, inputx_shape, true); + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); + + // CreateDropoutGenMask + std::vector dropout_gen_mask_inputs{NewValueNode(std::make_shared(kDropoutGenMaskOpName)), + shape_value, keep_prob_value}; + CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); + auto output_shape = CalDropoutGenMaskOutput(inputx_shape); + auto gen_mask_abstract = std::make_shared(kUInt8, output_shape); + MS_EXCEPTION_IF_NULL(gen_mask_abstract); + dropout_gen_mask->set_abstract(gen_mask_abstract); + dropout_gen_mask->set_scope(node->scope()); + + // CreateDropoutDoMask + MS_EXCEPTION_IF_NULL(dropout_node); + auto dropout_cnode = dropout_node->cast(); + MS_EXCEPTION_IF_NULL(dropout_cnode); + auto dropout_input = dropout_cnode->input(1); + std::vector dropout_do_mask_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), + dropout_input, dropout_gen_mask, keep_prob_value}; + auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_do_mask); + auto do_mask_abstract = std::make_shared(TypeIdToType(inputx_type_id), inputx_shape); + dropout_do_mask->set_abstract(do_mask_abstract); + dropout_do_mask->set_scope(node->scope()); + + // replace genmask and domask + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(getitem0_node, dropout_do_mask); + (void)manager->Replace(getitem1_node, dropout_gen_mask); + + return node; +} + +const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + auto dropout_grad_prim = std::make_shared(kDropoutGradOpName); + return VectorRef({dropout_grad_prim, X, Y}); +} + +const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto dropout_grad_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(dropout_grad_cnode); + + auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode); + auto grad_input_shape = GetInputXShape(dropout_grad_cnode); + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id); + + // CreateDropoutDoMask + auto grad_input = dropout_grad_cnode->input(1); + auto mask_input = dropout_grad_cnode->input(2); + std::vector dropout_do_mask_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), + grad_input, mask_input, keep_prob_value}; + auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_do_mask); + auto do_mask_abstract = + std::make_shared(TypeIdToType(grad_input_type_id), grad_input_shape); + dropout_do_mask->set_abstract(do_mask_abstract); + dropout_do_mask->set_scope(node->scope()); + return dropout_do_mask; } } // namespace mindspore::opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h index 553796376cd..3a078a1d312 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h @@ -42,6 +42,24 @@ class DropoutGradUnifyMindIR : public PatternProcessPass { private: VarPtr grad_input_; }; + +class DropoutUnifyMindIRPynative : public PatternProcessPass { + public: + explicit DropoutUnifyMindIRPynative(bool multigraph = true) + : PatternProcessPass("dropout_unify_mindir_pynative", multigraph) {} + ~DropoutUnifyMindIRPynative() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class DropoutGradUnifyMindIRPynative : public PatternProcessPass { + public: + explicit DropoutGradUnifyMindIRPynative(bool multigraph = true) + : PatternProcessPass("dropout_grad_unify_mindir_pynative", multigraph) {} + ~DropoutGradUnifyMindIRPynative() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_ diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index fdc6497e225..a110d45e63c 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -444,6 +444,15 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + } else { + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + } optimizer->AddPassManager(unify_mindir_pm); (void)optimizer->Optimize(graph); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index a43d5aa055c..33fe1551844 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1633,7 +1633,11 @@ std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInf manager->AddFuncGraph(graph); graph->set_manager(manager); } - UnifyMindIR(graph); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_INFER)) { + UnifyMindIR(graph); + } return graph; } diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index e6ae679b6f6..30205b0b0bf 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from mindspore._checkparam import Rel, Validator -from mindspore import context from ..cell import Cell from .activation import get_activation @@ -146,33 +145,17 @@ class Dropout(Cell): seed0, seed1 = _get_graph_seed(0, "dropout") self.seed0 = seed0 self.seed1 = seed1 - self.dtype = dtype - self.get_shape = P.Shape() - self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) - self.dropout_do_mask = P.DropoutDoMask() - self.cast = P.Cast() - self.is_ascend = context.get_context('device_target') in ["Ascend"] - self.dropout = P.Dropout(keep_prob) + self.dropout = P.Dropout(keep_prob, seed0, seed1) def construct(self, x): if not self.training: return x - if not self.is_ascend: - out, _ = self.dropout(x) - return out - if self.keep_prob == 1: return x - shape = self.get_shape(x) - dtype = P.DType()(x) - if _is_float_dtype(dtype): - keep_prob = self.cast(self.keep_prob, dtype) - else: - keep_prob = self.cast(self.keep_prob, mstype.float16) - output = self.dropout_gen_mask(shape, keep_prob) - return self.dropout_do_mask(x, output, keep_prob) + out, _ = self.dropout(x) + return out def extend_repr(self): return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)