forked from mindspore-Ecosystem/mindspore
!2078 replace first input of dropout_gen_mask of the subgraph instead of the whole sub graph
Merge pull request !2078 from yihuaijie/dev
This commit is contained in:
commit
b1ff4c15c2
|
@ -204,7 +204,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
|
||||
PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
||||
if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
|
||||
}
|
||||
|
||||
|
@ -215,8 +215,7 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask_cnode);
|
||||
if (dropout_gen_mask_cnode->inputs().size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
|
||||
if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
|
||||
}
|
||||
if (!IsValueNode<Primitive>(dropout_gen_mask_cnode->input(0))) {
|
||||
|
@ -233,11 +232,45 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
|
|||
return prim;
|
||||
}
|
||||
|
||||
void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
|
||||
}
|
||||
|
||||
AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX);
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
||||
if (!dropout_gen_mask->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode.";
|
||||
}
|
||||
|
||||
auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
|
||||
if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
|
||||
}
|
||||
|
||||
if (!IsValueNode<ValueTuple>(dropout_gen_mask_cnode->input(1))) {
|
||||
MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple.";
|
||||
}
|
||||
|
||||
FuncGraphPtr func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr.";
|
||||
}
|
||||
|
||||
ValuePtr new_shape = MakeValue(input_slice_shape);
|
||||
AnfNodePtr val = NewValueNode(new_shape);
|
||||
(void)manager->Replace(dropout_gen_mask_cnode->input(1), val);
|
||||
}
|
||||
|
||||
// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is
|
||||
// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
|
||||
// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
|
||||
// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
|
||||
Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
|
||||
std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
|
||||
std::vector<Operator> replace_ops;
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
PrimitivePtr prim = GetDropoutGenMaskPrim(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -260,15 +293,20 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
|
|||
if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) {
|
||||
MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1";
|
||||
}
|
||||
|
||||
Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
|
||||
int32_t seed_0 = GetValue<int32_t>(attr[SEED0]);
|
||||
int32_t seed_1 = GetValue<int32_t>(attr[SEED1]);
|
||||
if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) {
|
||||
seed_0 = SEED_NUM;
|
||||
seed_1 = SEED_NUM;
|
||||
SEED_NUM++;
|
||||
} else {
|
||||
SetGenMaskShape(cnode, input_slice_shape);
|
||||
MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape);
|
||||
return replace_ops;
|
||||
}
|
||||
|
||||
Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
|
||||
ValuePtr new_shape = MakeValue(input_slice_shape);
|
||||
Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0));
|
||||
Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1));
|
||||
|
@ -278,7 +316,8 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
|
|||
OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)};
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)};
|
||||
return replace_op;
|
||||
replace_ops.push_back(replace_op);
|
||||
return replace_ops;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,7 +41,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
|
|||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
|
||||
Operator GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
|
||||
std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -1876,11 +1876,15 @@ void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePt
|
|||
|
||||
DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast<DropoutDoMaskInfo>(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
Operator replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode);
|
||||
std::vector<Operator> replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode);
|
||||
if (replace_op.empty()) {
|
||||
MS_LOG(DEBUG) << "No need to replace dropout_gen_mask";
|
||||
return;
|
||||
}
|
||||
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
|
||||
}
|
||||
ReplaceOneOp(replace_op, cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
|
||||
ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
|
||||
}
|
||||
|
||||
void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
|
|
Loading…
Reference in New Issue