forked from mindspore-Ecosystem/mindspore
add reduce precision in pynative mode
This commit is contained in:
parent
40cfca781a
commit
9197d9f2ee
|
@ -120,6 +120,15 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -329,6 +338,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
|
|||
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
SelectKernel(kernel_graph);
|
||||
RunOpHardwareOptimize(kernel_graph);
|
||||
StartKernelRT();
|
||||
// Hide NopOp from execution graph
|
||||
opt::HideNopNode(kernel_graph.get());
|
||||
|
|
|
@ -50,6 +50,8 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
|
|
@ -395,11 +395,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
|||
result =
|
||||
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
|
||||
if (!result) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
result = kernel::GpuKernelFactory::GetInstance().ReducePrecision(AnfAlgo::GetCNodeName(kernel_node), builder);
|
||||
}
|
||||
result = kernel::GpuKernelFactory::GetInstance().ReducePrecision(AnfAlgo::GetCNodeName(kernel_node), builder);
|
||||
}
|
||||
if (!result) {
|
||||
result = SelectAkgKernel(kernel_node, builder->Build());
|
||||
|
|
Loading…
Reference in New Issue