!11317 modify dropout unify mindir pass

From: @yuchaojie
Reviewed-by: @jjfeing,@kisnwang
Signed-off-by: @kisnwang
This commit is contained in:
mindspore-ci-bot 2021-02-25 09:29:03 +08:00 committed by Gitee
commit 9147f57e40
3 changed files with 121 additions and 84 deletions

View File

@ -155,59 +155,20 @@ std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape)
MS_LOG(INFO) << "Output_size: " << ret;
return {ret};
}
bool NeedUpdate(const CNodePtr &getitem_cnode) {
MS_EXCEPTION_IF_NULL(getitem_cnode);
MS_EXCEPTION_IF_NULL(getitem_cnode->input(2));
auto index_vnode = getitem_cnode->input(2)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(index_vnode);
auto index_value = index_vnode->value();
MS_EXCEPTION_IF_NULL(index_value);
auto index = GetValue<int64_t>(index_value);
return index == 1;
}
} // namespace
const BaseRef DropoutUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
auto prim = std::make_shared<Primitive>(kDropoutOpName);
auto ref = VectorRef({prim, X});
return VectorRef({prim::kPrimTupleGetItem, ref, Y});
}
const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto tuple_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_cnode);
auto dropout_node = tuple_cnode->input(1);
MS_EXCEPTION_IF_NULL(dropout_node);
auto inputx_type_id = GetInputXDataType(dropout_node);
auto inputx_shape = GetInputXShape(dropout_node);
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
// CreateDropoutGenMask
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
shape_value, keep_prob_value};
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(node->scope());
// CreateDropoutDoMask
MS_EXCEPTION_IF_NULL(dropout_node);
auto dropout_cnode = dropout_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_cnode);
auto dropout_input = dropout_cnode->input(1);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
dropout_do_mask->set_scope(node->scope());
return dropout_do_mask;
}
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
const BaseRef DropoutAndDropoutGradUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
auto dropout_prim = std::make_shared<Primitive>(kDropoutOpName);
@ -220,8 +181,8 @@ const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
return VectorRef({dropout_grad_prim, grad_input_, ref1});
}
const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto dropout_grad_cnode = node->cast<CNodePtr>();
@ -312,13 +273,74 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
return dropout_do_mask;
}
const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const {
const BaseRef DropoutUnifyMindIR0::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
auto prim = std::make_shared<Primitive>(kDropoutOpName);
auto ref = VectorRef({prim, X});
return VectorRef({prim::kPrimTupleGetItem, ref, Y});
}
const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto tuple_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_cnode);
if (!NeedUpdate(tuple_cnode)) {
return nullptr;
}
auto dropout_node = tuple_cnode->input(1);
MS_EXCEPTION_IF_NULL(dropout_node);
auto inputx_type_id = GetInputXDataType(dropout_node);
auto inputx_shape = GetInputXShape(dropout_node);
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
// CreateDropoutGenMask
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
shape_value, keep_prob_value};
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(node->scope());
// CreateDropoutDoMask
MS_EXCEPTION_IF_NULL(dropout_node);
auto dropout_cnode = dropout_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_cnode);
auto dropout_input = dropout_cnode->input(1);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
dropout_do_mask->set_scope(node->scope());
// make tuple to replace dropout
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), dropout_do_mask, dropout_gen_mask};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
(void)manager->Replace(dropout_node, make_tuple);
tuple_cnode->set_abstract(gen_mask_abstract);
return tuple_cnode;
}
const BaseRef DropoutUnifyMindIR1::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return VectorRef({prim::kPrimDropout, X});
}
const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto dropout_node = node->cast<CNodePtr>();
@ -359,15 +381,15 @@ const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_gr
return make_tuple;
}
const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const {
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName);
return VectorRef({dropout_grad_prim, X, Y});
}
const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto dropout_grad_cnode = node->cast<CNodePtr>();
@ -377,9 +399,25 @@ const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &fun
auto grad_input_shape = GetInputXShape(dropout_grad_cnode);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id);
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter
// in that scene, need to be updated.
auto mask_input = dropout_grad_cnode->input(2);
if (mask_input->isa<Parameter>()) {
// update abstract
auto mask_abstract = mask_input->abstract();
MS_EXCEPTION_IF_NULL(mask_abstract);
auto mask_shape = CalDropoutGenMaskOutput(grad_input_shape);
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
mask_input->set_abstract(mask_abstract);
// update kernel info
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get());
}
// CreateDropoutDoMask
auto grad_input = dropout_grad_cnode->input(1);
auto mask_input = dropout_grad_cnode->input(2);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
grad_input, mask_input, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);

View File

@ -21,21 +21,13 @@
namespace mindspore {
namespace opt {
class DropoutUnifyMindIR : public PatternProcessPass {
class DropoutAndDropoutGradUnifyMindIR : public PatternProcessPass {
public:
explicit DropoutUnifyMindIR(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir", multigraph) {}
~DropoutUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class DropoutGradUnifyMindIR : public PatternProcessPass {
public:
explicit DropoutGradUnifyMindIR(bool multigraph = true)
: PatternProcessPass("dropout_grad_unify_mindir", multigraph) {
explicit DropoutAndDropoutGradUnifyMindIR(bool multigraph = true)
: PatternProcessPass("dropout_and_dropoutgrad_unify_mindir", multigraph) {
grad_input_ = std::make_shared<Var>();
}
~DropoutGradUnifyMindIR() override = default;
~DropoutAndDropoutGradUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
@ -43,20 +35,27 @@ class DropoutGradUnifyMindIR : public PatternProcessPass {
VarPtr grad_input_;
};
class DropoutUnifyMindIRPynative : public PatternProcessPass {
class DropoutUnifyMindIR0 : public PatternProcessPass {
public:
explicit DropoutUnifyMindIRPynative(bool multigraph = true)
: PatternProcessPass("dropout_unify_mindir_pynative", multigraph) {}
~DropoutUnifyMindIRPynative() override = default;
explicit DropoutUnifyMindIR0(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir0", multigraph) {}
~DropoutUnifyMindIR0() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class DropoutGradUnifyMindIRPynative : public PatternProcessPass {
class DropoutUnifyMindIR1 : public PatternProcessPass {
public:
explicit DropoutGradUnifyMindIRPynative(bool multigraph = true)
: PatternProcessPass("dropout_grad_unify_mindir_pynative", multigraph) {}
~DropoutGradUnifyMindIRPynative() override = default;
explicit DropoutUnifyMindIR1(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir1", multigraph) {}
~DropoutUnifyMindIR1() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class DropoutGradUnifyMindIR : public PatternProcessPass {
public:
explicit DropoutGradUnifyMindIR(bool multigraph = true)
: PatternProcessPass("dropoutgrad_unify_mindir", multigraph) {}
~DropoutGradUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};

View File

@ -236,17 +236,17 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutAndDropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR0>());
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
} else {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
}
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR1>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
optimizer->AddPassManager(unify_mindir_pm);
(void)optimizer->Optimize(graph);