From 0ce25b47242c5c5493d0a747139c6d61aa775ed3 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Mon, 8 Mar 2021 12:05:12 +0800 Subject: [PATCH] gpu inference config --- include/api/cell.h | 2 -- mindspore/ccsrc/cxx_api/cell.cc | 2 -- mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc | 5 +---- mindspore/ccsrc/cxx_api/graph/graph_impl.h | 4 +--- mindspore/ccsrc/cxx_api/model/ms/ms_model.cc | 2 -- 5 files changed, 2 insertions(+), 13 deletions(-) diff --git a/include/api/cell.h b/include/api/cell.h index 8b9580af49f..3039fa816bb 100644 --- a/include/api/cell.h +++ b/include/api/cell.h @@ -22,7 +22,6 @@ #include "include/api/status.h" #include "include/api/types.h" #include "include/api/graph.h" -#include "include/api/context.h" namespace mindspore { class InputAndOutput; @@ -98,7 +97,6 @@ class MS_API GraphCell final : public Cell { explicit GraphCell(Graph &&); explicit GraphCell(const std::shared_ptr &); - void SetContext(const std::shared_ptr &context); const std::shared_ptr &GetGraph() const { return graph_; } Status Run(const std::vector &inputs, std::vector *outputs) override; std::vector GetInputs(); diff --git a/mindspore/ccsrc/cxx_api/cell.cc b/mindspore/ccsrc/cxx_api/cell.cc index d7d5b53b304..ebf3a4706ed 100644 --- a/mindspore/ccsrc/cxx_api/cell.cc +++ b/mindspore/ccsrc/cxx_api/cell.cc @@ -78,8 +78,6 @@ GraphCell::GraphCell(Graph &&graph) executor_->SetGraph(graph_); } -void GraphCell::SetContext(const std::shared_ptr &context) { return executor_->SetContext(context); } - Status GraphCell::Run(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(executor_); return executor_->Run(inputs, outputs); diff --git a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc index d7872894f83..b755a0600e4 100644 --- a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc @@ -51,10 +51,7 @@ Status GPUGraphImpl::InitEnv() { ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); ms_context->set_param(MS_CTX_DEVICE_ID, device_id_); ms_context->set_param(MS_CTX_DEVICE_TARGET, kGPUDevice); - auto enable_trt = ModelContext::GetGpuTrtInferMode(graph_context_); - if (enable_trt == "True") { - ms_context->set_param(MS_CTX_ENABLE_INFER_OPT, true); - } + ms_context->set_param(MS_CTX_ENABLE_INFER_OPT, false); session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); if (session_impl_ == nullptr) { diff --git a/mindspore/ccsrc/cxx_api/graph/graph_impl.h b/mindspore/ccsrc/cxx_api/graph/graph_impl.h index 3b46e408da4..401df187da7 100644 --- a/mindspore/ccsrc/cxx_api/graph/graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/graph_impl.h @@ -29,12 +29,11 @@ namespace mindspore { class GraphCell::GraphImpl { public: - GraphImpl() : graph_(nullptr), graph_context_(nullptr) {} + GraphImpl() : graph_(nullptr) {} virtual ~GraphImpl() = default; std::shared_ptr &MutableGraphData() const { return graph_->graph_data_; } void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } - void SetContext(const std::shared_ptr &context) { graph_context_ = context; } virtual Status Run(const std::vector &inputs, std::vector *outputs) = 0; virtual Status Load() = 0; @@ -44,7 +43,6 @@ class GraphCell::GraphImpl { protected: std::shared_ptr graph_; - std::shared_ptr graph_context_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc index 9a07105a228..5a4366d0b7d 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -70,7 +70,6 @@ std::shared_ptr MsModel::GenerateGraphCell(const std::vector(graph); MS_EXCEPTION_IF_NULL(graph_cell); - graph_cell->SetContext(model_context_); auto ret = ModelImpl::Load(graph_cell); if (ret != kSuccess) { MS_LOG(ERROR) << "Load failed."; @@ -96,7 +95,6 @@ Status MsModel::Build() { MS_EXCEPTION_IF_NULL(graph); auto graph_cell = std::make_shared(graph); MS_EXCEPTION_IF_NULL(graph_cell); - graph_cell->SetContext(model_context_); auto ret = ModelImpl::Load(graph_cell); if (ret != kSuccess) { MS_LOG(ERROR) << "Load failed.";