diff --git a/include/api/types.h b/include/api/types.h index 709796e5f86..73950728699 100644 --- a/include/api/types.h +++ b/include/api/types.h @@ -104,6 +104,7 @@ class MS_API Buffer { extern MS_API const char *kDeviceTypeAscend310; extern MS_API const char *kDeviceTypeAscend910; +extern MS_API const char *kDeviceTypeGpu; constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path"; constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file diff --git a/mindspore/ccsrc/backend/session/CMakeLists.txt b/mindspore/ccsrc/backend/session/CMakeLists.txt index 145d32ec4df..e67fcbc7b80 100644 --- a/mindspore/ccsrc/backend/session/CMakeLists.txt +++ b/mindspore/ccsrc/backend/session/CMakeLists.txt @@ -13,7 +13,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin") endif() if(ENABLE_GPU) - file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc") + file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc" "gpu_inference_session.cc") list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) endif() diff --git a/mindspore/ccsrc/backend/session/gpu_inference_session.cc b/mindspore/ccsrc/backend/session/gpu_inference_session.cc new file mode 100644 index 00000000000..c1091d7c814 --- /dev/null +++ b/mindspore/ccsrc/backend/session/gpu_inference_session.cc @@ -0,0 +1,217 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/session/gpu_inference_session.h" +#include "ir/tensor.h" +#include "ir/anf.h" +#include "ir/param_info.h" +#include "runtime/device/kernel_runtime.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/ms_utils.h" +#include "common/trans.h" +#include "utils/config_manager.h" + +namespace mindspore { +namespace session { +void GpuInferenceSession::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector inputs(inputs_const); + auto input_nodes = kernel_graph->inputs(); + + size_t no_weight_input = 0; + for (size_t i = 0; i < input_nodes.size(); ++i) { + tensor::TensorPtr tensor = nullptr; + if (!input_nodes[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; + continue; + } + auto pk_node = input_nodes[i]->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + if (!AnfAlgo::IsParameterWeight(pk_node)) { + tensor = inputs[no_weight_input++]; + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } +} + +GraphId GpuInferenceSession::CompileGraphImpl(NotNull func_graph) { + auto graph_id = GPUSession::CompileGraphImpl(func_graph); + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + // load weight data to device + auto input_nodes = kernel_graph->inputs(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + if (!input_nodes[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; + continue; + } + auto pk_node = input_nodes[i]->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + if (AnfAlgo::IsParameterWeight(pk_node)) { + const auto ¶m_value = pk_node->default_param(); + MS_EXCEPTION_IF_NULL(param_value); + auto tensor = std::dynamic_pointer_cast(param_value); + MS_EXCEPTION_IF_NULL(tensor); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } + return graph_id; +} + +bool GpuInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vector &inputs, + std::string *error_msg) const { + MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id; + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto kernel_graph_inputs = kernel_graph->inputs(); + size_t no_weight_input = 0; + vector paras; + // find parameters of graph inputs + for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { + if (!kernel_graph_inputs[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; + continue; + } + auto parameter = kernel_graph_inputs[i]->cast(); + if (!AnfAlgo::IsParameterWeight(parameter)) { + paras.push_back(parameter); + } + } + + // check inputs + for (size_t i = 0; i < paras.size(); ++i) { + // compare input number + if (paras.size() != inputs.size()) { + MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() + << "] but the graph input number is [" << paras.size() << "]"; + MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); + if (error_msg != nullptr) { + std::stringstream str_stream; + str_stream << "Input number is inconsistent. The given input number [" << inputs.size() + << "] but the graph input number is [" << paras.size() << "]\n"; + str_stream << "InputsInfo --" << InputsInfo(paras, inputs); + *error_msg = str_stream.str(); + } + return false; + } + auto input = inputs[no_weight_input++]; + if (!CompareInput(input, paras[i])) { + MS_LOG(ERROR) << "Please check the input information."; + MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); + if (error_msg != nullptr) { + std::stringstream str_stream; + str_stream << "Please check the input information.\n"; + str_stream << "InputsInfo --" << InputsInfo(paras, inputs); + *error_msg = str_stream.str(); + } + return false; + } + } + return true; +} + +bool GpuInferenceSession::CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const { + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(parameter); + // compare dims + auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); + + // compare shape + auto input_shape = input->shape(); + vector trans_input; + (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input), + [](const int64_t dim) { return static_cast(dim); }); + auto is_scalar_shape = [](const vector &shape) { + return shape.empty() || (shape.size() == 1 && shape[0] == 1); + }; + if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) { + MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input) + << ", but the parameter shape is " << PrintInputShape(parameter_shape) + << ". parameter : " << parameter->DebugString(); + return false; + } + + // compare data type + auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); + if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) { + MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type() + << ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0) + << ". parameter : " << parameter->DebugString(); + return false; + } + return true; +} + +template +std::string GpuInferenceSession::PrintInputShape(std::vector shape) const { + string res = "["; + for (auto dim : shape) { + res += " " + std::to_string(dim); + } + return res + " ]"; +} + +std::string GpuInferenceSession::InputsInfo(const std::vector ¶s, + const std::vector &inputs) const { + const std::map dtype_name_map{ + {TypeId::kNumberTypeBegin, "Unknown"}, {TypeId::kNumberTypeBool, "Bool"}, + {TypeId::kNumberTypeFloat64, "Float64"}, {TypeId::kNumberTypeInt8, "Int8"}, + {TypeId::kNumberTypeUInt8, "Uint8"}, {TypeId::kNumberTypeInt16, "Int16"}, + {TypeId::kNumberTypeUInt16, "Uint16"}, {TypeId::kNumberTypeInt32, "Int32"}, + {TypeId::kNumberTypeUInt32, "Uint32"}, {TypeId::kNumberTypeInt64, "Int64"}, + {TypeId::kNumberTypeUInt64, "Uint64"}, {TypeId::kNumberTypeFloat16, "Float16"}, + {TypeId::kNumberTypeFloat32, "Float32"}, + }; + auto data_type_to_string = [&dtype_name_map](TypeId type_id) { + auto it = dtype_name_map.find(type_id); + if (it == dtype_name_map.end()) { + return std::string("Unknown"); + } + return it->second; + }; + + std::string graph = "graph inputs:{ "; + for (size_t i = 0; i < paras.size(); ++i) { + auto ¶ = paras[i]; + graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(para, 0).size()) + + ", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(para, 0)) + ", data type " + + data_type_to_string(AnfAlgo::GetSelectKernelBuildInfo(para)->GetOutputDeviceType(0)) + " }"; + } + + std::string actual = "given inputs:{ "; + for (size_t i = 0; i < inputs.size(); ++i) { + actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " + + PrintInputShape(inputs[i]->shape()) + ", data type " + data_type_to_string(inputs[i]->data_type()) + " }"; + } + return graph + " " + actual; +} + +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/gpu_inference_session.h b/mindspore/ccsrc/backend/session/gpu_inference_session.h new file mode 100644 index 00000000000..0ee2459799f --- /dev/null +++ b/mindspore/ccsrc/backend/session/gpu_inference_session.h @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/session/gpu_session.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/session_factory.h" + +namespace mindspore { +namespace session { +class GpuInferenceSession : public gpu::GPUSession { + public: + GpuInferenceSession() = default; + ~GpuInferenceSession() = default; + void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const; + bool CheckModelInputs(uint32_t graph_id, const std::vector &inputs, + std::string *error_msg) const override; + bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const; + template + std::string PrintInputShape(std::vector shape) const; + std::string InputsInfo(const std::vector ¶s, const std::vector &inputs) const; + + protected: + GraphId CompileGraphImpl(NotNull func_graph) override; +}; +MS_REG_SESSION(kGpuInferenceDevice, GpuInferenceSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 69e3dd5001f..a16623dbe39 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -15,9 +15,11 @@ */ #include "backend/session/gpu_session.h" +#include #include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/common/common_backend_optimization.h" #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" #include "backend/optimizer/gpu/adam_fusion.h" #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" @@ -298,16 +300,31 @@ void GPUSession::Execute(const std::shared_ptr &kernel_graph) const GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { // Construct graph, if successfully, graph_sum_ + 1 - auto graph_id = graph_sum_; auto graph = ConstructKernelGraph(lst, outputs); MS_EXCEPTION_IF_NULL(graph); + return CompileGraphImpl(graph); +} + +GraphId GPUSession::CompileGraphImpl(NotNull func_graph) { + std::vector all_graphs; + auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); + MS_EXCEPTION_IF_NULL(root_graph); + if (all_graphs.size() != 1) { + MS_LOG(EXCEPTION) << "Gpu backend does not support multi-graph schedule. graph num" << all_graphs.size(); + } + + opt::BackendCommonOptimization(root_graph); + return CompileGraphImpl(root_graph); +} + +GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) { // Prepare ms context info for dump .pb graph auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); // Dump .pb graph before graph optimization if (save_graphs) { - DumpIRProto(graph, "before_opt_" + std::to_string(graph_id)); + DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id())); } // Graph optimization irrelevant to device data format Optimize(graph); @@ -326,7 +343,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr AssignStream(graph); // Dump .pb graph before remove nop nodes if (save_graphs) { - DumpIRProto(graph, "before_removeNop_" + std::to_string(graph_id)); + DumpIRProto(graph, "before_removeNop_" + std::to_string(graph->graph_id())); } // Update Graph Dynamic Shape Attr. UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); @@ -343,7 +360,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr SetSummaryNodes(graph.get()); // Dump .pb graph after graph optimization if (save_graphs) { - DumpIRProto(graph, "after_opt_" + std::to_string(graph_id)); + DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id())); } // Set graph manager. MS_EXCEPTION_IF_NULL(context_); @@ -361,9 +378,8 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr debugger_->LoadGraphs(graph); } #endif - MS_LOG(INFO) << "CompileGraph graph_id: " << graph_id; - - return graph_id; + MS_LOG(INFO) << "CompileGraph graph_id: " << graph->graph_id(); + return graph->graph_id(); } void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector &inputs, diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 05ad8b5796c..2a9e23f933b 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -37,6 +37,7 @@ class GPUSession : public SessionBasic { protected: void UnifyMindIR(const KernelGraphPtr &graph) override { return; } GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + GraphId CompileGraphImpl(NotNull func_graph) override; void RunGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, @@ -81,6 +82,8 @@ class GPUSession : public SessionBasic { void SyncValueNodeDeviceAddr(const std::shared_ptr &kernel_graph) const; void CleanValueNodeDeviceAddr(const std::shared_ptr &kernel_graph) const; + + GraphId CompileGraphImpl(KernelGraphPtr kernel_graph); }; using GPUSessionPtr = std::shared_ptr; MS_REG_SESSION(kGPUDevice, GPUSession); diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt index 4d14c811cbb..4992308f15b 100644 --- a/mindspore/ccsrc/cxx_api/CMakeLists.txt +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -15,10 +15,15 @@ if(ENABLE_ACL) "model/model_converter_utils/*.cc" "graph/acl/*.cc" ) - endif() + if(ENABLE_D) - file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/ms/*.cc") + file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} + "python_utils.cc" "model/ms/*.cc" "graph/ascend/*.cc") +endif() + +if(ENABLE_GPU) + file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc") endif() set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc @@ -98,6 +103,15 @@ if(ENABLE_D) target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server}) endif() +if(ENABLE_GPU) + target_link_libraries(mindspore_shared_lib PRIVATE gpu_cuda_lib gpu_queue cublas + ${CUDA_PATH}/lib64/libcurand.so + ${CUDNN_LIBRARY_PATH} + ${CUDA_PATH}/lib64/libcudart.so + ${CUDA_PATH}/lib64/stubs/libcuda.so + ${CUDA_PATH}/lib64/libcusolver.so) +endif() + if(CMAKE_SYSTEM_NAME MATCHES "Linux") set(MINDSPORE_RPATH $ORIGIN) if(ENABLE_D) @@ -110,7 +124,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) + set(MINDSPORE_RPATH + ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) elseif(ENABLE_ACL) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/atc/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/atc/lib64) @@ -121,7 +136,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) + set(MINDSPORE_RPATH + ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) endif() set_target_properties(mindspore_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) diff --git a/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc similarity index 71% rename from mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc rename to mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc index 2839964f173..9490bcf74b1 100644 --- a/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "cxx_api/graph/ms/ms_graph_impl.h" +#include "cxx_api/graph/ascend/ascend_graph_impl.h" #include #include "include/api/context.h" #include "cxx_api/factory.h" @@ -26,43 +26,9 @@ #include "runtime/device/kernel_runtime_manager.h" namespace mindspore::api { -API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl); +API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); -static DataType TransTypeId2InferDataType(TypeId type_id) { - const std::map id2type_map{ - {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, - {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, - {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, - {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, - {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, - {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, - {TypeId::kNumberTypeFloat32, api::kMsFloat32}, - }; - - // cppcheck-suppress stlIfFind - if (auto it = id2type_map.find(type_id); it != id2type_map.end()) { - return it->second; - } - - MS_LOG(WARNING) << "Unsupported data id " << type_id; - return api::kMsUnknown; -} - -template -inline static void ClearIfNotNull(T *vec) { - if (vec != nullptr) { - vec->clear(); - } -} - -template > -inline static void PushbackIfNotNull(U *vec, T &&item) { - if (vec != nullptr) { - vec->emplace_back(item); - } -} - -MsGraphImpl::MsGraphImpl() +AscendGraphImpl::AscendGraphImpl() : session_impl_(nullptr), graph_id_(0), device_type_("Ascend"), @@ -75,9 +41,9 @@ MsGraphImpl::MsGraphImpl() init_flag_(false), load_flag_(false) {} -MsGraphImpl::~MsGraphImpl() { (void)FinalizeEnv(); } +AscendGraphImpl::~AscendGraphImpl() { (void)FinalizeEnv(); } -Status MsGraphImpl::InitEnv() { +Status AscendGraphImpl::InitEnv() { if (init_flag_) { return SUCCESS; } @@ -108,7 +74,7 @@ Status MsGraphImpl::InitEnv() { return SUCCESS; } -Status MsGraphImpl::FinalizeEnv() { +Status AscendGraphImpl::FinalizeEnv() { if (!init_flag_) { return SUCCESS; } @@ -136,7 +102,7 @@ Status MsGraphImpl::FinalizeEnv() { return SUCCESS; } -Status MsGraphImpl::CompileGraph(const std::shared_ptr &funcGraphPtr) { +Status AscendGraphImpl::CompileGraph(const std::shared_ptr &funcGraphPtr) { MS_ASSERT(session_impl_ != nullptr); try { graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); @@ -147,7 +113,7 @@ Status MsGraphImpl::CompileGraph(const std::shared_ptr &funcGraphPtr) } } -std::vector MsGraphImpl::RunGraph(const std::vector &inputs) { +std::vector AscendGraphImpl::RunGraph(const std::vector &inputs) { try { VectorRef outputs; session_impl_->RunGraph(graph_id_, inputs, &outputs); @@ -158,7 +124,7 @@ std::vector MsGraphImpl::RunGraph(const std::vector &inputs) const { +Status AscendGraphImpl::CheckModelInputs(const std::vector &inputs) const { MS_ASSERT(session_impl_ != nullptr); std::string error_msg; if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) { @@ -167,7 +133,7 @@ Status MsGraphImpl::CheckModelInputs(const std::vector &input return SUCCESS; } -Status MsGraphImpl::ExecuteModel(const std::vector &request, std::vector *reply) { +Status AscendGraphImpl::ExecuteModel(const std::vector &request, std::vector *reply) { MS_EXCEPTION_IF_NULL(reply); if (context_ == nullptr) { MS_LOG(ERROR) << "rtCtx is nullptr"; @@ -206,8 +172,8 @@ Status MsGraphImpl::ExecuteModel(const std::vector &request, std::vector return SUCCESS; } -Status MsGraphImpl::GetInputsInfo(std::vector *names, std::vector> *shapes, - std::vector *data_types, std::vector *mem_sizes) { +Status AscendGraphImpl::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) { if (!load_flag_) { Status ret = Load(); if (ret != SUCCESS) { @@ -216,22 +182,22 @@ Status MsGraphImpl::GetInputsInfo(std::vector *names, std::vectorshape()); - PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); - PushbackIfNotNull(mem_sizes, tensor->Size()); + GraphUtils::PushbackIfNotNull(names, input_names_[i]); + GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); + GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); + GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); } return SUCCESS; } -Status MsGraphImpl::GetOutputsInfo(std::vector *names, std::vector> *shapes, - std::vector *data_types, std::vector *mem_sizes) { +Status AscendGraphImpl::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) { if (!load_flag_) { Status ret = Load(); if (ret != SUCCESS) { @@ -240,22 +206,22 @@ Status MsGraphImpl::GetOutputsInfo(std::vector *names, std::vector< } } - ClearIfNotNull(names); - ClearIfNotNull(shapes); - ClearIfNotNull(data_types); - ClearIfNotNull(mem_sizes); + GraphUtils::ClearIfNotNull(names); + GraphUtils::ClearIfNotNull(shapes); + GraphUtils::ClearIfNotNull(data_types); + GraphUtils::ClearIfNotNull(mem_sizes); for (size_t i = 0; i < outputs_.size(); i++) { auto &tensor = outputs_[i]; - PushbackIfNotNull(names, output_names_[i]); - PushbackIfNotNull(shapes, tensor->shape()); - PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); - PushbackIfNotNull(mem_sizes, tensor->Size()); + GraphUtils::PushbackIfNotNull(names, output_names_[i]); + GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); + GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); + GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); } return SUCCESS; } -Status MsGraphImpl::Load() { +Status AscendGraphImpl::Load() { // check graph type if (graph_->ModelType() != ModelType::kMindIR) { MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); @@ -311,7 +277,7 @@ Status MsGraphImpl::Load() { return SUCCESS; } -Status MsGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { +Status AscendGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(outputs); if (!load_flag_) { Status ret = Load(); diff --git a/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.h b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h similarity index 87% rename from mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.h rename to mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h index ec7d148be1d..fae683558e3 100644 --- a/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H -#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H +#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H +#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H #include #include #include @@ -28,12 +28,13 @@ #include "ir/anf.h" #include "cxx_api/model/model_impl.h" #include "runtime/context.h" +#include "cxx_api/graph/graph_utils.h" namespace mindspore::api { -class MsGraphImpl : public GraphCell::GraphImpl { +class AscendGraphImpl : public GraphCell::GraphImpl { public: - MsGraphImpl(); - ~MsGraphImpl() override; + AscendGraphImpl(); + ~AscendGraphImpl() override; Status Run(const std::vector &inputs, std::vector *outputs) override; Status Load() override; @@ -63,4 +64,4 @@ class MsGraphImpl : public GraphCell::GraphImpl { bool load_flag_; }; } // namespace mindspore::api -#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H +#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_stub.cc b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_stub.cc similarity index 69% rename from mindspore/ccsrc/cxx_api/model/ms/ms_stub.cc rename to mindspore/ccsrc/cxx_api/graph/ascend/ascend_stub.cc index 50e2c46d73a..464a8a2f11e 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_stub.cc +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_stub.cc @@ -20,10 +20,10 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() { return instance; } -bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), - Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), - Status (*pRegProfReporterCallback)(MsprofReporterCallback), - Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { +bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), + Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), + Status (*pRegProfReporterCallback)(MsprofReporterCallback), + Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { return false; } diff --git a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc new file mode 100644 index 00000000000..6af3a9ab9b6 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc @@ -0,0 +1,256 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cxx_api/graph/gpu/gpu_graph_impl.h" +#include +#include "include/api/context.h" +#include "cxx_api/factory.h" +#include "utils/log_adapter.h" +#include "mindspore/core/base/base_ref_utils.h" +#include "backend/session/session_factory.h" +#include "backend/session/executor_manager.h" +#include "runtime/device/kernel_runtime_manager.h" + +namespace mindspore::api { +API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl); + +GPUGraphImpl::GPUGraphImpl() + : session_impl_(nullptr), + graph_id_(0), + device_id_(Context::Instance().GetDeviceID()), + inputs_(), + outputs_(), + input_names_(), + output_names_(), + init_flag_(false), + load_flag_(false) {} + +Status GPUGraphImpl::InitEnv() { + if (init_flag_) { + MS_LOG(WARNING) << "Initialized again, return success."; + return SUCCESS; + } + + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return FAILED; + } + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); + ms_context->set_param(MS_CTX_DEVICE_ID, device_id_); + ms_context->set_param(MS_CTX_DEVICE_TARGET, kGPUDevice); + ms_context->set_param(MS_CTX_ENABLE_INFER_OPT, true); + + session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); + if (session_impl_ == nullptr) { + MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice + << " is available."; + return FAILED; + } + + session_impl_->Init(device_id_); + init_flag_ = true; + return SUCCESS; +} + +Status GPUGraphImpl::FinalizeEnv() { + if (!init_flag_) { + MS_LOG(WARNING) << "Never initialize before, return success"; + return SUCCESS; + } + + MS_LOG_INFO << "Start finalize env"; + session::ExecutorManager::Instance().Clear(); + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); + + init_flag_ = false; + MS_LOG(INFO) << "End finalize env"; + return SUCCESS; +} + +Status GPUGraphImpl::Load() { + // check graph type + if (graph_->ModelType() != ModelType::kMindIR) { + MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); + return INVALID_INPUTS; + } + + const auto &graph_data = GraphImpl::MutableGraphData(); + MS_EXCEPTION_IF_NULL(graph_data); + auto func_graph = graph_data->GetFuncGraph(); + + // init + Status ret = InitEnv(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "InitEnv failed."; + return FAILED; + } + + ret = CompileGraph(func_graph); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Compile graph model failed"; + return FAILED; + } + session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); + session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); + if (inputs_.empty() || inputs_.size() != input_names_.size()) { + MS_LOG_ERROR << "Get model inputs info failed"; + return FAILED; + } + if (outputs_.empty() || outputs_.size() != output_names_.size()) { + MS_LOG_ERROR << "Get model outputs info failed"; + return FAILED; + } + load_flag_ = true; + return SUCCESS; +} + +Status GPUGraphImpl::CompileGraph(const std::shared_ptr &funcGraphPtr) { + MS_ASSERT(session_impl_ != nullptr); + try { + graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + return SUCCESS; + } catch (std::exception &e) { + MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); + return FAILED; + } +} + +std::vector GPUGraphImpl::RunGraph(const std::vector &inputs) { + try { + VectorRef outputs; + session_impl_->RunGraph(graph_id_, inputs, &outputs); + return TransformVectorRefToMultiTensor(outputs); + } catch (std::exception &e) { + MS_LOG(ERROR) << "RunGraph failed: " << e.what(); + return std::vector(); + } +} + +Status GPUGraphImpl::ExecuteModel(const std::vector &request, std::vector *reply) { + MS_EXCEPTION_IF_NULL(reply); + + vector inputs; + for (size_t i = 0; i < request.size(); i++) { + auto &item = request[i]; + auto input = inputs_[i]; + if (input->Size() != item.DataSize()) { + MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " + << input->Size(); + return FAILED; + } + auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Tensor copy failed"; + return FAILED; + } + inputs.push_back(input); + } + vector outputs = RunGraph(inputs); + if (outputs.empty()) { + MS_LOG(ERROR) << "Execute Model Failed"; + return FAILED; + } + reply->clear(); + std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), + [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); + return SUCCESS; +} + +Status GPUGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + if (!load_flag_) { + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "PrepareModel failed."; + return ret; + } + } + + if (inputs.size() != inputs_.size()) { + MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); + return INVALID_INPUTS; + } + + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs[i].DataSize() != inputs_[i]->Size()) { + MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " + << inputs[i].DataSize(); + return INVALID_INPUTS; + } + } + if (ExecuteModel(inputs, outputs) != SUCCESS) { + MS_LOG(ERROR) << "Execute Model Failed"; + return FAILED; + } + if (outputs_.size() != outputs->size()) { + MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " + << outputs_.size(); + return FAILED; + } + + return SUCCESS; +} + +Status GPUGraphImpl::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) { + if (!load_flag_) { + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "PrepareModel failed."; + return ret; + } + } + + GraphUtils::ClearIfNotNull(names); + GraphUtils::ClearIfNotNull(shapes); + GraphUtils::ClearIfNotNull(data_types); + GraphUtils::ClearIfNotNull(mem_sizes); + for (size_t i = 0; i < inputs_.size(); i++) { + auto &tensor = inputs_[i]; + GraphUtils::PushbackIfNotNull(names, input_names_[i]); + GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); + GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); + GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); + } + return SUCCESS; +} + +Status GPUGraphImpl::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) { + if (!load_flag_) { + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "PrepareModel failed."; + return ret; + } + } + + GraphUtils::ClearIfNotNull(names); + GraphUtils::ClearIfNotNull(shapes); + GraphUtils::ClearIfNotNull(data_types); + GraphUtils::ClearIfNotNull(mem_sizes); + for (size_t i = 0; i < outputs_.size(); i++) { + auto &tensor = outputs_[i]; + GraphUtils::PushbackIfNotNull(names, output_names_[i]); + GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); + GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); + GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); + } + + return SUCCESS; +} + +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.h b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.h new file mode 100644 index 00000000000..fca0323f82f --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.h @@ -0,0 +1,67 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H +#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/graph.h" +#include "cxx_api/graph/graph_impl.h" +#include "backend/session/session_basic.h" +#include "ir/anf.h" +#include "cxx_api/model/model_impl.h" +#include "cxx_api/graph/graph_utils.h" + +namespace mindspore::api { +class GPUGraphImpl : public GraphCell::GraphImpl { + public: + GPUGraphImpl(); + ~GPUGraphImpl() override = default; + + Status Run(const std::vector &inputs, std::vector *outputs) override; + Status Load() override; + Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) override; + Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) override; + + private: + Status InitEnv(); + Status FinalizeEnv(); + Status CompileGraph(const std::shared_ptr &funcGraphPtr); + Status CheckModelInputs(const std::vector &inputs) const; + std::vector RunGraph(const std::vector &inputs); + Status ExecuteModel(const std::vector &inputs, std::vector *outputs); + + std::shared_ptr session_impl_; + uint32_t graph_id_; + std::string device_type_; + uint32_t device_id_; + std::vector inputs_; + std::vector outputs_; + std::vector input_names_; + std::vector output_names_; + bool init_flag_; + bool load_flag_; + + // tensor-rt + uint32_t batch_size_; + uint32_t workspace_size_; +}; +} // namespace mindspore::api +#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/graph/graph_utils.h b/mindspore/ccsrc/cxx_api/graph/graph_utils.h new file mode 100644 index 00000000000..6a087e019d6 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/graph_utils.h @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H +#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H +#include +#include +#include "include/api/types.h" +#include "ir/dtype/type_id.h" +#include "utils/log_adapter.h" + +namespace mindspore::api { +class GraphUtils { + public: + static DataType TransTypeId2InferDataType(TypeId type_id) { + const std::map id2type_map{ + {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, + {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, + {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, + {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, + {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, + {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, + {TypeId::kNumberTypeFloat32, api::kMsFloat32}, + }; + + auto it = id2type_map.find(type_id); + if (it != id2type_map.end()) { + return it->second; + } + + MS_LOG(WARNING) << "Unsupported data id " << type_id; + return api::kMsUnknown; + } + + template + inline static void ClearIfNotNull(T *vec) { + if (vec != nullptr) { + vec->clear(); + } + } + + template + inline static void PushbackIfNotNull(U *vec, T &&item) { + if (vec != nullptr) { + vec->emplace_back(item); + } + } +}; +} // namespace mindspore::api + +#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc index 034d464d6ba..7349aba9432 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -22,6 +22,7 @@ namespace mindspore { namespace api { API_FACTORY_REG(ModelImpl, Ascend910, MsModel); +API_FACTORY_REG(ModelImpl, GPU, MsModel); Status MsModel::Build(const std::map &) { MS_LOG(INFO) << "Start build model."; diff --git a/mindspore/ccsrc/cxx_api/types.cc b/mindspore/ccsrc/cxx_api/types.cc index 74d4c1bb99e..98178f108b7 100644 --- a/mindspore/ccsrc/cxx_api/types.cc +++ b/mindspore/ccsrc/cxx_api/types.cc @@ -21,6 +21,7 @@ namespace mindspore::api { const char *kDeviceTypeAscend310 = "Ascend310"; const char *kDeviceTypeAscend910 = "Ascend910"; +const char *kDeviceTypeGpu = "GPU"; class DataImpl { public: diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_callback_register.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_callback_register.cc index 202f28f7b05..c4056ab84b4 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_callback_register.cc +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_callback_register.cc @@ -31,7 +31,7 @@ constexpr Status PROF_FAILED = 0xFFFFFFFF; } // namespace Status RegProfCtrlCallback(MsprofCtrlCallback func) { - if (VMCallbackRegister::GetInstance().registed()) { + if (VMCallbackRegister::GetInstance().registered()) { return VMCallbackRegister::GetInstance().DoRegProfCtrlCallback(func); } else { return PROF_SUCCESS; @@ -39,7 +39,7 @@ Status RegProfCtrlCallback(MsprofCtrlCallback func) { } Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { - if (VMCallbackRegister::GetInstance().registed()) { + if (VMCallbackRegister::GetInstance().registered()) { return VMCallbackRegister::GetInstance().DoRegProfSetDeviceCallback(func); } else { return PROF_SUCCESS; @@ -47,7 +47,7 @@ Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { } Status RegProfReporterCallback(MsprofReporterCallback func) { - if (VMCallbackRegister::GetInstance().registed()) { + if (VMCallbackRegister::GetInstance().registered()) { return VMCallbackRegister::GetInstance().DoRegProfReporterCallback(func); } else { return PROF_SUCCESS; @@ -55,7 +55,7 @@ Status RegProfReporterCallback(MsprofReporterCallback func) { } Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { - if (VMCallbackRegister::GetInstance().registed()) { + if (VMCallbackRegister::GetInstance().registered()) { return VMCallbackRegister::GetInstance().DoProfCommandHandle(type, data, len); } else { return PROF_SUCCESS; @@ -69,16 +69,16 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() { return instance; } -bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), - Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), - Status (*pRegProfReporterCallback)(MsprofReporterCallback), - Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { - if (!registed_) { +bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), + Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), + Status (*pRegProfReporterCallback)(MsprofReporterCallback), + Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { + if (!registered_) { pRegProfCtrlCallback_ = pRegProfCtrlCallback; pRegProfSetDeviceCallback_ = pRegProfSetDeviceCallback; pRegProfReporterCallback_ = pRegProfReporterCallback; pProfCommandHandle_ = pProfCommandHandle; - registed_ = true; + registered_ = true; ForceMsprofilerInit(); return true; } diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_callback_register.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_callback_register.h index a90d7e836d6..ace8c4631d3 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_callback_register.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_callback_register.h @@ -49,12 +49,12 @@ class __attribute__((visibility("default"))) VMCallbackRegister { static VMCallbackRegister &GetInstance(); VMCallbackRegister(const VMCallbackRegister &) = delete; VMCallbackRegister &operator=(const VMCallbackRegister &) = delete; - bool Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), - Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), - Status (*pRegProfReporterCallback)(MsprofReporterCallback), - Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)); + bool Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), + Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), + Status (*pRegProfReporterCallback)(MsprofReporterCallback), + Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)); void ForceMsprofilerInit(); - bool registed() { return registed_; } + bool registered() { return registered_; } Status DoRegProfCtrlCallback(MsprofCtrlCallback func) { return pRegProfCtrlCallback_(func); } Status DoRegProfSetDeviceCallback(MsprofSetDeviceCallback func) { return pRegProfSetDeviceCallback_(func); } Status DoRegProfReporterCallback(MsprofReporterCallback func) { return pRegProfReporterCallback_(func); } @@ -64,7 +64,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister { private: VMCallbackRegister() - : registed_(false), + : registered_(false), ms_profile_inited_(false), pRegProfCtrlCallback_(nullptr), pRegProfSetDeviceCallback_(nullptr), @@ -72,7 +72,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister { pProfCommandHandle_(nullptr) {} ~VMCallbackRegister() = default; - bool registed_; + bool registered_; bool ms_profile_inited_; Status (*pRegProfCtrlCallback_)(MsprofCtrlCallback); Status (*pRegProfSetDeviceCallback_)(MsprofSetDeviceCallback); diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc index 5be04d72efc..50232db9108 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc @@ -299,8 +299,8 @@ Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { bool DoRegiste() { MS_LOG(INFO) << "VM profiling register start"; - return VMCallbackRegister::GetInstance().Registe(RegProfCtrlCallback, RegProfSetDeviceCallback, - RegProfReporterCallback, ProfCommandHandle); + return VMCallbackRegister::GetInstance().Register(RegProfCtrlCallback, RegProfSetDeviceCallback, + RegProfReporterCallback, ProfCommandHandle); } static bool doRegiste = DoRegiste(); } // namespace ascend diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 3ece84d9651..6fd62ec2279 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -41,6 +41,7 @@ const char kCPUDevice[] = "CPU"; const char kGPUDevice[] = "GPU"; const char kAscendDevice[] = "Ascend"; const char kDavinciInferenceDevice[] = "AscendInference"; +const char kGpuInferenceDevice[] = "GpuInference"; const char kDavinciDevice[] = "Davinci"; const char KNpuLog[] = "_npu_log"; const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; @@ -51,7 +52,7 @@ const float kDefaultMaxDeviceMemory = 1024; // enum definition for MindSpore Context Parameter enum MsCtxParam : unsigned { - // paramater of type bool + // parameter of type bool MS_CTX_TYPE_BOOL_BEGIN, MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN, MS_CTX_CHECK_BPROP_FLAG, @@ -74,14 +75,15 @@ enum MsCtxParam : unsigned { MS_CTX_ENABLE_PROFILING, MS_CTX_SAVE_GRAPHS_FLAG, MS_CTX_ENABLE_PARALLEL_SPLIT, + MS_CTX_ENABLE_INFER_OPT, MS_CTX_TYPE_BOOL_END, - // paramater of type int + // parameter of type int MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END, MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN, MS_CTX_TYPE_INT_END, - // paramater of type uint32 + // parameter of type uint32 MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, MS_CTX_GE_REF, @@ -89,12 +91,12 @@ enum MsCtxParam : unsigned { MS_CTX_TSD_REF, MS_CTX_TYPE_UINT32_END, - // paramater of type float + // parameter of type float MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, MS_CTX_TYPE_FLOAT_END, - // paramater of type string + // parameter of type string MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END, MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN, MS_CTX_GRAPH_MEMORY_MAX_SIZE,