add reduce precision in pynative mode

This commit is contained in:
chujinjin 2020-10-28 16:40:31 +08:00
parent 40cfca781a
commit 9197d9f2ee
3 changed files with 13 additions and 5 deletions

View File

@ -120,6 +120,15 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -329,6 +338,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(kernel_graph);
SelectKernel(kernel_graph);
RunOpHardwareOptimize(kernel_graph);
StartKernelRT();
// Hide NopOp from execution graph
opt::HideNopNode(kernel_graph.get());

View File

@ -50,6 +50,8 @@ class GPUSession : public SessionBasic {
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);

View File

@ -395,11 +395,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
result =
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
if (!result) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
result = kernel::GpuKernelFactory::GetInstance().ReducePrecision(AnfAlgo::GetCNodeName(kernel_node), builder);
}
result = kernel::GpuKernelFactory::GetInstance().ReducePrecision(AnfAlgo::GetCNodeName(kernel_node), builder);
}
if (!result) {
result = SelectAkgKernel(kernel_node, builder->Build());