!24156 Move InsertCastCpu before BuildKernel

Merge pull request !24156 from DeshiChen/0926_insert_cast_cpu
This commit is contained in:
i-robot 2021-09-28 03:19:31 +00:00 committed by Gitee
commit 5bd7e4def9
5 changed files with 11 additions and 43 deletions

View File

@ -28,10 +28,12 @@
#include "utils/utils.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/common_utils.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
namespace {
constexpr unsigned int kLstmReserveIndex = 3;
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
@ -52,17 +54,8 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
}
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
AnfAlgo::SetNodeAttr("dst_type", TypeIdToType(output_type), cast);
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kCastOpName, cast);
if (cpu_kernel == nullptr) {
MS_LOG(EXCEPTION) << "Operator[Cast] " << cast->kernel_info() << " is not support.";
}
try {
cpu_kernel->Init(cast);
} catch (std::exception &e) {
MS_LOG(EXCEPTION) << e.what() << "\nTrace: " << trace::DumpSourceLines(cast);
}
AnfAlgo::SetKernelMod(cpu_kernel, cast.get());
return cast;
}
@ -100,6 +93,11 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, i);
auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, i);
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, i);
// The shape of LSTM's reserved output will be changed in InitKernel, and this output is only used
// by its gradient operator, so we don't handle it in this pass.
if (IsPrimitiveCNode(cnode, prim::kPrimLstm) && i == kLstmReserveIndex) {
continue;
}
if (infer_type != device_type) {
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i);
for (size_t j = 0; j < used_node_list->size(); j++) {

View File

@ -95,17 +95,7 @@ void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
}
#endif
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void CPUSession::ProcessCast(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
MS_EXCEPTION_IF_NULL(pm);
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast_cpu"));
MS_LOG(INFO) << "Insert cast pass";
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast"));
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
@ -124,7 +114,6 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
FinalOptimize(graph);
MS_LOG(INFO) << "Build kernel";
BuildKernel(graph.get());
ProcessCast(graph);
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
auto execution_order = graph->execution_order();
Reorder(&execution_order);
@ -235,7 +224,6 @@ KernelGraphPtr CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const Graph
SetKernelInfo(kernel_graph.get());
Optimize(kernel_graph);
BuildKernel(kernel_graph.get());
ProcessCast(kernel_graph);
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_op_graph_cache) {
run_op_graphs_[graph_info] = kernel_graph;

View File

@ -56,7 +56,6 @@ class CPUSession : public SessionBasic {
private:
void Reorder(std::vector<CNodePtr> *node_list);
void ProcessCast(const std::shared_ptr<KernelGraph> &kernel_graph);
void SetKernelInfo(const KernelGraph *kernel_graph);
void BuildKernel(const KernelGraph *kernel_graph);
void SetOutputFlags(const VectorRef &base_ref);

View File

@ -118,6 +118,8 @@ void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast"));
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();
@ -179,32 +181,14 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
}
}
namespace {
void ProcessCast(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast_cpu"));
MS_LOG(INFO) << "Insert cast pass";
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();
}
} // namespace
void CPUDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
ProcessCast(graph);
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
auto execution_order = graph->execution_order();
AnfAlgo::ReorderPosteriorExecList(NOT_NULL(&execution_order));
graph->set_execution_order(execution_order);
}
void CPUDeviceContext::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const { ProcessCast(graph); }
bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
bool) const {

View File

@ -51,7 +51,6 @@ class CPUDeviceContext : public DeviceContext {
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
void PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const override;
void PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const override;
bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,