forked from mindspore-Ecosystem/mindspore
gpu inference config
This commit is contained in:
parent
363e574ff8
commit
0ce25b4724
|
@ -22,7 +22,6 @@
|
|||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
class InputAndOutput;
|
||||
|
@ -98,7 +97,6 @@ class MS_API GraphCell final : public Cell<GraphCell> {
|
|||
explicit GraphCell(Graph &&);
|
||||
explicit GraphCell(const std::shared_ptr<Graph> &);
|
||||
|
||||
void SetContext(const std::shared_ptr<Context> &context);
|
||||
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
|
||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||
std::vector<MSTensor> GetInputs();
|
||||
|
|
|
@ -78,8 +78,6 @@ GraphCell::GraphCell(Graph &&graph)
|
|||
executor_->SetGraph(graph_);
|
||||
}
|
||||
|
||||
void GraphCell::SetContext(const std::shared_ptr<Context> &context) { return executor_->SetContext(context); }
|
||||
|
||||
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->Run(inputs, outputs);
|
||||
|
|
|
@ -51,10 +51,7 @@ Status GPUGraphImpl::InitEnv() {
|
|||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice);
|
||||
auto enable_trt = ModelContext::GetGpuTrtInferMode(graph_context_);
|
||||
if (enable_trt == "True") {
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true);
|
||||
}
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
|
||||
|
||||
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
|
||||
if (session_impl_ == nullptr) {
|
||||
|
|
|
@ -29,12 +29,11 @@
|
|||
namespace mindspore {
|
||||
class GraphCell::GraphImpl {
|
||||
public:
|
||||
GraphImpl() : graph_(nullptr), graph_context_(nullptr) {}
|
||||
GraphImpl() : graph_(nullptr) {}
|
||||
virtual ~GraphImpl() = default;
|
||||
|
||||
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
|
||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||
void SetContext(const std::shared_ptr<Context> &context) { graph_context_ = context; }
|
||||
|
||||
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
|
||||
virtual Status Load() = 0;
|
||||
|
@ -44,7 +43,6 @@ class GraphCell::GraphImpl {
|
|||
|
||||
protected:
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::shared_ptr<Context> graph_context_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
||||
|
|
|
@ -70,7 +70,6 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||
graph_cell->SetContext(model_context_);
|
||||
auto ret = ModelImpl::Load(graph_cell);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Load failed.";
|
||||
|
@ -96,7 +95,6 @@ Status MsModel::Build() {
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||
graph_cell->SetContext(model_context_);
|
||||
auto ret = ModelImpl::Load(graph_cell);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Load failed.";
|
||||
|
|
Loading…
Reference in New Issue