forked from mindspore-Ecosystem/mindspore
!11317 modify dropout unify mindir pass
From: @yuchaojie Reviewed-by: @jjfeing,@kisnwang Signed-off-by: @kisnwang
This commit is contained in:
commit
9147f57e40
|
@ -155,59 +155,20 @@ std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape)
|
|||
MS_LOG(INFO) << "Output_size: " << ret;
|
||||
return {ret};
|
||||
}
|
||||
|
||||
bool NeedUpdate(const CNodePtr &getitem_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(getitem_cnode);
|
||||
MS_EXCEPTION_IF_NULL(getitem_cnode->input(2));
|
||||
auto index_vnode = getitem_cnode->input(2)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(index_vnode);
|
||||
auto index_value = index_vnode->value();
|
||||
MS_EXCEPTION_IF_NULL(index_value);
|
||||
auto index = GetValue<int64_t>(index_value);
|
||||
return index == 1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef DropoutUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
auto prim = std::make_shared<Primitive>(kDropoutOpName);
|
||||
auto ref = VectorRef({prim, X});
|
||||
return VectorRef({prim::kPrimTupleGetItem, ref, Y});
|
||||
}
|
||||
|
||||
const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto tuple_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_cnode);
|
||||
auto dropout_node = tuple_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
|
||||
auto inputx_type_id = GetInputXDataType(dropout_node);
|
||||
auto inputx_shape = GetInputXShape(dropout_node);
|
||||
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
|
||||
|
||||
// CreateDropoutGenMask
|
||||
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
|
||||
shape_value, keep_prob_value};
|
||||
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
||||
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
|
||||
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
|
||||
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
|
||||
dropout_gen_mask->set_abstract(gen_mask_abstract);
|
||||
dropout_gen_mask->set_scope(node->scope());
|
||||
|
||||
// CreateDropoutDoMask
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
auto dropout_cnode = dropout_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_cnode);
|
||||
auto dropout_input = dropout_cnode->input(1);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
dropout_do_mask->set_scope(node->scope());
|
||||
|
||||
return dropout_do_mask;
|
||||
}
|
||||
|
||||
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
|
||||
const BaseRef DropoutAndDropoutGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
auto dropout_prim = std::make_shared<Primitive>(kDropoutOpName);
|
||||
|
@ -220,7 +181,7 @@ const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
|
|||
return VectorRef({dropout_grad_prim, grad_input_, ref1});
|
||||
}
|
||||
|
||||
const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -312,12 +273,73 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
return dropout_do_mask;
|
||||
}
|
||||
|
||||
const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const {
|
||||
const BaseRef DropoutUnifyMindIR0::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
auto prim = std::make_shared<Primitive>(kDropoutOpName);
|
||||
auto ref = VectorRef({prim, X});
|
||||
return VectorRef({prim::kPrimTupleGetItem, ref, Y});
|
||||
}
|
||||
|
||||
const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto tuple_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_cnode);
|
||||
if (!NeedUpdate(tuple_cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto dropout_node = tuple_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
auto inputx_type_id = GetInputXDataType(dropout_node);
|
||||
auto inputx_shape = GetInputXShape(dropout_node);
|
||||
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
|
||||
|
||||
// CreateDropoutGenMask
|
||||
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
|
||||
shape_value, keep_prob_value};
|
||||
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
||||
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
|
||||
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
|
||||
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
|
||||
dropout_gen_mask->set_abstract(gen_mask_abstract);
|
||||
dropout_gen_mask->set_scope(node->scope());
|
||||
|
||||
// CreateDropoutDoMask
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
auto dropout_cnode = dropout_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_cnode);
|
||||
auto dropout_input = dropout_cnode->input(1);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
dropout_do_mask->set_scope(node->scope());
|
||||
|
||||
// make tuple to replace dropout
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), dropout_do_mask, dropout_gen_mask};
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
(void)manager->Replace(dropout_node, make_tuple);
|
||||
|
||||
tuple_cnode->set_abstract(gen_mask_abstract);
|
||||
return tuple_cnode;
|
||||
}
|
||||
|
||||
const BaseRef DropoutUnifyMindIR1::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimDropout, X});
|
||||
}
|
||||
|
||||
const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -359,14 +381,14 @@ const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_gr
|
|||
return make_tuple;
|
||||
}
|
||||
|
||||
const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const {
|
||||
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName);
|
||||
return VectorRef({dropout_grad_prim, X, Y});
|
||||
}
|
||||
|
||||
const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -377,9 +399,25 @@ const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &fun
|
|||
auto grad_input_shape = GetInputXShape(dropout_grad_cnode);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id);
|
||||
|
||||
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter
|
||||
// in that scene, need to be updated.
|
||||
auto mask_input = dropout_grad_cnode->input(2);
|
||||
if (mask_input->isa<Parameter>()) {
|
||||
// update abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(mask_abstract);
|
||||
auto mask_shape = CalDropoutGenMaskOutput(grad_input_shape);
|
||||
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
mask_input->set_abstract(mask_abstract);
|
||||
// update kernel info
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get());
|
||||
}
|
||||
|
||||
// CreateDropoutDoMask
|
||||
auto grad_input = dropout_grad_cnode->input(1);
|
||||
auto mask_input = dropout_grad_cnode->input(2);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
grad_input, mask_input, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
|
|
|
@ -21,21 +21,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DropoutUnifyMindIR : public PatternProcessPass {
|
||||
class DropoutAndDropoutGradUnifyMindIR : public PatternProcessPass {
|
||||
public:
|
||||
explicit DropoutUnifyMindIR(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir", multigraph) {}
|
||||
~DropoutUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class DropoutGradUnifyMindIR : public PatternProcessPass {
|
||||
public:
|
||||
explicit DropoutGradUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("dropout_grad_unify_mindir", multigraph) {
|
||||
explicit DropoutAndDropoutGradUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("dropout_and_dropoutgrad_unify_mindir", multigraph) {
|
||||
grad_input_ = std::make_shared<Var>();
|
||||
}
|
||||
~DropoutGradUnifyMindIR() override = default;
|
||||
~DropoutAndDropoutGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
|
@ -43,20 +35,27 @@ class DropoutGradUnifyMindIR : public PatternProcessPass {
|
|||
VarPtr grad_input_;
|
||||
};
|
||||
|
||||
class DropoutUnifyMindIRPynative : public PatternProcessPass {
|
||||
class DropoutUnifyMindIR0 : public PatternProcessPass {
|
||||
public:
|
||||
explicit DropoutUnifyMindIRPynative(bool multigraph = true)
|
||||
: PatternProcessPass("dropout_unify_mindir_pynative", multigraph) {}
|
||||
~DropoutUnifyMindIRPynative() override = default;
|
||||
explicit DropoutUnifyMindIR0(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir0", multigraph) {}
|
||||
~DropoutUnifyMindIR0() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class DropoutGradUnifyMindIRPynative : public PatternProcessPass {
|
||||
class DropoutUnifyMindIR1 : public PatternProcessPass {
|
||||
public:
|
||||
explicit DropoutGradUnifyMindIRPynative(bool multigraph = true)
|
||||
: PatternProcessPass("dropout_grad_unify_mindir_pynative", multigraph) {}
|
||||
~DropoutGradUnifyMindIRPynative() override = default;
|
||||
explicit DropoutUnifyMindIR1(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir1", multigraph) {}
|
||||
~DropoutUnifyMindIR1() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class DropoutGradUnifyMindIR : public PatternProcessPass {
|
||||
public:
|
||||
explicit DropoutGradUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("dropoutgrad_unify_mindir", multigraph) {}
|
||||
~DropoutGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
|
|
@ -236,17 +236,17 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutAndDropoutGradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR0>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
} else {
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
}
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR1>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
||||
|
||||
optimizer->AddPassManager(unify_mindir_pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
|
|
Loading…
Reference in New Issue