forked from mindspore-Ecosystem/mindspore
mindspore serving support gpu backend
This commit is contained in:
parent
3708624a25
commit
a911b9ef9e
|
@ -104,6 +104,7 @@ class MS_API Buffer {
|
||||||
|
|
||||||
extern MS_API const char *kDeviceTypeAscend310;
|
extern MS_API const char *kDeviceTypeAscend310;
|
||||||
extern MS_API const char *kDeviceTypeAscend910;
|
extern MS_API const char *kDeviceTypeAscend910;
|
||||||
|
extern MS_API const char *kDeviceTypeGpu;
|
||||||
|
|
||||||
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
|
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
|
||||||
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
|
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
|
||||||
|
|
|
@ -13,7 +13,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_GPU)
|
if(ENABLE_GPU)
|
||||||
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc")
|
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc" "gpu_inference_session.cc")
|
||||||
list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST})
|
list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,217 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include "backend/session/gpu_inference_session.h"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/param_info.h"
|
||||||
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
#include "common/trans.h"
|
||||||
|
#include "utils/config_manager.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace session {
|
||||||
|
void GpuInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||||
|
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
||||||
|
auto input_nodes = kernel_graph->inputs();
|
||||||
|
|
||||||
|
size_t no_weight_input = 0;
|
||||||
|
for (size_t i = 0; i < input_nodes.size(); ++i) {
|
||||||
|
tensor::TensorPtr tensor = nullptr;
|
||||||
|
if (!input_nodes[i]->isa<Parameter>()) {
|
||||||
|
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(pk_node);
|
||||||
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
|
if (!AnfAlgo::IsParameterWeight(pk_node)) {
|
||||||
|
tensor = inputs[no_weight_input++];
|
||||||
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||||
|
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||||
|
tensor->data_c())) {
|
||||||
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphId GpuInferenceSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
||||||
|
auto graph_id = GPUSession::CompileGraphImpl(func_graph);
|
||||||
|
auto kernel_graph = GetGraph(graph_id);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
// load weight data to device
|
||||||
|
auto input_nodes = kernel_graph->inputs();
|
||||||
|
for (size_t i = 0; i < input_nodes.size(); ++i) {
|
||||||
|
if (!input_nodes[i]->isa<Parameter>()) {
|
||||||
|
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(pk_node);
|
||||||
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
|
if (AnfAlgo::IsParameterWeight(pk_node)) {
|
||||||
|
const auto ¶m_value = pk_node->default_param();
|
||||||
|
MS_EXCEPTION_IF_NULL(param_value);
|
||||||
|
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||||
|
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||||
|
tensor->data_c())) {
|
||||||
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return graph_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GpuInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
|
std::string *error_msg) const {
|
||||||
|
MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id;
|
||||||
|
auto kernel_graph = GetGraph(graph_id);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
auto kernel_graph_inputs = kernel_graph->inputs();
|
||||||
|
size_t no_weight_input = 0;
|
||||||
|
vector<ParameterPtr> paras;
|
||||||
|
// find parameters of graph inputs
|
||||||
|
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
|
||||||
|
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
|
||||||
|
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
||||||
|
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
||||||
|
paras.push_back(parameter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check inputs
|
||||||
|
for (size_t i = 0; i < paras.size(); ++i) {
|
||||||
|
// compare input number
|
||||||
|
if (paras.size() != inputs.size()) {
|
||||||
|
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
|
||||||
|
<< "] but the graph input number is [" << paras.size() << "]";
|
||||||
|
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
|
||||||
|
if (error_msg != nullptr) {
|
||||||
|
std::stringstream str_stream;
|
||||||
|
str_stream << "Input number is inconsistent. The given input number [" << inputs.size()
|
||||||
|
<< "] but the graph input number is [" << paras.size() << "]\n";
|
||||||
|
str_stream << "InputsInfo --" << InputsInfo(paras, inputs);
|
||||||
|
*error_msg = str_stream.str();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto input = inputs[no_weight_input++];
|
||||||
|
if (!CompareInput(input, paras[i])) {
|
||||||
|
MS_LOG(ERROR) << "Please check the input information.";
|
||||||
|
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
|
||||||
|
if (error_msg != nullptr) {
|
||||||
|
std::stringstream str_stream;
|
||||||
|
str_stream << "Please check the input information.\n";
|
||||||
|
str_stream << "InputsInfo --" << InputsInfo(paras, inputs);
|
||||||
|
*error_msg = str_stream.str();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GpuInferenceSession::CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
MS_EXCEPTION_IF_NULL(parameter);
|
||||||
|
// compare dims
|
||||||
|
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
|
||||||
|
|
||||||
|
// compare shape
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
vector<size_t> trans_input;
|
||||||
|
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input),
|
||||||
|
[](const int64_t dim) { return static_cast<size_t>(dim); });
|
||||||
|
auto is_scalar_shape = [](const vector<size_t> &shape) {
|
||||||
|
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
|
||||||
|
};
|
||||||
|
if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) {
|
||||||
|
MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input)
|
||||||
|
<< ", but the parameter shape is " << PrintInputShape(parameter_shape)
|
||||||
|
<< ". parameter : " << parameter->DebugString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare data type
|
||||||
|
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
|
||||||
|
if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) {
|
||||||
|
MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type()
|
||||||
|
<< ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0)
|
||||||
|
<< ". parameter : " << parameter->DebugString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::string GpuInferenceSession::PrintInputShape(std::vector<T> shape) const {
|
||||||
|
string res = "[";
|
||||||
|
for (auto dim : shape) {
|
||||||
|
res += " " + std::to_string(dim);
|
||||||
|
}
|
||||||
|
return res + " ]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GpuInferenceSession::InputsInfo(const std::vector<ParameterPtr> ¶s,
|
||||||
|
const std::vector<tensor::TensorPtr> &inputs) const {
|
||||||
|
const std::map<TypeId, std::string> dtype_name_map{
|
||||||
|
{TypeId::kNumberTypeBegin, "Unknown"}, {TypeId::kNumberTypeBool, "Bool"},
|
||||||
|
{TypeId::kNumberTypeFloat64, "Float64"}, {TypeId::kNumberTypeInt8, "Int8"},
|
||||||
|
{TypeId::kNumberTypeUInt8, "Uint8"}, {TypeId::kNumberTypeInt16, "Int16"},
|
||||||
|
{TypeId::kNumberTypeUInt16, "Uint16"}, {TypeId::kNumberTypeInt32, "Int32"},
|
||||||
|
{TypeId::kNumberTypeUInt32, "Uint32"}, {TypeId::kNumberTypeInt64, "Int64"},
|
||||||
|
{TypeId::kNumberTypeUInt64, "Uint64"}, {TypeId::kNumberTypeFloat16, "Float16"},
|
||||||
|
{TypeId::kNumberTypeFloat32, "Float32"},
|
||||||
|
};
|
||||||
|
auto data_type_to_string = [&dtype_name_map](TypeId type_id) {
|
||||||
|
auto it = dtype_name_map.find(type_id);
|
||||||
|
if (it == dtype_name_map.end()) {
|
||||||
|
return std::string("Unknown");
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string graph = "graph inputs:{ ";
|
||||||
|
for (size_t i = 0; i < paras.size(); ++i) {
|
||||||
|
auto ¶ = paras[i];
|
||||||
|
graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(para, 0).size()) +
|
||||||
|
", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(para, 0)) + ", data type " +
|
||||||
|
data_type_to_string(AnfAlgo::GetSelectKernelBuildInfo(para)->GetOutputDeviceType(0)) + " }";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string actual = "given inputs:{ ";
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " +
|
||||||
|
PrintInputShape(inputs[i]->shape()) + ", data type " + data_type_to_string(inputs[i]->data_type()) + " }";
|
||||||
|
}
|
||||||
|
return graph + " " + actual;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace session
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include <stack>
|
||||||
|
#include <map>
|
||||||
|
#include <tuple>
|
||||||
|
#include <set>
|
||||||
|
#include "backend/session/gpu_session.h"
|
||||||
|
#include "backend/session/kernel_graph.h"
|
||||||
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
|
#include "backend/session/session_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace session {
|
||||||
|
class GpuInferenceSession : public gpu::GPUSession {
|
||||||
|
public:
|
||||||
|
GpuInferenceSession() = default;
|
||||||
|
~GpuInferenceSession() = default;
|
||||||
|
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||||
|
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||||
|
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
|
std::string *error_msg) const override;
|
||||||
|
bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const;
|
||||||
|
template <typename T>
|
||||||
|
std::string PrintInputShape(std::vector<T> shape) const;
|
||||||
|
std::string InputsInfo(const std::vector<ParameterPtr> ¶s, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
|
||||||
|
};
|
||||||
|
MS_REG_SESSION(kGpuInferenceDevice, GpuInferenceSession);
|
||||||
|
} // namespace session
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H
|
|
@ -15,9 +15,11 @@
|
||||||
*/
|
*/
|
||||||
#include "backend/session/gpu_session.h"
|
#include "backend/session/gpu_session.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
#include "backend/optimizer/common/pass_manager.h"
|
#include "backend/optimizer/common/pass_manager.h"
|
||||||
|
#include "backend/optimizer/common/common_backend_optimization.h"
|
||||||
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
|
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
|
||||||
#include "backend/optimizer/gpu/adam_fusion.h"
|
#include "backend/optimizer/gpu/adam_fusion.h"
|
||||||
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
|
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
|
||||||
|
@ -298,16 +300,31 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
|
||||||
|
|
||||||
GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||||
// Construct graph, if successfully, graph_sum_ + 1
|
// Construct graph, if successfully, graph_sum_ + 1
|
||||||
auto graph_id = graph_sum_;
|
|
||||||
auto graph = ConstructKernelGraph(lst, outputs);
|
auto graph = ConstructKernelGraph(lst, outputs);
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
return CompileGraphImpl(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphId GPUSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
||||||
|
std::vector<KernelGraphPtr> all_graphs;
|
||||||
|
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||||
|
MS_EXCEPTION_IF_NULL(root_graph);
|
||||||
|
if (all_graphs.size() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Gpu backend does not support multi-graph schedule. graph num" << all_graphs.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
opt::BackendCommonOptimization(root_graph);
|
||||||
|
return CompileGraphImpl(root_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
|
||||||
// Prepare ms context info for dump .pb graph
|
// Prepare ms context info for dump .pb graph
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||||
// Dump .pb graph before graph optimization
|
// Dump .pb graph before graph optimization
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
DumpIRProto(graph, "before_opt_" + std::to_string(graph_id));
|
DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
|
||||||
}
|
}
|
||||||
// Graph optimization irrelevant to device data format
|
// Graph optimization irrelevant to device data format
|
||||||
Optimize(graph);
|
Optimize(graph);
|
||||||
|
@ -326,7 +343,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
||||||
AssignStream(graph);
|
AssignStream(graph);
|
||||||
// Dump .pb graph before remove nop nodes
|
// Dump .pb graph before remove nop nodes
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
DumpIRProto(graph, "before_removeNop_" + std::to_string(graph_id));
|
DumpIRProto(graph, "before_removeNop_" + std::to_string(graph->graph_id()));
|
||||||
}
|
}
|
||||||
// Update Graph Dynamic Shape Attr.
|
// Update Graph Dynamic Shape Attr.
|
||||||
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
|
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
|
||||||
|
@ -343,7 +360,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
||||||
SetSummaryNodes(graph.get());
|
SetSummaryNodes(graph.get());
|
||||||
// Dump .pb graph after graph optimization
|
// Dump .pb graph after graph optimization
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));
|
DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
|
||||||
}
|
}
|
||||||
// Set graph manager.
|
// Set graph manager.
|
||||||
MS_EXCEPTION_IF_NULL(context_);
|
MS_EXCEPTION_IF_NULL(context_);
|
||||||
|
@ -361,9 +378,8 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
||||||
debugger_->LoadGraphs(graph);
|
debugger_->LoadGraphs(graph);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
MS_LOG(INFO) << "CompileGraph graph_id: " << graph_id;
|
MS_LOG(INFO) << "CompileGraph graph_id: " << graph->graph_id();
|
||||||
|
return graph->graph_id();
|
||||||
return graph_id;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
|
|
|
@ -37,6 +37,7 @@ class GPUSession : public SessionBasic {
|
||||||
protected:
|
protected:
|
||||||
void UnifyMindIR(const KernelGraphPtr &graph) override { return; }
|
void UnifyMindIR(const KernelGraphPtr &graph) override { return; }
|
||||||
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||||
|
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
|
||||||
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
|
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
|
||||||
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
|
@ -81,6 +82,8 @@ class GPUSession : public SessionBasic {
|
||||||
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
|
|
||||||
void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
|
|
||||||
|
GraphId CompileGraphImpl(KernelGraphPtr kernel_graph);
|
||||||
};
|
};
|
||||||
using GPUSessionPtr = std::shared_ptr<GPUSession>;
|
using GPUSessionPtr = std::shared_ptr<GPUSession>;
|
||||||
MS_REG_SESSION(kGPUDevice, GPUSession);
|
MS_REG_SESSION(kGPUDevice, GPUSession);
|
||||||
|
|
|
@ -15,10 +15,15 @@ if(ENABLE_ACL)
|
||||||
"model/model_converter_utils/*.cc"
|
"model/model_converter_utils/*.cc"
|
||||||
"graph/acl/*.cc"
|
"graph/acl/*.cc"
|
||||||
)
|
)
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_D)
|
if(ENABLE_D)
|
||||||
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/ms/*.cc")
|
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
"python_utils.cc" "model/ms/*.cc" "graph/ascend/*.cc")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(ENABLE_GPU)
|
||||||
|
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
||||||
|
@ -98,6 +103,15 @@ if(ENABLE_D)
|
||||||
target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server})
|
target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(ENABLE_GPU)
|
||||||
|
target_link_libraries(mindspore_shared_lib PRIVATE gpu_cuda_lib gpu_queue cublas
|
||||||
|
${CUDA_PATH}/lib64/libcurand.so
|
||||||
|
${CUDNN_LIBRARY_PATH}
|
||||||
|
${CUDA_PATH}/lib64/libcudart.so
|
||||||
|
${CUDA_PATH}/lib64/stubs/libcuda.so
|
||||||
|
${CUDA_PATH}/lib64/libcusolver.so)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
set(MINDSPORE_RPATH $ORIGIN)
|
set(MINDSPORE_RPATH $ORIGIN)
|
||||||
if(ENABLE_D)
|
if(ENABLE_D)
|
||||||
|
@ -110,7 +124,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
set(MINDSPORE_RPATH
|
||||||
|
${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||||
elseif(ENABLE_ACL)
|
elseif(ENABLE_ACL)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/atc/lib64)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/atc/lib64)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/atc/lib64)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/atc/lib64)
|
||||||
|
@ -121,7 +136,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
set(MINDSPORE_RPATH
|
||||||
|
${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set_target_properties(mindspore_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
set_target_properties(mindspore_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "cxx_api/graph/ms/ms_graph_impl.h"
|
#include "cxx_api/graph/ascend/ascend_graph_impl.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
#include "cxx_api/factory.h"
|
#include "cxx_api/factory.h"
|
||||||
|
@ -26,43 +26,9 @@
|
||||||
#include "runtime/device/kernel_runtime_manager.h"
|
#include "runtime/device/kernel_runtime_manager.h"
|
||||||
|
|
||||||
namespace mindspore::api {
|
namespace mindspore::api {
|
||||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl);
|
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
|
||||||
|
|
||||||
static DataType TransTypeId2InferDataType(TypeId type_id) {
|
AscendGraphImpl::AscendGraphImpl()
|
||||||
const std::map<TypeId, api::DataType> id2type_map{
|
|
||||||
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
|
|
||||||
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
|
|
||||||
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
|
|
||||||
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
|
|
||||||
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
|
|
||||||
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
|
|
||||||
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
|
|
||||||
};
|
|
||||||
|
|
||||||
// cppcheck-suppress stlIfFind
|
|
||||||
if (auto it = id2type_map.find(type_id); it != id2type_map.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(WARNING) << "Unsupported data id " << type_id;
|
|
||||||
return api::kMsUnknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
inline static void ClearIfNotNull(T *vec) {
|
|
||||||
if (vec != nullptr) {
|
|
||||||
vec->clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T, class U = std::vector<T>>
|
|
||||||
inline static void PushbackIfNotNull(U *vec, T &&item) {
|
|
||||||
if (vec != nullptr) {
|
|
||||||
vec->emplace_back(item);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MsGraphImpl::MsGraphImpl()
|
|
||||||
: session_impl_(nullptr),
|
: session_impl_(nullptr),
|
||||||
graph_id_(0),
|
graph_id_(0),
|
||||||
device_type_("Ascend"),
|
device_type_("Ascend"),
|
||||||
|
@ -75,9 +41,9 @@ MsGraphImpl::MsGraphImpl()
|
||||||
init_flag_(false),
|
init_flag_(false),
|
||||||
load_flag_(false) {}
|
load_flag_(false) {}
|
||||||
|
|
||||||
MsGraphImpl::~MsGraphImpl() { (void)FinalizeEnv(); }
|
AscendGraphImpl::~AscendGraphImpl() { (void)FinalizeEnv(); }
|
||||||
|
|
||||||
Status MsGraphImpl::InitEnv() {
|
Status AscendGraphImpl::InitEnv() {
|
||||||
if (init_flag_) {
|
if (init_flag_) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -108,7 +74,7 @@ Status MsGraphImpl::InitEnv() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsGraphImpl::FinalizeEnv() {
|
Status AscendGraphImpl::FinalizeEnv() {
|
||||||
if (!init_flag_) {
|
if (!init_flag_) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -136,7 +102,7 @@ Status MsGraphImpl::FinalizeEnv() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
|
Status AscendGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
|
||||||
MS_ASSERT(session_impl_ != nullptr);
|
MS_ASSERT(session_impl_ != nullptr);
|
||||||
try {
|
try {
|
||||||
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||||
|
@ -147,7 +113,7 @@ Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
|
std::vector<tensor::TensorPtr> AscendGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
|
||||||
try {
|
try {
|
||||||
VectorRef outputs;
|
VectorRef outputs;
|
||||||
session_impl_->RunGraph(graph_id_, inputs, &outputs);
|
session_impl_->RunGraph(graph_id_, inputs, &outputs);
|
||||||
|
@ -158,7 +124,7 @@ std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::T
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const {
|
Status AscendGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const {
|
||||||
MS_ASSERT(session_impl_ != nullptr);
|
MS_ASSERT(session_impl_ != nullptr);
|
||||||
std::string error_msg;
|
std::string error_msg;
|
||||||
if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) {
|
if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) {
|
||||||
|
@ -167,7 +133,7 @@ Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &input
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
|
Status AscendGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
|
||||||
MS_EXCEPTION_IF_NULL(reply);
|
MS_EXCEPTION_IF_NULL(reply);
|
||||||
if (context_ == nullptr) {
|
if (context_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||||
|
@ -206,8 +172,8 @@ Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
Status AscendGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load();
|
||||||
if (ret != SUCCESS) {
|
if (ret != SUCCESS) {
|
||||||
|
@ -216,22 +182,22 @@ Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ClearIfNotNull(names);
|
GraphUtils::ClearIfNotNull(names);
|
||||||
ClearIfNotNull(shapes);
|
GraphUtils::ClearIfNotNull(shapes);
|
||||||
ClearIfNotNull(data_types);
|
GraphUtils::ClearIfNotNull(data_types);
|
||||||
ClearIfNotNull(mem_sizes);
|
GraphUtils::ClearIfNotNull(mem_sizes);
|
||||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||||
auto &tensor = inputs_[i];
|
auto &tensor = inputs_[i];
|
||||||
PushbackIfNotNull(names, input_names_[i]);
|
GraphUtils::PushbackIfNotNull(names, input_names_[i]);
|
||||||
PushbackIfNotNull(shapes, tensor->shape());
|
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
|
||||||
PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type()));
|
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
|
||||||
PushbackIfNotNull(mem_sizes, tensor->Size());
|
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
Status AscendGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load();
|
||||||
if (ret != SUCCESS) {
|
if (ret != SUCCESS) {
|
||||||
|
@ -240,22 +206,22 @@ Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ClearIfNotNull(names);
|
GraphUtils::ClearIfNotNull(names);
|
||||||
ClearIfNotNull(shapes);
|
GraphUtils::ClearIfNotNull(shapes);
|
||||||
ClearIfNotNull(data_types);
|
GraphUtils::ClearIfNotNull(data_types);
|
||||||
ClearIfNotNull(mem_sizes);
|
GraphUtils::ClearIfNotNull(mem_sizes);
|
||||||
for (size_t i = 0; i < outputs_.size(); i++) {
|
for (size_t i = 0; i < outputs_.size(); i++) {
|
||||||
auto &tensor = outputs_[i];
|
auto &tensor = outputs_[i];
|
||||||
PushbackIfNotNull(names, output_names_[i]);
|
GraphUtils::PushbackIfNotNull(names, output_names_[i]);
|
||||||
PushbackIfNotNull(shapes, tensor->shape());
|
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
|
||||||
PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type()));
|
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
|
||||||
PushbackIfNotNull(mem_sizes, tensor->Size());
|
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
|
||||||
}
|
}
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsGraphImpl::Load() {
|
Status AscendGraphImpl::Load() {
|
||||||
// check graph type
|
// check graph type
|
||||||
if (graph_->ModelType() != ModelType::kMindIR) {
|
if (graph_->ModelType() != ModelType::kMindIR) {
|
||||||
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
||||||
|
@ -311,7 +277,7 @@ Status MsGraphImpl::Load() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
Status AscendGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(outputs);
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load();
|
|
@ -13,8 +13,8 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H
|
||||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -28,12 +28,13 @@
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "cxx_api/model/model_impl.h"
|
#include "cxx_api/model/model_impl.h"
|
||||||
#include "runtime/context.h"
|
#include "runtime/context.h"
|
||||||
|
#include "cxx_api/graph/graph_utils.h"
|
||||||
|
|
||||||
namespace mindspore::api {
|
namespace mindspore::api {
|
||||||
class MsGraphImpl : public GraphCell::GraphImpl {
|
class AscendGraphImpl : public GraphCell::GraphImpl {
|
||||||
public:
|
public:
|
||||||
MsGraphImpl();
|
AscendGraphImpl();
|
||||||
~MsGraphImpl() override;
|
~AscendGraphImpl() override;
|
||||||
|
|
||||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||||
Status Load() override;
|
Status Load() override;
|
||||||
|
@ -63,4 +64,4 @@ class MsGraphImpl : public GraphCell::GraphImpl {
|
||||||
bool load_flag_;
|
bool load_flag_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::api
|
} // namespace mindspore::api
|
||||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H
|
|
@ -20,10 +20,10 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() {
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
|
bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
|
||||||
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
|
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
|
||||||
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
|
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
|
||||||
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) {
|
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,256 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "cxx_api/graph/gpu/gpu_graph_impl.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "cxx_api/factory.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "mindspore/core/base/base_ref_utils.h"
|
||||||
|
#include "backend/session/session_factory.h"
|
||||||
|
#include "backend/session/executor_manager.h"
|
||||||
|
#include "runtime/device/kernel_runtime_manager.h"
|
||||||
|
|
||||||
|
namespace mindspore::api {
|
||||||
|
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
|
||||||
|
|
||||||
|
GPUGraphImpl::GPUGraphImpl()
|
||||||
|
: session_impl_(nullptr),
|
||||||
|
graph_id_(0),
|
||||||
|
device_id_(Context::Instance().GetDeviceID()),
|
||||||
|
inputs_(),
|
||||||
|
outputs_(),
|
||||||
|
input_names_(),
|
||||||
|
output_names_(),
|
||||||
|
init_flag_(false),
|
||||||
|
load_flag_(false) {}
|
||||||
|
|
||||||
|
Status GPUGraphImpl::InitEnv() {
|
||||||
|
if (init_flag_) {
|
||||||
|
MS_LOG(WARNING) << "Initialized again, return success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
if (ms_context == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Get Context failed!";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||||
|
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||||
|
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice);
|
||||||
|
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true);
|
||||||
|
|
||||||
|
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
|
||||||
|
if (session_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice
|
||||||
|
<< " is available.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
session_impl_->Init(device_id_);
|
||||||
|
init_flag_ = true;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GPUGraphImpl::FinalizeEnv() {
|
||||||
|
if (!init_flag_) {
|
||||||
|
MS_LOG(WARNING) << "Never initialize before, return success";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG_INFO << "Start finalize env";
|
||||||
|
session::ExecutorManager::Instance().Clear();
|
||||||
|
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||||
|
|
||||||
|
init_flag_ = false;
|
||||||
|
MS_LOG(INFO) << "End finalize env";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GPUGraphImpl::Load() {
|
||||||
|
// check graph type
|
||||||
|
if (graph_->ModelType() != ModelType::kMindIR) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
||||||
|
return INVALID_INPUTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &graph_data = GraphImpl::MutableGraphData();
|
||||||
|
MS_EXCEPTION_IF_NULL(graph_data);
|
||||||
|
auto func_graph = graph_data->GetFuncGraph();
|
||||||
|
|
||||||
|
// init
|
||||||
|
Status ret = InitEnv();
|
||||||
|
if (ret != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "InitEnv failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = CompileGraph(func_graph);
|
||||||
|
if (ret != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Compile graph model failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
|
||||||
|
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
|
||||||
|
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
|
||||||
|
MS_LOG_ERROR << "Get model inputs info failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
|
||||||
|
MS_LOG_ERROR << "Get model outputs info failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
load_flag_ = true;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GPUGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
|
||||||
|
MS_ASSERT(session_impl_ != nullptr);
|
||||||
|
try {
|
||||||
|
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||||
|
return SUCCESS;
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<tensor::TensorPtr> GPUGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
|
||||||
|
try {
|
||||||
|
VectorRef outputs;
|
||||||
|
session_impl_->RunGraph(graph_id_, inputs, &outputs);
|
||||||
|
return TransformVectorRefToMultiTensor(outputs);
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
MS_LOG(ERROR) << "RunGraph failed: " << e.what();
|
||||||
|
return std::vector<tensor::TensorPtr>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GPUGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
|
||||||
|
MS_EXCEPTION_IF_NULL(reply);
|
||||||
|
|
||||||
|
vector<tensor::TensorPtr> inputs;
|
||||||
|
for (size_t i = 0; i < request.size(); i++) {
|
||||||
|
auto &item = request[i];
|
||||||
|
auto input = inputs_[i];
|
||||||
|
if (input->Size() != item.DataSize()) {
|
||||||
|
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
|
||||||
|
<< input->Size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
|
||||||
|
if (ret != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Tensor copy failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
inputs.push_back(input);
|
||||||
|
}
|
||||||
|
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
|
||||||
|
if (outputs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Execute Model Failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
reply->clear();
|
||||||
|
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
|
||||||
|
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GPUGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
|
if (!load_flag_) {
|
||||||
|
Status ret = Load();
|
||||||
|
if (ret != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs.size() != inputs_.size()) {
|
||||||
|
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
|
||||||
|
return INVALID_INPUTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||||
|
if (inputs[i].DataSize() != inputs_[i]->Size()) {
|
||||||
|
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
|
||||||
|
<< inputs[i].DataSize();
|
||||||
|
return INVALID_INPUTS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (ExecuteModel(inputs, outputs) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Execute Model Failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (outputs_.size() != outputs->size()) {
|
||||||
|
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
|
||||||
|
<< outputs_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GPUGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||||
|
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||||
|
if (!load_flag_) {
|
||||||
|
Status ret = Load();
|
||||||
|
if (ret != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphUtils::ClearIfNotNull(names);
|
||||||
|
GraphUtils::ClearIfNotNull(shapes);
|
||||||
|
GraphUtils::ClearIfNotNull(data_types);
|
||||||
|
GraphUtils::ClearIfNotNull(mem_sizes);
|
||||||
|
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||||
|
auto &tensor = inputs_[i];
|
||||||
|
GraphUtils::PushbackIfNotNull(names, input_names_[i]);
|
||||||
|
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
|
||||||
|
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
|
||||||
|
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GPUGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||||
|
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||||
|
if (!load_flag_) {
|
||||||
|
Status ret = Load();
|
||||||
|
if (ret != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphUtils::ClearIfNotNull(names);
|
||||||
|
GraphUtils::ClearIfNotNull(shapes);
|
||||||
|
GraphUtils::ClearIfNotNull(data_types);
|
||||||
|
GraphUtils::ClearIfNotNull(mem_sizes);
|
||||||
|
for (size_t i = 0; i < outputs_.size(); i++) {
|
||||||
|
auto &tensor = outputs_[i];
|
||||||
|
GraphUtils::PushbackIfNotNull(names, output_names_[i]);
|
||||||
|
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
|
||||||
|
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
|
||||||
|
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore::api
|
|
@ -0,0 +1,67 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H
|
||||||
|
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/status.h"
|
||||||
|
#include "include/api/graph.h"
|
||||||
|
#include "cxx_api/graph/graph_impl.h"
|
||||||
|
#include "backend/session/session_basic.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "cxx_api/model/model_impl.h"
|
||||||
|
#include "cxx_api/graph/graph_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore::api {
|
||||||
|
class GPUGraphImpl : public GraphCell::GraphImpl {
|
||||||
|
public:
|
||||||
|
GPUGraphImpl();
|
||||||
|
~GPUGraphImpl() override = default;
|
||||||
|
|
||||||
|
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||||
|
Status Load() override;
|
||||||
|
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||||
|
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||||
|
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||||
|
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status InitEnv();
|
||||||
|
Status FinalizeEnv();
|
||||||
|
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
|
||||||
|
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
|
||||||
|
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
|
||||||
|
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||||
|
|
||||||
|
std::shared_ptr<session::SessionBasic> session_impl_;
|
||||||
|
uint32_t graph_id_;
|
||||||
|
std::string device_type_;
|
||||||
|
uint32_t device_id_;
|
||||||
|
std::vector<tensor::TensorPtr> inputs_;
|
||||||
|
std::vector<tensor::TensorPtr> outputs_;
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
bool init_flag_;
|
||||||
|
bool load_flag_;
|
||||||
|
|
||||||
|
// tensor-rt
|
||||||
|
uint32_t batch_size_;
|
||||||
|
uint32_t workspace_size_;
|
||||||
|
};
|
||||||
|
} // namespace mindspore::api
|
||||||
|
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H
|
|
@ -0,0 +1,63 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
|
||||||
|
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "ir/dtype/type_id.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore::api {
|
||||||
|
class GraphUtils {
|
||||||
|
public:
|
||||||
|
static DataType TransTypeId2InferDataType(TypeId type_id) {
|
||||||
|
const std::map<TypeId, api::DataType> id2type_map{
|
||||||
|
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
|
||||||
|
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
|
||||||
|
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
|
||||||
|
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
|
||||||
|
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
|
||||||
|
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
|
||||||
|
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
|
||||||
|
};
|
||||||
|
|
||||||
|
auto it = id2type_map.find(type_id);
|
||||||
|
if (it != id2type_map.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(WARNING) << "Unsupported data id " << type_id;
|
||||||
|
return api::kMsUnknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
inline static void ClearIfNotNull(T *vec) {
|
||||||
|
if (vec != nullptr) {
|
||||||
|
vec->clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T, class U>
|
||||||
|
inline static void PushbackIfNotNull(U *vec, T &&item) {
|
||||||
|
if (vec != nullptr) {
|
||||||
|
vec->emplace_back(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace mindspore::api
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
|
|
@ -22,6 +22,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace api {
|
namespace api {
|
||||||
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
|
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
|
||||||
|
API_FACTORY_REG(ModelImpl, GPU, MsModel);
|
||||||
|
|
||||||
Status MsModel::Build(const std::map<std::string, std::string> &) {
|
Status MsModel::Build(const std::map<std::string, std::string> &) {
|
||||||
MS_LOG(INFO) << "Start build model.";
|
MS_LOG(INFO) << "Start build model.";
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
namespace mindspore::api {
|
namespace mindspore::api {
|
||||||
const char *kDeviceTypeAscend310 = "Ascend310";
|
const char *kDeviceTypeAscend310 = "Ascend310";
|
||||||
const char *kDeviceTypeAscend910 = "Ascend910";
|
const char *kDeviceTypeAscend910 = "Ascend910";
|
||||||
|
const char *kDeviceTypeGpu = "GPU";
|
||||||
|
|
||||||
class DataImpl {
|
class DataImpl {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -31,7 +31,7 @@ constexpr Status PROF_FAILED = 0xFFFFFFFF;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status RegProfCtrlCallback(MsprofCtrlCallback func) {
|
Status RegProfCtrlCallback(MsprofCtrlCallback func) {
|
||||||
if (VMCallbackRegister::GetInstance().registed()) {
|
if (VMCallbackRegister::GetInstance().registered()) {
|
||||||
return VMCallbackRegister::GetInstance().DoRegProfCtrlCallback(func);
|
return VMCallbackRegister::GetInstance().DoRegProfCtrlCallback(func);
|
||||||
} else {
|
} else {
|
||||||
return PROF_SUCCESS;
|
return PROF_SUCCESS;
|
||||||
|
@ -39,7 +39,7 @@ Status RegProfCtrlCallback(MsprofCtrlCallback func) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) {
|
Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) {
|
||||||
if (VMCallbackRegister::GetInstance().registed()) {
|
if (VMCallbackRegister::GetInstance().registered()) {
|
||||||
return VMCallbackRegister::GetInstance().DoRegProfSetDeviceCallback(func);
|
return VMCallbackRegister::GetInstance().DoRegProfSetDeviceCallback(func);
|
||||||
} else {
|
} else {
|
||||||
return PROF_SUCCESS;
|
return PROF_SUCCESS;
|
||||||
|
@ -47,7 +47,7 @@ Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RegProfReporterCallback(MsprofReporterCallback func) {
|
Status RegProfReporterCallback(MsprofReporterCallback func) {
|
||||||
if (VMCallbackRegister::GetInstance().registed()) {
|
if (VMCallbackRegister::GetInstance().registered()) {
|
||||||
return VMCallbackRegister::GetInstance().DoRegProfReporterCallback(func);
|
return VMCallbackRegister::GetInstance().DoRegProfReporterCallback(func);
|
||||||
} else {
|
} else {
|
||||||
return PROF_SUCCESS;
|
return PROF_SUCCESS;
|
||||||
|
@ -55,7 +55,7 @@ Status RegProfReporterCallback(MsprofReporterCallback func) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) {
|
Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) {
|
||||||
if (VMCallbackRegister::GetInstance().registed()) {
|
if (VMCallbackRegister::GetInstance().registered()) {
|
||||||
return VMCallbackRegister::GetInstance().DoProfCommandHandle(type, data, len);
|
return VMCallbackRegister::GetInstance().DoProfCommandHandle(type, data, len);
|
||||||
} else {
|
} else {
|
||||||
return PROF_SUCCESS;
|
return PROF_SUCCESS;
|
||||||
|
@ -69,16 +69,16 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() {
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
|
bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
|
||||||
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
|
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
|
||||||
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
|
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
|
||||||
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) {
|
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) {
|
||||||
if (!registed_) {
|
if (!registered_) {
|
||||||
pRegProfCtrlCallback_ = pRegProfCtrlCallback;
|
pRegProfCtrlCallback_ = pRegProfCtrlCallback;
|
||||||
pRegProfSetDeviceCallback_ = pRegProfSetDeviceCallback;
|
pRegProfSetDeviceCallback_ = pRegProfSetDeviceCallback;
|
||||||
pRegProfReporterCallback_ = pRegProfReporterCallback;
|
pRegProfReporterCallback_ = pRegProfReporterCallback;
|
||||||
pProfCommandHandle_ = pProfCommandHandle;
|
pProfCommandHandle_ = pProfCommandHandle;
|
||||||
registed_ = true;
|
registered_ = true;
|
||||||
ForceMsprofilerInit();
|
ForceMsprofilerInit();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,12 +49,12 @@ class __attribute__((visibility("default"))) VMCallbackRegister {
|
||||||
static VMCallbackRegister &GetInstance();
|
static VMCallbackRegister &GetInstance();
|
||||||
VMCallbackRegister(const VMCallbackRegister &) = delete;
|
VMCallbackRegister(const VMCallbackRegister &) = delete;
|
||||||
VMCallbackRegister &operator=(const VMCallbackRegister &) = delete;
|
VMCallbackRegister &operator=(const VMCallbackRegister &) = delete;
|
||||||
bool Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
|
bool Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
|
||||||
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
|
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
|
||||||
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
|
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
|
||||||
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t));
|
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t));
|
||||||
void ForceMsprofilerInit();
|
void ForceMsprofilerInit();
|
||||||
bool registed() { return registed_; }
|
bool registered() { return registered_; }
|
||||||
Status DoRegProfCtrlCallback(MsprofCtrlCallback func) { return pRegProfCtrlCallback_(func); }
|
Status DoRegProfCtrlCallback(MsprofCtrlCallback func) { return pRegProfCtrlCallback_(func); }
|
||||||
Status DoRegProfSetDeviceCallback(MsprofSetDeviceCallback func) { return pRegProfSetDeviceCallback_(func); }
|
Status DoRegProfSetDeviceCallback(MsprofSetDeviceCallback func) { return pRegProfSetDeviceCallback_(func); }
|
||||||
Status DoRegProfReporterCallback(MsprofReporterCallback func) { return pRegProfReporterCallback_(func); }
|
Status DoRegProfReporterCallback(MsprofReporterCallback func) { return pRegProfReporterCallback_(func); }
|
||||||
|
@ -64,7 +64,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
VMCallbackRegister()
|
VMCallbackRegister()
|
||||||
: registed_(false),
|
: registered_(false),
|
||||||
ms_profile_inited_(false),
|
ms_profile_inited_(false),
|
||||||
pRegProfCtrlCallback_(nullptr),
|
pRegProfCtrlCallback_(nullptr),
|
||||||
pRegProfSetDeviceCallback_(nullptr),
|
pRegProfSetDeviceCallback_(nullptr),
|
||||||
|
@ -72,7 +72,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister {
|
||||||
pProfCommandHandle_(nullptr) {}
|
pProfCommandHandle_(nullptr) {}
|
||||||
~VMCallbackRegister() = default;
|
~VMCallbackRegister() = default;
|
||||||
|
|
||||||
bool registed_;
|
bool registered_;
|
||||||
bool ms_profile_inited_;
|
bool ms_profile_inited_;
|
||||||
Status (*pRegProfCtrlCallback_)(MsprofCtrlCallback);
|
Status (*pRegProfCtrlCallback_)(MsprofCtrlCallback);
|
||||||
Status (*pRegProfSetDeviceCallback_)(MsprofSetDeviceCallback);
|
Status (*pRegProfSetDeviceCallback_)(MsprofSetDeviceCallback);
|
||||||
|
|
|
@ -299,8 +299,8 @@ Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) {
|
||||||
|
|
||||||
bool DoRegiste() {
|
bool DoRegiste() {
|
||||||
MS_LOG(INFO) << "VM profiling register start";
|
MS_LOG(INFO) << "VM profiling register start";
|
||||||
return VMCallbackRegister::GetInstance().Registe(RegProfCtrlCallback, RegProfSetDeviceCallback,
|
return VMCallbackRegister::GetInstance().Register(RegProfCtrlCallback, RegProfSetDeviceCallback,
|
||||||
RegProfReporterCallback, ProfCommandHandle);
|
RegProfReporterCallback, ProfCommandHandle);
|
||||||
}
|
}
|
||||||
static bool doRegiste = DoRegiste();
|
static bool doRegiste = DoRegiste();
|
||||||
} // namespace ascend
|
} // namespace ascend
|
||||||
|
|
|
@ -41,6 +41,7 @@ const char kCPUDevice[] = "CPU";
|
||||||
const char kGPUDevice[] = "GPU";
|
const char kGPUDevice[] = "GPU";
|
||||||
const char kAscendDevice[] = "Ascend";
|
const char kAscendDevice[] = "Ascend";
|
||||||
const char kDavinciInferenceDevice[] = "AscendInference";
|
const char kDavinciInferenceDevice[] = "AscendInference";
|
||||||
|
const char kGpuInferenceDevice[] = "GpuInference";
|
||||||
const char kDavinciDevice[] = "Davinci";
|
const char kDavinciDevice[] = "Davinci";
|
||||||
const char KNpuLog[] = "_npu_log";
|
const char KNpuLog[] = "_npu_log";
|
||||||
const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
|
const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
|
||||||
|
@ -51,7 +52,7 @@ const float kDefaultMaxDeviceMemory = 1024;
|
||||||
|
|
||||||
// enum definition for MindSpore Context Parameter
|
// enum definition for MindSpore Context Parameter
|
||||||
enum MsCtxParam : unsigned {
|
enum MsCtxParam : unsigned {
|
||||||
// paramater of type bool
|
// parameter of type bool
|
||||||
MS_CTX_TYPE_BOOL_BEGIN,
|
MS_CTX_TYPE_BOOL_BEGIN,
|
||||||
MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN,
|
MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN,
|
||||||
MS_CTX_CHECK_BPROP_FLAG,
|
MS_CTX_CHECK_BPROP_FLAG,
|
||||||
|
@ -74,14 +75,15 @@ enum MsCtxParam : unsigned {
|
||||||
MS_CTX_ENABLE_PROFILING,
|
MS_CTX_ENABLE_PROFILING,
|
||||||
MS_CTX_SAVE_GRAPHS_FLAG,
|
MS_CTX_SAVE_GRAPHS_FLAG,
|
||||||
MS_CTX_ENABLE_PARALLEL_SPLIT,
|
MS_CTX_ENABLE_PARALLEL_SPLIT,
|
||||||
|
MS_CTX_ENABLE_INFER_OPT,
|
||||||
MS_CTX_TYPE_BOOL_END,
|
MS_CTX_TYPE_BOOL_END,
|
||||||
|
|
||||||
// paramater of type int
|
// parameter of type int
|
||||||
MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END,
|
MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END,
|
||||||
MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN,
|
MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN,
|
||||||
MS_CTX_TYPE_INT_END,
|
MS_CTX_TYPE_INT_END,
|
||||||
|
|
||||||
// paramater of type uint32
|
// parameter of type uint32
|
||||||
MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END,
|
MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END,
|
||||||
MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN,
|
MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN,
|
||||||
MS_CTX_GE_REF,
|
MS_CTX_GE_REF,
|
||||||
|
@ -89,12 +91,12 @@ enum MsCtxParam : unsigned {
|
||||||
MS_CTX_TSD_REF,
|
MS_CTX_TSD_REF,
|
||||||
MS_CTX_TYPE_UINT32_END,
|
MS_CTX_TYPE_UINT32_END,
|
||||||
|
|
||||||
// paramater of type float
|
// parameter of type float
|
||||||
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
|
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
|
||||||
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
|
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
|
||||||
MS_CTX_TYPE_FLOAT_END,
|
MS_CTX_TYPE_FLOAT_END,
|
||||||
|
|
||||||
// paramater of type string
|
// parameter of type string
|
||||||
MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END,
|
MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END,
|
||||||
MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN,
|
MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN,
|
||||||
MS_CTX_GRAPH_MEMORY_MAX_SIZE,
|
MS_CTX_GRAPH_MEMORY_MAX_SIZE,
|
||||||
|
|
Loading…
Reference in New Issue