code clean backend/pass

This commit is contained in:
zhaosida 2021-06-17 16:22:32 +08:00
parent 9085be08b9
commit a239d85c80
10 changed files with 22 additions and 8 deletions

View File

@ -105,6 +105,7 @@ bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bo
bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto backend_cse = std::make_shared<BackendCSE>();
MS_EXCEPTION_IF_NULL(backend_cse);
return backend_cse->Cse(func_graph, func_graph->manager());
}
} // namespace opt

View File

@ -361,6 +361,7 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu
continue;
}
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto graph_id = kernel_graph->graph_id();
AnfNodePtr new_communication_op =
CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);

View File

@ -29,11 +29,14 @@ const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const A
return nullptr;
}
CNodePtr cnode = node->cast<CNodePtr>();
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs();
AnfNodePtr op = inputs[0];
MS_EXCEPTION_IF_NULL(op);
if (IsValueNode<Primitive>(op)) {
auto prim = GetValueNode<PrimitivePtr>(op);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
std::string type_name = prim->name();
for (auto attr : attrs) {

View File

@ -34,7 +34,8 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
return nullptr;
}
CNodePtr cnode = node->cast<CNodePtr>();
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
ConstInputToAttrInfoRegister reg;
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) {
return nullptr;
@ -47,6 +48,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return nullptr;
}

View File

@ -68,6 +68,7 @@ const AnfNodePtr ConvertConstScalarToTensor::Process(const FuncGraphPtr &func_gr
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
bool input_changed = false;
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
auto new_input = CreateTensorInput(func_graph->cast<KernelGraphPtr>(), cnode->inputs()[i]);

View File

@ -38,11 +38,12 @@ AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNode
if (anf_graph != nullptr) {
kernel_graph = anf_graph->cast<KernelGraphPtr>();
}
MS_EXCEPTION_IF_NULL(kernel_graph);
if (kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf)) {
return kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf);
}
auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf);
MS_EXCEPTION_IF_NULL(make_tuple);
kernel_graph->InsertTupleParameterToMakeTupleMap(tuple_anf, make_tuple);
// replace graph inputs if input is a parameter
kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple);

View File

@ -26,7 +26,7 @@ namespace {
bool IsC(const BaseRef &n) {
MS_EXCEPTION_IF_NULL(n);
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
auto in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
return in->isa<ValueNode>();
} else {
@ -43,7 +43,7 @@ const BaseRef GetitemTuple::DefinePattern() const {
const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
CNodePtr tuple_getitem = node->cast<CNodePtr>();
auto tuple_getitem = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum);
AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
@ -51,10 +51,10 @@ const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &n
AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(index_node);
if (IsValueNode<Int64Imm>(index_node)) {
ValueNodePtr value_node = index_node->cast<ValueNodePtr>();
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto index = GetValue<int64_t>(value_node->value());
CNodePtr make_tuple = make_tuple_anf->cast<CNodePtr>();
auto make_tuple = make_tuple_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
if (make_tuple->inputs().size() > LongToSize(index + 1)) {
auto ret = make_tuple->input(LongToSize(index + 1));

View File

@ -60,6 +60,7 @@ CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
return nullptr;
}
auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
MS_EXCEPTION_IF_NULL(real_input_op);
if (!real_input_op->isa<CNode>()) {
return nullptr;
}
@ -210,6 +211,7 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto depend_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
auto replacing_node = depend_cnode->input(index);
MS_EXCEPTION_IF_NULL(replacing_node);
if (!replacing_node->isa<CNode>()) {

View File

@ -35,6 +35,8 @@ const BaseRef OptimizeUpdateState::DefinePattern() const {
const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto update_state = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(update_state);
if (update_state->size() <= kAdditionalAttachIndex) {

View File

@ -60,7 +60,8 @@ bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
for (auto node : node_list) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
CNodePtr cnode = node->cast<CNodePtr>();
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim = std::make_shared<Primitive>(kEmbeddingLookupProxyOpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> proxy_inputs = {NewValueNode(prim)};