forked from mindspore-Ecosystem/mindspore
gpu inference context
This commit is contained in:
parent
1892a629c8
commit
db2668d72a
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class InputAndOutput;
|
class InputAndOutput;
|
||||||
|
class Context;
|
||||||
using Input = InputAndOutput;
|
using Input = InputAndOutput;
|
||||||
using Output = InputAndOutput;
|
using Output = InputAndOutput;
|
||||||
|
|
||||||
|
@ -97,6 +98,7 @@ class MS_API GraphCell final : public Cell<GraphCell> {
|
||||||
explicit GraphCell(Graph &&);
|
explicit GraphCell(Graph &&);
|
||||||
explicit GraphCell(const std::shared_ptr<Graph> &);
|
explicit GraphCell(const std::shared_ptr<Graph> &);
|
||||||
|
|
||||||
|
void SetContext(const std::shared_ptr<Context> &context);
|
||||||
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
|
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
|
||||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||||
std::vector<MSTensor> GetInputs();
|
std::vector<MSTensor> GetInputs();
|
||||||
|
|
|
@ -212,6 +212,5 @@ std::string GpuInferenceSession::InputsInfo(const std::vector<ParameterPtr> &par
|
||||||
}
|
}
|
||||||
return graph + " " + actual;
|
return graph + " " + actual;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -73,6 +73,18 @@ GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_E
|
||||||
|
|
||||||
GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
|
GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
|
||||||
|
|
||||||
|
void GraphCell::SetContext(const std::shared_ptr<Context> &context) {
|
||||||
|
if (executor_ == nullptr) {
|
||||||
|
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
|
||||||
|
if (executor_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
executor_->SetGraph(graph_);
|
||||||
|
}
|
||||||
|
executor_->SetContext(context);
|
||||||
|
}
|
||||||
|
|
||||||
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||||
if (executor_ == nullptr) {
|
if (executor_ == nullptr) {
|
||||||
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
|
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
|
||||||
|
|
|
@ -54,7 +54,17 @@ Status GPUGraphImpl::InitEnv() {
|
||||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice);
|
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice);
|
||||||
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
|
|
||||||
|
auto &device_infos = graph_context_->MutableDeviceInfo();
|
||||||
|
if (device_infos.size() != 1) {
|
||||||
|
return kMCDeviceError;
|
||||||
|
}
|
||||||
|
auto gpu_info = device_infos[0]->Cast<NvidiaGPUDeviceInfo>();
|
||||||
|
if (gpu_info == nullptr) {
|
||||||
|
return kMCDeviceError;
|
||||||
|
}
|
||||||
|
auto enable_trt = gpu_info->GetGpuTrtInferMode();
|
||||||
|
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, enable_trt);
|
||||||
|
|
||||||
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
|
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
|
||||||
if (session_impl_ == nullptr) {
|
if (session_impl_ == nullptr) {
|
||||||
|
|
|
@ -29,11 +29,12 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class GraphCell::GraphImpl {
|
class GraphCell::GraphImpl {
|
||||||
public:
|
public:
|
||||||
GraphImpl() : graph_(nullptr) {}
|
GraphImpl() : graph_(nullptr), graph_context_(nullptr) {}
|
||||||
virtual ~GraphImpl() = default;
|
virtual ~GraphImpl() = default;
|
||||||
|
|
||||||
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
|
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
|
||||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||||
|
void SetContext(const std::shared_ptr<Context> &context) { graph_context_ = context; }
|
||||||
|
|
||||||
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
|
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
|
||||||
virtual Status Load(uint32_t device_id) = 0;
|
virtual Status Load(uint32_t device_id) = 0;
|
||||||
|
@ -43,6 +44,7 @@ class GraphCell::GraphImpl {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<Graph> graph_;
|
std::shared_ptr<Graph> graph_;
|
||||||
|
std::shared_ptr<Context> graph_context_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
||||||
|
|
|
@ -74,6 +74,7 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||||
|
graph_cell->SetContext(model_context_);
|
||||||
auto ret = ModelImpl::Load(graph_cell, GetDeviceID());
|
auto ret = ModelImpl::Load(graph_cell, GetDeviceID());
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Load failed.";
|
MS_LOG(ERROR) << "Load failed.";
|
||||||
|
@ -99,6 +100,7 @@ Status MsModel::Build() {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||||
|
graph_cell->SetContext(model_context_);
|
||||||
auto ret = ModelImpl::Load(graph_cell, GetDeviceID());
|
auto ret = ModelImpl::Load(graph_cell, GetDeviceID());
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Load failed.";
|
MS_LOG(ERROR) << "Load failed.";
|
||||||
|
|
|
@ -83,6 +83,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||||
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
|
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
|
||||||
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
|
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
|
||||||
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
|
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
|
||||||
|
set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
|
||||||
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
|
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
|
||||||
|
|
||||||
backend_policy_ = policy_map_[policy];
|
backend_policy_ = policy_map_[policy];
|
||||||
|
|
Loading…
Reference in New Issue