diff --git a/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc b/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc index ad4b0bbc0d8..3b762f255bc 100644 --- a/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc +++ b/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc @@ -28,10 +28,12 @@ #include "utils/utils.h" #include "utils/ms_context.h" #include "backend/kernel_compiler/common_utils.h" +#include "base/core_ops.h" namespace mindspore { namespace opt { namespace { +constexpr unsigned int kLstmReserveIndex = 3; AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, const TypeId &input_type, const TypeId &output_type, const std::vector &origin_shape, const TypeId &origin_type) { @@ -52,17 +54,8 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr } AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); + AnfAlgo::SetNodeAttr("dst_type", TypeIdToType(output_type), cast); AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); - std::shared_ptr cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kCastOpName, cast); - if (cpu_kernel == nullptr) { - MS_LOG(EXCEPTION) << "Operator[Cast] " << cast->kernel_info() << " is not support."; - } - try { - cpu_kernel->Init(cast); - } catch (std::exception &e) { - MS_LOG(EXCEPTION) << e.what() << "\nTrace: " << trace::DumpSourceLines(cast); - } - AnfAlgo::SetKernelMod(cpu_kernel, cast.get()); return cast; } @@ -100,6 +93,11 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, i); auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, i); const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, i); + // The shape of LSTM's reserved output will be changed in InitKernel, and this output is only used + // by its gradient operator, so we don't handle it in this pass. + if (IsPrimitiveCNode(cnode, prim::kPrimLstm) && i == kLstmReserveIndex) { + continue; + } if (infer_type != device_type) { auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i); for (size_t j = 0; j < used_node_list->size(); j++) { diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index bf5de6bd35e..fe95ad3be0f 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -95,17 +95,7 @@ void CPUSession::Optimize(const std::shared_ptr &kernel_graph) { } #endif pm->AddPass(std::make_shared("insert_format_transform_op_cpu")); - optimizer->AddPassManager(pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void CPUSession::ProcessCast(const std::shared_ptr &kernel_graph) { - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - MS_EXCEPTION_IF_NULL(pm); - pm->AddPass(std::make_shared("insert_cast_cpu")); - MS_LOG(INFO) << "Insert cast pass"; + pm->AddPass(std::make_shared("insert_cast")); pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); @@ -124,7 +114,6 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr FinalOptimize(graph); MS_LOG(INFO) << "Build kernel"; BuildKernel(graph.get()); - ProcessCast(graph); // Remove reorder after PS feature finish adapting push/pull in auto_monad. auto execution_order = graph->execution_order(); Reorder(&execution_order); @@ -235,7 +224,6 @@ KernelGraphPtr CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const Graph SetKernelInfo(kernel_graph.get()); Optimize(kernel_graph); BuildKernel(kernel_graph.get()); - ProcessCast(kernel_graph); auto enable_op_graph_cache = MsContext::GetInstance()->get_param(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE); if (enable_op_graph_cache) { run_op_graphs_[graph_info] = kernel_graph; diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index e9a29abbbef..90f6066407a 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -56,7 +56,6 @@ class CPUSession : public SessionBasic { private: void Reorder(std::vector *node_list); - void ProcessCast(const std::shared_ptr &kernel_graph); void SetKernelInfo(const KernelGraph *kernel_graph); void BuildKernel(const KernelGraph *kernel_graph); void SetOutputFlags(const VectorRef &base_ref); diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc index 6e15f806b62..a749fb01974 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc @@ -118,6 +118,8 @@ void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const { auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared("insert_format_transform_op_cpu")); + pm->AddPass(std::make_shared("insert_cast")); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(graph); graph->SetExecOrderByDefault(); @@ -179,32 +181,14 @@ void CPUDeviceContext::CreateKernel(const std::vector &nodes) const { } } -namespace { -void ProcessCast(const KernelGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - pm->AddPass(std::make_shared("insert_cast_cpu")); - MS_LOG(INFO) << "Insert cast pass"; - pm->AddPass(std::make_shared()); - optimizer->AddPassManager(pm); - (void)optimizer->Optimize(graph); - graph->SetExecOrderByDefault(); -} -} // namespace - void CPUDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const { MS_EXCEPTION_IF_NULL(graph); - ProcessCast(graph); - // Remove reorder after PS feature finish adapting push/pull in auto_monad. auto execution_order = graph->execution_order(); AnfAlgo::ReorderPosteriorExecList(NOT_NULL(&execution_order)); graph->set_execution_order(execution_order); } -void CPUDeviceContext::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const { ProcessCast(graph); } - bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, bool) const { diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h index 0855c709282..8190b7e2c55 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h @@ -51,7 +51,6 @@ class CPUDeviceContext : public DeviceContext { void CreateKernel(const std::vector &nodes) const override; void PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const override; - void PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const override; bool LaunchKernel(const CNodePtr &kernel, const std::vector &inputs, const std::vector &workspace, const std::vector &outputs,