!5623 Fix run error when there is a Depend or ControlDepend on BatchNorm
Merge pull request !5623 from huanghui/bn-infer
This commit is contained in:
commit
43c1147843
|
@ -79,7 +79,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
|
||||
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
|
||||
: trans::IsNeedPadding(input_format, input_node_out_shape.size());
|
||||
|
||||
if (!need_padding) {
|
||||
// don't need padding insert transdata only
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
||||
|
|
|
@ -121,7 +121,9 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
|
|||
if (!NeedFusion(graph, node, &batchnorm)) {
|
||||
return nullptr;
|
||||
}
|
||||
return CreateBNInfer(graph, batchnorm, node);
|
||||
auto bn_infer = CreateBNInfer(graph, batchnorm, node);
|
||||
TransferDepend(batchnorm, graph, bn_infer);
|
||||
return bn_infer;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -81,7 +81,7 @@ bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad
|
|||
return true;
|
||||
}
|
||||
|
||||
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) {
|
||||
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm_grad) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto tuple_getitem = node->cast<CNodePtr>();
|
||||
|
@ -93,12 +93,12 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
|
|||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(batchnormgrad_anf);
|
||||
MS_EXCEPTION_IF_NULL(batchnormgrad);
|
||||
*batchnormgrad = batchnormgrad_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(*batchnormgrad);
|
||||
return CheckBatchNormGrad(graph, *batchnormgrad);
|
||||
AnfNodePtr batchnorm_grad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(batchnorm_grad_anf);
|
||||
MS_EXCEPTION_IF_NULL(batchnorm_grad);
|
||||
*batchnorm_grad = batchnorm_grad_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(*batchnorm_grad);
|
||||
return CheckBatchNormGrad(graph, *batchnorm_grad);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -117,11 +117,13 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
CNodePtr batchnormgrad = nullptr;
|
||||
if (!NeedFusion(graph, node, &batchnormgrad)) {
|
||||
CNodePtr batchnorm_grad = nullptr;
|
||||
if (!NeedFusion(graph, node, &batchnorm_grad)) {
|
||||
return nullptr;
|
||||
}
|
||||
return CreateBNInferGrad(graph, batchnormgrad, node);
|
||||
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node);
|
||||
TransferDepend(batchnorm_grad, graph, bn_infer_grad);
|
||||
return bn_infer_grad;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -872,5 +872,26 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
|
|||
return new_value_node;
|
||||
}
|
||||
|
||||
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
// find BatchNorm's output which is a Depend or ControlDepend
|
||||
for (const auto &node_index : manager->node_users()[old_node]) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
size_t index = IntToSize(node_index.second);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
||||
auto control_depend = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(control_depend);
|
||||
control_depend->set_input(index, new_node);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
|
||||
auto depend = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_input(index, new_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -203,6 +203,9 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
|
|||
|
||||
// Create a new value node of func graph,not kernel graph
|
||||
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
|
||||
|
||||
// Transfer depend or control_depend to the new node
|
||||
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
|
||||
static std::vector<size_t> g_output_idx;
|
||||
|
||||
bool HasAtomic(const AnfNodePtr &input) {
|
||||
|
|
|
@ -98,7 +98,7 @@ void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
|
|||
}
|
||||
}
|
||||
|
||||
bool CheckSegments(size_t segments, size_t communication_op_node_size, std::vector<size_t> *segment_index) {
|
||||
bool CheckSegments(size_t segments, size_t communication_op_node_size, const std::vector<size_t> *segment_index) {
|
||||
MS_EXCEPTION_IF_NULL(segment_index);
|
||||
if (segments >= communication_op_node_size) {
|
||||
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace opt {
|
|||
class ConstToAttrStridedSliceGradPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConstToAttrStridedSliceGradPass(bool multigraph = true)
|
||||
: PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {}
|
||||
: PatternProcessPass("const_to_attr_strided_slice_grad", multigraph) {}
|
||||
~ConstToAttrStridedSliceGradPass() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
|
Loading…
Reference in New Issue