forked from mindspore-Ecosystem/mindspore
!24156 Move InsertCastCpu before BuildKernel
Merge pull request !24156 from DeshiChen/0926_insert_cast_cpu
This commit is contained in:
commit
5bd7e4def9
|
@ -28,10 +28,12 @@
|
|||
#include "utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr unsigned int kLstmReserveIndex = 3;
|
||||
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
|
||||
const TypeId &input_type, const TypeId &output_type,
|
||||
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
|
||||
|
@ -52,17 +54,8 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
|
||||
AnfAlgo::SetNodeAttr("dst_type", TypeIdToType(output_type), cast);
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
||||
std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kCastOpName, cast);
|
||||
if (cpu_kernel == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Operator[Cast] " << cast->kernel_info() << " is not support.";
|
||||
}
|
||||
try {
|
||||
cpu_kernel->Init(cast);
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << e.what() << "\nTrace: " << trace::DumpSourceLines(cast);
|
||||
}
|
||||
AnfAlgo::SetKernelMod(cpu_kernel, cast.get());
|
||||
return cast;
|
||||
}
|
||||
|
||||
|
@ -100,6 +93,11 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
|
|||
auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, i);
|
||||
auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, i);
|
||||
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, i);
|
||||
// The shape of LSTM's reserved output will be changed in InitKernel, and this output is only used
|
||||
// by its gradient operator, so we don't handle it in this pass.
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimLstm) && i == kLstmReserveIndex) {
|
||||
continue;
|
||||
}
|
||||
if (infer_type != device_type) {
|
||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i);
|
||||
for (size_t j = 0; j < used_node_list->size(); j++) {
|
||||
|
|
|
@ -95,17 +95,7 @@ void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
}
|
||||
#endif
|
||||
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void CPUSession::ProcessCast(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
MS_EXCEPTION_IF_NULL(pm);
|
||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast_cpu"));
|
||||
MS_LOG(INFO) << "Insert cast pass";
|
||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast"));
|
||||
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
@ -124,7 +114,6 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
|||
FinalOptimize(graph);
|
||||
MS_LOG(INFO) << "Build kernel";
|
||||
BuildKernel(graph.get());
|
||||
ProcessCast(graph);
|
||||
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
|
||||
auto execution_order = graph->execution_order();
|
||||
Reorder(&execution_order);
|
||||
|
@ -235,7 +224,6 @@ KernelGraphPtr CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const Graph
|
|||
SetKernelInfo(kernel_graph.get());
|
||||
Optimize(kernel_graph);
|
||||
BuildKernel(kernel_graph.get());
|
||||
ProcessCast(kernel_graph);
|
||||
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_op_graph_cache) {
|
||||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
|
|
|
@ -56,7 +56,6 @@ class CPUSession : public SessionBasic {
|
|||
|
||||
private:
|
||||
void Reorder(std::vector<CNodePtr> *node_list);
|
||||
void ProcessCast(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
void SetKernelInfo(const KernelGraph *kernel_graph);
|
||||
void BuildKernel(const KernelGraph *kernel_graph);
|
||||
void SetOutputFlags(const VectorRef &base_ref);
|
||||
|
|
|
@ -118,6 +118,8 @@ void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
|
|||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
|
||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast"));
|
||||
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
@ -179,32 +181,14 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void ProcessCast(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast_cpu"));
|
||||
MS_LOG(INFO) << "Insert cast pass";
|
||||
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CPUDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
ProcessCast(graph);
|
||||
|
||||
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
|
||||
auto execution_order = graph->execution_order();
|
||||
AnfAlgo::ReorderPosteriorExecList(NOT_NULL(&execution_order));
|
||||
graph->set_execution_order(execution_order);
|
||||
}
|
||||
|
||||
void CPUDeviceContext::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const { ProcessCast(graph); }
|
||||
|
||||
bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
bool) const {
|
||||
|
|
|
@ -51,7 +51,6 @@ class CPUDeviceContext : public DeviceContext {
|
|||
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
|
||||
|
||||
void PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const override;
|
||||
void PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const override;
|
||||
|
||||
bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
|
|
Loading…
Reference in New Issue