forked from mindspore-Ecosystem/mindspore
code clean backend/pass
This commit is contained in:
parent
9085be08b9
commit
a239d85c80
|
@ -105,6 +105,7 @@ bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bo
|
|||
bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto backend_cse = std::make_shared<BackendCSE>();
|
||||
MS_EXCEPTION_IF_NULL(backend_cse);
|
||||
return backend_cse->Cse(func_graph, func_graph->manager());
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -361,6 +361,7 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu
|
|||
continue;
|
||||
}
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto graph_id = kernel_graph->graph_id();
|
||||
AnfNodePtr new_communication_op =
|
||||
CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
|
||||
|
|
|
@ -29,11 +29,14 @@ const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const A
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
AnfNodePtr op = inputs[0];
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
if (IsValueNode<Primitive>(op)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(op);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
std::string type_name = prim->name();
|
||||
for (auto attr : attrs) {
|
||||
|
|
|
@ -34,7 +34,8 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
ConstInputToAttrInfoRegister reg;
|
||||
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) {
|
||||
return nullptr;
|
||||
|
@ -47,6 +48,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
|||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -68,6 +68,7 @@ const AnfNodePtr ConvertConstScalarToTensor::Process(const FuncGraphPtr &func_gr
|
|||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
bool input_changed = false;
|
||||
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
|
||||
auto new_input = CreateTensorInput(func_graph->cast<KernelGraphPtr>(), cnode->inputs()[i]);
|
||||
|
|
|
@ -38,11 +38,12 @@ AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNode
|
|||
if (anf_graph != nullptr) {
|
||||
kernel_graph = anf_graph->cast<KernelGraphPtr>();
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf)) {
|
||||
return kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf);
|
||||
}
|
||||
auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
kernel_graph->InsertTupleParameterToMakeTupleMap(tuple_anf, make_tuple);
|
||||
// replace graph inputs if input is a parameter
|
||||
kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple);
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace {
|
|||
bool IsC(const BaseRef &n) {
|
||||
MS_EXCEPTION_IF_NULL(n);
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
|
||||
auto in = utils::cast<AnfNodePtr>(n);
|
||||
MS_EXCEPTION_IF_NULL(in);
|
||||
return in->isa<ValueNode>();
|
||||
} else {
|
||||
|
@ -43,7 +43,7 @@ const BaseRef GetitemTuple::DefinePattern() const {
|
|||
|
||||
const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
CNodePtr tuple_getitem = node->cast<CNodePtr>();
|
||||
auto tuple_getitem = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum);
|
||||
AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
|
||||
|
@ -51,10 +51,10 @@ const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &n
|
|||
AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
if (IsValueNode<Int64Imm>(index_node)) {
|
||||
ValueNodePtr value_node = index_node->cast<ValueNodePtr>();
|
||||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto index = GetValue<int64_t>(value_node->value());
|
||||
CNodePtr make_tuple = make_tuple_anf->cast<CNodePtr>();
|
||||
auto make_tuple = make_tuple_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
if (make_tuple->inputs().size() > LongToSize(index + 1)) {
|
||||
auto ret = make_tuple->input(LongToSize(index + 1));
|
||||
|
|
|
@ -60,6 +60,7 @@ CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
|
|||
return nullptr;
|
||||
}
|
||||
auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
|
||||
MS_EXCEPTION_IF_NULL(real_input_op);
|
||||
if (!real_input_op->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -210,6 +211,7 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto depend_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto replacing_node = depend_cnode->input(index);
|
||||
MS_EXCEPTION_IF_NULL(replacing_node);
|
||||
if (!replacing_node->isa<CNode>()) {
|
||||
|
|
|
@ -35,6 +35,8 @@ const BaseRef OptimizeUpdateState::DefinePattern() const {
|
|||
|
||||
const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto update_state = dyn_cast<CNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(update_state);
|
||||
if (update_state->size() <= kAdditionalAttachIndex) {
|
||||
|
|
|
@ -60,7 +60,8 @@ bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
|
|||
for (auto node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = std::make_shared<Primitive>(kEmbeddingLookupProxyOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> proxy_inputs = {NewValueNode(prim)};
|
||||
|
|
Loading…
Reference in New Issue