dropout do mask only replace first input of

dropout_gen_mask of the subgraph instead of
the whole sub graph.
This commit is contained in:
Yi Huaijie 2020-06-12 17:20:14 +08:00
parent 553432c968
commit 6c85fc9f9f
3 changed files with 52 additions and 9 deletions

View File

@ -204,7 +204,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) {
PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
} }
@ -215,8 +215,7 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
} }
auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>(); auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_gen_mask_cnode); if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
if (dropout_gen_mask_cnode->inputs().size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
} }
if (!IsValueNode<Primitive>(dropout_gen_mask_cnode->input(0))) { if (!IsValueNode<Primitive>(dropout_gen_mask_cnode->input(0))) {
@ -233,11 +232,45 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
return prim; return prim;
} }
void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
}
AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
if (!dropout_gen_mask->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode.";
}
auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
}
if (!IsValueNode<ValueTuple>(dropout_gen_mask_cnode->input(1))) {
MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple.";
}
FuncGraphPtr func_graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = func_graph->manager();
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr.";
}
ValuePtr new_shape = MakeValue(input_slice_shape);
AnfNodePtr val = NewValueNode(new_shape);
(void)manager->Replace(dropout_gen_mask_cnode->input(1), val);
}
// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is // DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is
// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
std::vector<Operator> replace_ops;
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); PrimitivePtr prim = GetDropoutGenMaskPrim(cnode);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
@ -260,15 +293,20 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) { if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) {
MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1"; MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1";
} }
Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
int32_t seed_0 = GetValue<int32_t>(attr[SEED0]); int32_t seed_0 = GetValue<int32_t>(attr[SEED0]);
int32_t seed_1 = GetValue<int32_t>(attr[SEED1]); int32_t seed_1 = GetValue<int32_t>(attr[SEED1]);
if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) { if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) {
seed_0 = SEED_NUM; seed_0 = SEED_NUM;
seed_1 = SEED_NUM; seed_1 = SEED_NUM;
SEED_NUM++; SEED_NUM++;
} else {
SetGenMaskShape(cnode, input_slice_shape);
MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape);
return replace_ops;
} }
Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
ValuePtr new_shape = MakeValue(input_slice_shape); ValuePtr new_shape = MakeValue(input_slice_shape);
Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0)); Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0));
Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1)); Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1));
@ -278,7 +316,8 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)}; OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)};
OperatorArgs args = std::make_pair(attrs, params); OperatorArgs args = std::make_pair(attrs, params);
Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)}; Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)};
return replace_op; replace_ops.push_back(replace_op);
return replace_ops;
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -41,7 +41,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
Operator GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
protected: protected:
Status CheckStrategy(const StrategyPtr &strategy) override; Status CheckStrategy(const StrategyPtr &strategy) override;

View File

@ -1876,11 +1876,15 @@ void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePt
DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast<DropoutDoMaskInfo>(distribute_operator); DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast<DropoutDoMaskInfo>(distribute_operator);
MS_EXCEPTION_IF_NULL(dropout_do_mask); MS_EXCEPTION_IF_NULL(dropout_do_mask);
Operator replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); std::vector<Operator> replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode);
if (replace_op.empty()) {
MS_LOG(DEBUG) << "No need to replace dropout_gen_mask";
return;
}
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
} }
ReplaceOneOp(replace_op, cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>()); ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
} }
void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {