mindspore serving support gpu backend

This commit is contained in:
wilfChen 2021-01-28 09:17:24 +08:00
parent 3708624a25
commit a911b9ef9e
19 changed files with 775 additions and 112 deletions

View File

@ -104,6 +104,7 @@ class MS_API Buffer {
extern MS_API const char *kDeviceTypeAscend310;
extern MS_API const char *kDeviceTypeAscend910;
extern MS_API const char *kDeviceTypeGpu;
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file

View File

@ -13,7 +13,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
endif()
if(ENABLE_GPU)
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc")
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc" "gpu_inference_session.cc")
list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST})
endif()

View File

@ -0,0 +1,217 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "backend/session/gpu_inference_session.h"
#include "ir/tensor.h"
#include "ir/anf.h"
#include "ir/param_info.h"
#include "runtime/device/kernel_runtime.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/ms_utils.h"
#include "common/trans.h"
#include "utils/config_manager.h"
namespace mindspore {
namespace session {
void GpuInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<tensor::TensorPtr> inputs(inputs_const);
auto input_nodes = kernel_graph->inputs();
size_t no_weight_input = 0;
for (size_t i = 0; i < input_nodes.size(); ++i) {
tensor::TensorPtr tensor = nullptr;
if (!input_nodes[i]->isa<Parameter>()) {
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter";
continue;
}
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(pk_node);
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
if (!AnfAlgo::IsParameterWeight(pk_node)) {
tensor = inputs[no_weight_input++];
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
}
}
GraphId GpuInferenceSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
auto graph_id = GPUSession::CompileGraphImpl(func_graph);
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
// load weight data to device
auto input_nodes = kernel_graph->inputs();
for (size_t i = 0; i < input_nodes.size(); ++i) {
if (!input_nodes[i]->isa<Parameter>()) {
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter";
continue;
}
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(pk_node);
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
if (AnfAlgo::IsParameterWeight(pk_node)) {
const auto &param_value = pk_node->default_param();
MS_EXCEPTION_IF_NULL(param_value);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
MS_EXCEPTION_IF_NULL(tensor);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
}
return graph_id;
}
bool GpuInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
std::string *error_msg) const {
MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id;
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto kernel_graph_inputs = kernel_graph->inputs();
size_t no_weight_input = 0;
vector<ParameterPtr> paras;
// find parameters of graph inputs
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
continue;
}
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
if (!AnfAlgo::IsParameterWeight(parameter)) {
paras.push_back(parameter);
}
}
// check inputs
for (size_t i = 0; i < paras.size(); ++i) {
// compare input number
if (paras.size() != inputs.size()) {
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
<< "] but the graph input number is [" << paras.size() << "]";
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
if (error_msg != nullptr) {
std::stringstream str_stream;
str_stream << "Input number is inconsistent. The given input number [" << inputs.size()
<< "] but the graph input number is [" << paras.size() << "]\n";
str_stream << "InputsInfo --" << InputsInfo(paras, inputs);
*error_msg = str_stream.str();
}
return false;
}
auto input = inputs[no_weight_input++];
if (!CompareInput(input, paras[i])) {
MS_LOG(ERROR) << "Please check the input information.";
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
if (error_msg != nullptr) {
std::stringstream str_stream;
str_stream << "Please check the input information.\n";
str_stream << "InputsInfo --" << InputsInfo(paras, inputs);
*error_msg = str_stream.str();
}
return false;
}
}
return true;
}
bool GpuInferenceSession::CompareInput(const tensor::TensorPtr &input, const ParameterPtr &parameter) const {
MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(parameter);
// compare dims
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
// compare shape
auto input_shape = input->shape();
vector<size_t> trans_input;
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input),
[](const int64_t dim) { return static_cast<size_t>(dim); });
auto is_scalar_shape = [](const vector<size_t> &shape) {
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
};
if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) {
MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input)
<< ", but the parameter shape is " << PrintInputShape(parameter_shape)
<< ". parameter : " << parameter->DebugString();
return false;
}
// compare data type
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) {
MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type()
<< ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0)
<< ". parameter : " << parameter->DebugString();
return false;
}
return true;
}
template <typename T>
std::string GpuInferenceSession::PrintInputShape(std::vector<T> shape) const {
string res = "[";
for (auto dim : shape) {
res += " " + std::to_string(dim);
}
return res + " ]";
}
std::string GpuInferenceSession::InputsInfo(const std::vector<ParameterPtr> &paras,
const std::vector<tensor::TensorPtr> &inputs) const {
const std::map<TypeId, std::string> dtype_name_map{
{TypeId::kNumberTypeBegin, "Unknown"}, {TypeId::kNumberTypeBool, "Bool"},
{TypeId::kNumberTypeFloat64, "Float64"}, {TypeId::kNumberTypeInt8, "Int8"},
{TypeId::kNumberTypeUInt8, "Uint8"}, {TypeId::kNumberTypeInt16, "Int16"},
{TypeId::kNumberTypeUInt16, "Uint16"}, {TypeId::kNumberTypeInt32, "Int32"},
{TypeId::kNumberTypeUInt32, "Uint32"}, {TypeId::kNumberTypeInt64, "Int64"},
{TypeId::kNumberTypeUInt64, "Uint64"}, {TypeId::kNumberTypeFloat16, "Float16"},
{TypeId::kNumberTypeFloat32, "Float32"},
};
auto data_type_to_string = [&dtype_name_map](TypeId type_id) {
auto it = dtype_name_map.find(type_id);
if (it == dtype_name_map.end()) {
return std::string("Unknown");
}
return it->second;
};
std::string graph = "graph inputs:{ ";
for (size_t i = 0; i < paras.size(); ++i) {
auto &para = paras[i];
graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(para, 0).size()) +
", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(para, 0)) + ", data type " +
data_type_to_string(AnfAlgo::GetSelectKernelBuildInfo(para)->GetOutputDeviceType(0)) + " }";
}
std::string actual = "given inputs:{ ";
for (size_t i = 0; i < inputs.size(); ++i) {
actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " +
PrintInputShape(inputs[i]->shape()) + ", data type " + data_type_to_string(inputs[i]->data_type()) + " }";
}
return graph + " " + actual;
}
} // namespace session
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H
#define MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H
#include <unordered_map>
#include <string>
#include <memory>
#include <vector>
#include <utility>
#include <stack>
#include <map>
#include <tuple>
#include <set>
#include "backend/session/gpu_session.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/session_factory.h"
namespace mindspore {
namespace session {
class GpuInferenceSession : public gpu::GPUSession {
public:
GpuInferenceSession() = default;
~GpuInferenceSession() = default;
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const;
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
std::string *error_msg) const override;
bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr &parameter) const;
template <typename T>
std::string PrintInputShape(std::vector<T> shape) const;
std::string InputsInfo(const std::vector<ParameterPtr> &paras, const std::vector<tensor::TensorPtr> &inputs) const;
protected:
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
};
MS_REG_SESSION(kGpuInferenceDevice, GpuInferenceSession);
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H

View File

@ -15,9 +15,11 @@
*/
#include "backend/session/gpu_session.h"
#include <string>
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/common/common_backend_optimization.h"
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
#include "backend/optimizer/gpu/adam_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
@ -298,16 +300,31 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
// Construct graph, if successfully, graph_sum_ + 1
auto graph_id = graph_sum_;
auto graph = ConstructKernelGraph(lst, outputs);
MS_EXCEPTION_IF_NULL(graph);
return CompileGraphImpl(graph);
}
GraphId GPUSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
std::vector<KernelGraphPtr> all_graphs;
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
MS_EXCEPTION_IF_NULL(root_graph);
if (all_graphs.size() != 1) {
MS_LOG(EXCEPTION) << "Gpu backend does not support multi-graph schedule. graph num" << all_graphs.size();
}
opt::BackendCommonOptimization(root_graph);
return CompileGraphImpl(root_graph);
}
GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
// Prepare ms context info for dump .pb graph
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
// Dump .pb graph before graph optimization
if (save_graphs) {
DumpIRProto(graph, "before_opt_" + std::to_string(graph_id));
DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
}
// Graph optimization irrelevant to device data format
Optimize(graph);
@ -326,7 +343,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
AssignStream(graph);
// Dump .pb graph before remove nop nodes
if (save_graphs) {
DumpIRProto(graph, "before_removeNop_" + std::to_string(graph_id));
DumpIRProto(graph, "before_removeNop_" + std::to_string(graph->graph_id()));
}
// Update Graph Dynamic Shape Attr.
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
@ -343,7 +360,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
SetSummaryNodes(graph.get());
// Dump .pb graph after graph optimization
if (save_graphs) {
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));
DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
}
// Set graph manager.
MS_EXCEPTION_IF_NULL(context_);
@ -361,9 +378,8 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
debugger_->LoadGraphs(graph);
}
#endif
MS_LOG(INFO) << "CompileGraph graph_id: " << graph_id;
return graph_id;
MS_LOG(INFO) << "CompileGraph graph_id: " << graph->graph_id();
return graph->graph_id();
}
void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,

View File

@ -37,6 +37,7 @@ class GPUSession : public SessionBasic {
protected:
void UnifyMindIR(const KernelGraphPtr &graph) override { return; }
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
@ -81,6 +82,8 @@ class GPUSession : public SessionBasic {
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
GraphId CompileGraphImpl(KernelGraphPtr kernel_graph);
};
using GPUSessionPtr = std::shared_ptr<GPUSession>;
MS_REG_SESSION(kGPUDevice, GPUSession);

View File

@ -15,10 +15,15 @@ if(ENABLE_ACL)
"model/model_converter_utils/*.cc"
"graph/acl/*.cc"
)
endif()
if(ENABLE_D)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/ms/*.cc")
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"python_utils.cc" "model/ms/*.cc" "graph/ascend/*.cc")
endif()
if(ENABLE_GPU)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc")
endif()
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
@ -98,6 +103,15 @@ if(ENABLE_D)
target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server})
endif()
if(ENABLE_GPU)
target_link_libraries(mindspore_shared_lib PRIVATE gpu_cuda_lib gpu_queue cublas
${CUDA_PATH}/lib64/libcurand.so
${CUDNN_LIBRARY_PATH}
${CUDA_PATH}/lib64/libcudart.so
${CUDA_PATH}/lib64/stubs/libcuda.so
${CUDA_PATH}/lib64/libcusolver.so)
endif()
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(MINDSPORE_RPATH $ORIGIN)
if(ENABLE_D)
@ -110,7 +124,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
set(MINDSPORE_RPATH
${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
elseif(ENABLE_ACL)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/atc/lib64)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/atc/lib64)
@ -121,7 +136,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
set(MINDSPORE_RPATH
${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
endif()
set_target_properties(mindspore_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/graph/ms/ms_graph_impl.h"
#include "cxx_api/graph/ascend/ascend_graph_impl.h"
#include <algorithm>
#include "include/api/context.h"
#include "cxx_api/factory.h"
@ -26,43 +26,9 @@
#include "runtime/device/kernel_runtime_manager.h"
namespace mindspore::api {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl);
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
static DataType TransTypeId2InferDataType(TypeId type_id) {
const std::map<TypeId, api::DataType> id2type_map{
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
};
// cppcheck-suppress stlIfFind
if (auto it = id2type_map.find(type_id); it != id2type_map.end()) {
return it->second;
}
MS_LOG(WARNING) << "Unsupported data id " << type_id;
return api::kMsUnknown;
}
template <class T>
inline static void ClearIfNotNull(T *vec) {
if (vec != nullptr) {
vec->clear();
}
}
template <class T, class U = std::vector<T>>
inline static void PushbackIfNotNull(U *vec, T &&item) {
if (vec != nullptr) {
vec->emplace_back(item);
}
}
MsGraphImpl::MsGraphImpl()
AscendGraphImpl::AscendGraphImpl()
: session_impl_(nullptr),
graph_id_(0),
device_type_("Ascend"),
@ -75,9 +41,9 @@ MsGraphImpl::MsGraphImpl()
init_flag_(false),
load_flag_(false) {}
MsGraphImpl::~MsGraphImpl() { (void)FinalizeEnv(); }
AscendGraphImpl::~AscendGraphImpl() { (void)FinalizeEnv(); }
Status MsGraphImpl::InitEnv() {
Status AscendGraphImpl::InitEnv() {
if (init_flag_) {
return SUCCESS;
}
@ -108,7 +74,7 @@ Status MsGraphImpl::InitEnv() {
return SUCCESS;
}
Status MsGraphImpl::FinalizeEnv() {
Status AscendGraphImpl::FinalizeEnv() {
if (!init_flag_) {
return SUCCESS;
}
@ -136,7 +102,7 @@ Status MsGraphImpl::FinalizeEnv() {
return SUCCESS;
}
Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
Status AscendGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr);
try {
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
@ -147,7 +113,7 @@ Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr)
}
}
std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
std::vector<tensor::TensorPtr> AscendGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
try {
VectorRef outputs;
session_impl_->RunGraph(graph_id_, inputs, &outputs);
@ -158,7 +124,7 @@ std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::T
}
}
Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const {
Status AscendGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const {
MS_ASSERT(session_impl_ != nullptr);
std::string error_msg;
if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) {
@ -167,7 +133,7 @@ Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &input
return SUCCESS;
}
Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
Status AscendGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
MS_EXCEPTION_IF_NULL(reply);
if (context_ == nullptr) {
MS_LOG(ERROR) << "rtCtx is nullptr";
@ -206,7 +172,7 @@ Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector
return SUCCESS;
}
Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
Status AscendGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
if (!load_flag_) {
Status ret = Load();
@ -216,21 +182,21 @@ Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<s
}
}
ClearIfNotNull(names);
ClearIfNotNull(shapes);
ClearIfNotNull(data_types);
ClearIfNotNull(mem_sizes);
GraphUtils::ClearIfNotNull(names);
GraphUtils::ClearIfNotNull(shapes);
GraphUtils::ClearIfNotNull(data_types);
GraphUtils::ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < inputs_.size(); i++) {
auto &tensor = inputs_[i];
PushbackIfNotNull(names, input_names_[i]);
PushbackIfNotNull(shapes, tensor->shape());
PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type()));
PushbackIfNotNull(mem_sizes, tensor->Size());
GraphUtils::PushbackIfNotNull(names, input_names_[i]);
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
}
return SUCCESS;
}
Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
Status AscendGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
if (!load_flag_) {
Status ret = Load();
@ -240,22 +206,22 @@ Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<
}
}
ClearIfNotNull(names);
ClearIfNotNull(shapes);
ClearIfNotNull(data_types);
ClearIfNotNull(mem_sizes);
GraphUtils::ClearIfNotNull(names);
GraphUtils::ClearIfNotNull(shapes);
GraphUtils::ClearIfNotNull(data_types);
GraphUtils::ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < outputs_.size(); i++) {
auto &tensor = outputs_[i];
PushbackIfNotNull(names, output_names_[i]);
PushbackIfNotNull(shapes, tensor->shape());
PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type()));
PushbackIfNotNull(mem_sizes, tensor->Size());
GraphUtils::PushbackIfNotNull(names, output_names_[i]);
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
}
return SUCCESS;
}
Status MsGraphImpl::Load() {
Status AscendGraphImpl::Load() {
// check graph type
if (graph_->ModelType() != ModelType::kMindIR) {
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
@ -311,7 +277,7 @@ Status MsGraphImpl::Load() {
return SUCCESS;
}
Status MsGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status AscendGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) {
Status ret = Load();

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H
#include <functional>
#include <map>
#include <string>
@ -28,12 +28,13 @@
#include "ir/anf.h"
#include "cxx_api/model/model_impl.h"
#include "runtime/context.h"
#include "cxx_api/graph/graph_utils.h"
namespace mindspore::api {
class MsGraphImpl : public GraphCell::GraphImpl {
class AscendGraphImpl : public GraphCell::GraphImpl {
public:
MsGraphImpl();
~MsGraphImpl() override;
AscendGraphImpl();
~AscendGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Load() override;
@ -63,4 +64,4 @@ class MsGraphImpl : public GraphCell::GraphImpl {
bool load_flag_;
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H

View File

@ -20,7 +20,7 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() {
return instance;
}
bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) {

View File

@ -0,0 +1,256 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/graph/gpu/gpu_graph_impl.h"
#include <algorithm>
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "utils/log_adapter.h"
#include "mindspore/core/base/base_ref_utils.h"
#include "backend/session/session_factory.h"
#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
namespace mindspore::api {
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
GPUGraphImpl::GPUGraphImpl()
: session_impl_(nullptr),
graph_id_(0),
device_id_(Context::Instance().GetDeviceID()),
inputs_(),
outputs_(),
input_names_(),
output_names_(),
init_flag_(false),
load_flag_(false) {}
Status GPUGraphImpl::InitEnv() {
if (init_flag_) {
MS_LOG(WARNING) << "Initialized again, return success.";
return SUCCESS;
}
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice);
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true);
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice
<< " is available.";
return FAILED;
}
session_impl_->Init(device_id_);
init_flag_ = true;
return SUCCESS;
}
Status GPUGraphImpl::FinalizeEnv() {
if (!init_flag_) {
MS_LOG(WARNING) << "Never initialize before, return success";
return SUCCESS;
}
MS_LOG_INFO << "Start finalize env";
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
init_flag_ = false;
MS_LOG(INFO) << "End finalize env";
return SUCCESS;
}
Status GPUGraphImpl::Load() {
// check graph type
if (graph_->ModelType() != ModelType::kMindIR) {
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
return INVALID_INPUTS;
}
const auto &graph_data = GraphImpl::MutableGraphData();
MS_EXCEPTION_IF_NULL(graph_data);
auto func_graph = graph_data->GetFuncGraph();
// init
Status ret = InitEnv();
if (ret != SUCCESS) {
MS_LOG(ERROR) << "InitEnv failed.";
return FAILED;
}
ret = CompileGraph(func_graph);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Compile graph model failed";
return FAILED;
}
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
MS_LOG_ERROR << "Get model inputs info failed";
return FAILED;
}
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
MS_LOG_ERROR << "Get model outputs info failed";
return FAILED;
}
load_flag_ = true;
return SUCCESS;
}
Status GPUGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr);
try {
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
return SUCCESS;
} catch (std::exception &e) {
MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
return FAILED;
}
}
std::vector<tensor::TensorPtr> GPUGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
try {
VectorRef outputs;
session_impl_->RunGraph(graph_id_, inputs, &outputs);
return TransformVectorRefToMultiTensor(outputs);
} catch (std::exception &e) {
MS_LOG(ERROR) << "RunGraph failed: " << e.what();
return std::vector<tensor::TensorPtr>();
}
}
Status GPUGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
MS_EXCEPTION_IF_NULL(reply);
vector<tensor::TensorPtr> inputs;
for (size_t i = 0; i < request.size(); i++) {
auto &item = request[i];
auto input = inputs_[i];
if (input->Size() != item.DataSize()) {
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
<< input->Size();
return FAILED;
}
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Tensor copy failed";
return FAILED;
}
inputs.push_back(input);
}
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
if (outputs.empty()) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
}
reply->clear();
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
return SUCCESS;
}
Status GPUGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
}
}
if (inputs.size() != inputs_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
return INVALID_INPUTS;
}
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs[i].DataSize() != inputs_[i]->Size()) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
<< inputs[i].DataSize();
return INVALID_INPUTS;
}
}
if (ExecuteModel(inputs, outputs) != SUCCESS) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
}
if (outputs_.size() != outputs->size()) {
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
<< outputs_.size();
return FAILED;
}
return SUCCESS;
}
Status GPUGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
}
}
GraphUtils::ClearIfNotNull(names);
GraphUtils::ClearIfNotNull(shapes);
GraphUtils::ClearIfNotNull(data_types);
GraphUtils::ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < inputs_.size(); i++) {
auto &tensor = inputs_[i];
GraphUtils::PushbackIfNotNull(names, input_names_[i]);
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
}
return SUCCESS;
}
Status GPUGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
}
}
GraphUtils::ClearIfNotNull(names);
GraphUtils::ClearIfNotNull(shapes);
GraphUtils::ClearIfNotNull(data_types);
GraphUtils::ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < outputs_.size(); i++) {
auto &tensor = outputs_[i];
GraphUtils::PushbackIfNotNull(names, output_names_[i]);
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
}
return SUCCESS;
}
} // namespace mindspore::api

View File

@ -0,0 +1,67 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H
#include <string>
#include <vector>
#include <utility>
#include <memory>
#include "include/api/status.h"
#include "include/api/graph.h"
#include "cxx_api/graph/graph_impl.h"
#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "cxx_api/model/model_impl.h"
#include "cxx_api/graph/graph_utils.h"
namespace mindspore::api {
class GPUGraphImpl : public GraphCell::GraphImpl {
public:
GPUGraphImpl();
~GPUGraphImpl() override = default;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
private:
Status InitEnv();
Status FinalizeEnv();
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
std::shared_ptr<session::SessionBasic> session_impl_;
uint32_t graph_id_;
std::string device_type_;
uint32_t device_id_;
std::vector<tensor::TensorPtr> inputs_;
std::vector<tensor::TensorPtr> outputs_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
bool init_flag_;
bool load_flag_;
// tensor-rt
uint32_t batch_size_;
uint32_t workspace_size_;
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
#include <map>
#include <vector>
#include "include/api/types.h"
#include "ir/dtype/type_id.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
class GraphUtils {
public:
static DataType TransTypeId2InferDataType(TypeId type_id) {
const std::map<TypeId, api::DataType> id2type_map{
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
};
auto it = id2type_map.find(type_id);
if (it != id2type_map.end()) {
return it->second;
}
MS_LOG(WARNING) << "Unsupported data id " << type_id;
return api::kMsUnknown;
}
template <class T>
inline static void ClearIfNotNull(T *vec) {
if (vec != nullptr) {
vec->clear();
}
}
template <class T, class U>
inline static void PushbackIfNotNull(U *vec, T &&item) {
if (vec != nullptr) {
vec->emplace_back(item);
}
}
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H

View File

@ -22,6 +22,7 @@
namespace mindspore {
namespace api {
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
API_FACTORY_REG(ModelImpl, GPU, MsModel);
Status MsModel::Build(const std::map<std::string, std::string> &) {
MS_LOG(INFO) << "Start build model.";

View File

@ -21,6 +21,7 @@
namespace mindspore::api {
const char *kDeviceTypeAscend310 = "Ascend310";
const char *kDeviceTypeAscend910 = "Ascend910";
const char *kDeviceTypeGpu = "GPU";
class DataImpl {
public:

View File

@ -31,7 +31,7 @@ constexpr Status PROF_FAILED = 0xFFFFFFFF;
} // namespace
Status RegProfCtrlCallback(MsprofCtrlCallback func) {
if (VMCallbackRegister::GetInstance().registed()) {
if (VMCallbackRegister::GetInstance().registered()) {
return VMCallbackRegister::GetInstance().DoRegProfCtrlCallback(func);
} else {
return PROF_SUCCESS;
@ -39,7 +39,7 @@ Status RegProfCtrlCallback(MsprofCtrlCallback func) {
}
Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) {
if (VMCallbackRegister::GetInstance().registed()) {
if (VMCallbackRegister::GetInstance().registered()) {
return VMCallbackRegister::GetInstance().DoRegProfSetDeviceCallback(func);
} else {
return PROF_SUCCESS;
@ -47,7 +47,7 @@ Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) {
}
Status RegProfReporterCallback(MsprofReporterCallback func) {
if (VMCallbackRegister::GetInstance().registed()) {
if (VMCallbackRegister::GetInstance().registered()) {
return VMCallbackRegister::GetInstance().DoRegProfReporterCallback(func);
} else {
return PROF_SUCCESS;
@ -55,7 +55,7 @@ Status RegProfReporterCallback(MsprofReporterCallback func) {
}
Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) {
if (VMCallbackRegister::GetInstance().registed()) {
if (VMCallbackRegister::GetInstance().registered()) {
return VMCallbackRegister::GetInstance().DoProfCommandHandle(type, data, len);
} else {
return PROF_SUCCESS;
@ -69,16 +69,16 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() {
return instance;
}
bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) {
if (!registed_) {
if (!registered_) {
pRegProfCtrlCallback_ = pRegProfCtrlCallback;
pRegProfSetDeviceCallback_ = pRegProfSetDeviceCallback;
pRegProfReporterCallback_ = pRegProfReporterCallback;
pProfCommandHandle_ = pProfCommandHandle;
registed_ = true;
registered_ = true;
ForceMsprofilerInit();
return true;
}

View File

@ -49,12 +49,12 @@ class __attribute__((visibility("default"))) VMCallbackRegister {
static VMCallbackRegister &GetInstance();
VMCallbackRegister(const VMCallbackRegister &) = delete;
VMCallbackRegister &operator=(const VMCallbackRegister &) = delete;
bool Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
bool Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback),
Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback),
Status (*pRegProfReporterCallback)(MsprofReporterCallback),
Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t));
void ForceMsprofilerInit();
bool registed() { return registed_; }
bool registered() { return registered_; }
Status DoRegProfCtrlCallback(MsprofCtrlCallback func) { return pRegProfCtrlCallback_(func); }
Status DoRegProfSetDeviceCallback(MsprofSetDeviceCallback func) { return pRegProfSetDeviceCallback_(func); }
Status DoRegProfReporterCallback(MsprofReporterCallback func) { return pRegProfReporterCallback_(func); }
@ -64,7 +64,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister {
private:
VMCallbackRegister()
: registed_(false),
: registered_(false),
ms_profile_inited_(false),
pRegProfCtrlCallback_(nullptr),
pRegProfSetDeviceCallback_(nullptr),
@ -72,7 +72,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister {
pProfCommandHandle_(nullptr) {}
~VMCallbackRegister() = default;
bool registed_;
bool registered_;
bool ms_profile_inited_;
Status (*pRegProfCtrlCallback_)(MsprofCtrlCallback);
Status (*pRegProfSetDeviceCallback_)(MsprofSetDeviceCallback);

View File

@ -299,7 +299,7 @@ Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) {
bool DoRegiste() {
MS_LOG(INFO) << "VM profiling register start";
return VMCallbackRegister::GetInstance().Registe(RegProfCtrlCallback, RegProfSetDeviceCallback,
return VMCallbackRegister::GetInstance().Register(RegProfCtrlCallback, RegProfSetDeviceCallback,
RegProfReporterCallback, ProfCommandHandle);
}
static bool doRegiste = DoRegiste();

View File

@ -41,6 +41,7 @@ const char kCPUDevice[] = "CPU";
const char kGPUDevice[] = "GPU";
const char kAscendDevice[] = "Ascend";
const char kDavinciInferenceDevice[] = "AscendInference";
const char kGpuInferenceDevice[] = "GpuInference";
const char kDavinciDevice[] = "Davinci";
const char KNpuLog[] = "_npu_log";
const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
@ -51,7 +52,7 @@ const float kDefaultMaxDeviceMemory = 1024;
// enum definition for MindSpore Context Parameter
enum MsCtxParam : unsigned {
// paramater of type bool
// parameter of type bool
MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_CHECK_BPROP_FLAG,
@ -74,14 +75,15 @@ enum MsCtxParam : unsigned {
MS_CTX_ENABLE_PROFILING,
MS_CTX_SAVE_GRAPHS_FLAG,
MS_CTX_ENABLE_PARALLEL_SPLIT,
MS_CTX_ENABLE_INFER_OPT,
MS_CTX_TYPE_BOOL_END,
// paramater of type int
// parameter of type int
MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END,
MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN,
MS_CTX_TYPE_INT_END,
// paramater of type uint32
// parameter of type uint32
MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END,
MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN,
MS_CTX_GE_REF,
@ -89,12 +91,12 @@ enum MsCtxParam : unsigned {
MS_CTX_TSD_REF,
MS_CTX_TYPE_UINT32_END,
// paramater of type float
// parameter of type float
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
MS_CTX_TYPE_FLOAT_END,
// paramater of type string
// parameter of type string
MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END,
MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN,
MS_CTX_GRAPH_MEMORY_MAX_SIZE,