diff --git a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc index 7bccbcf654d..fd9ba1734a4 100644 --- a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc @@ -20,7 +20,6 @@ #include "cxx_api/akg_kernel_register.h" #include "cxx_api/acl_utils.h" #include "utils/log_adapter.h" -#include "runtime/device/context_extends.h" #include "mindspore/core/base/base_ref_utils.h" #include "backend/common/session/session_factory.h" #include "backend/common/session/executor_manager.h" @@ -28,6 +27,7 @@ #include "runtime/dev.h" #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "include/common/utils/python_adapter.h" +#include "runtime/hardware/device_context_manager.h" namespace mindspore { namespace { @@ -47,12 +47,15 @@ void InitHccl() { #ifndef ENABLE_SECURITY runtime_instance->PreInit(); #endif - (void)context::OpenTsd(ms_context); + const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {kAscendDevice, ms_context->get_param(MS_CTX_DEVICE_ID)}); + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); + (void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context); + if (!runtime_instance->Init()) { MS_LOG(EXCEPTION) << "Runtime init failed."; } - } else { - (void)context::OpenTsd(ms_context); } } @@ -381,7 +384,11 @@ AscendGraphImpl::MsEnvGuard::~MsEnvGuard() { if (ms_context->get_param(MS_CTX_ENABLE_HCCL)) { PythonEnvGuard guard; - if (!context::CloseTsd(ms_context)) { + const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {kAscendDevice, ms_context->get_param(MS_CTX_DEVICE_ID)}); + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); + if (!device_context->GetDeprecatedInterface()->CloseTsd(ms_context)) { MS_LOG(ERROR) << "CloseTsd failed!"; return; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index a98311e467e..f35d5e00ae8 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -44,7 +44,6 @@ #include "include/common/utils/config_manager.h" #include "include/common/utils/convert_utils.h" #include "include/common/utils/convert_utils_py.h" -#include "runtime/device/context_extends.h" #include "utils/ms_context.h" #include "utils/shape_utils.h" #include "utils/info.h" @@ -1376,10 +1375,17 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba std::string name = MsContext::GetInstance()->backend_policy(); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); -#ifndef NO_DLIB - if (!context::IsTsdOpened(ms_context)) { - InitPipeline(); +#ifdef WITH_BACKEND + if (ms_context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { + auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {kAscendDevice, ms_context->get_param(MS_CTX_DEVICE_ID)}); + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); + if (!device_context->GetDeprecatedInterface()->IsTsdOpened(ms_context)) { + InitPipeline(); + } } + #endif if (iter_num == -1) { iter_num = INT32_MAX; @@ -1390,10 +1396,8 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba std::string backend = ms_context->backend_policy(); #ifdef WITH_BACKEND if (backend == "ge") { - MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( - {MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET), - MsContext::GetInstance()->get_param(MS_CTX_DEVICE_ID)}); + {ms_context->get_param(MS_CTX_DEVICE_TARGET), ms_context->get_param(MS_CTX_DEVICE_ID)}); MS_EXCEPTION_IF_NULL(device_context); MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); @@ -1534,15 +1538,15 @@ void InitHccl() { #endif mindspore::python_adapter::set_python_env_flag(true); + std::string device_name = ms_context->get_param(MS_CTX_DEVICE_TARGET); uint32_t device_id = ms_context->get_param(MS_CTX_DEVICE_ID); - if (common::UseMPI() && ms_context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { + if (common::UseMPI() && device_name == kAscendDevice) { const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( - {ms_context->get_param(MS_CTX_DEVICE_TARGET), ms_context->get_param(MS_CTX_DEVICE_ID)}); + {device_name, ms_context->get_param(MS_CTX_DEVICE_ID)}); MS_EXCEPTION_IF_NULL(device_context); MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); device_id = device_context->GetDeprecatedInterface()->InitCollective(); } - std::string device_name = ms_context->get_param(MS_CTX_DEVICE_TARGET); ms_context->set_param(MS_CTX_ENABLE_HCCL, true); if (ms_context->backend_policy() == "ms" && ms_context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { @@ -1551,12 +1555,14 @@ void InitHccl() { #ifndef ENABLE_SECURITY runtime_instance->PreInit(); #endif - (void)context::OpenTsd(ms_context); + const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {device_name, ms_context->get_param(MS_CTX_DEVICE_ID)}); + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); + (void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context); if (!runtime_instance->Init()) { MS_LOG(EXCEPTION) << "Runtime init failed."; } - } else { - (void)context::OpenTsd(ms_context); } } @@ -1633,11 +1639,18 @@ FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_ return func_graph; } -void ReleaseGeTsd() { +void CloseTsd(bool force) { +#ifdef WITH_BACKEND auto context_ptr = MsContext::GetInstance(); - if (context_ptr != nullptr) { - (void)context::CloseTsd(context_ptr, true); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { + const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {kAscendDevice, context_ptr->get_param(MS_CTX_DEVICE_ID)}); + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); + (void)device_context->GetDeprecatedInterface()->CloseTsd(context_ptr, force); } +#endif } void InitPipeline() { @@ -1648,22 +1661,23 @@ void InitPipeline() { MS_EXCEPTION_IF_NULL(ms_context); #ifdef WITH_BACKEND auto backend = ms_context->backend_policy(); + auto device_name = ms_context->get_param(MS_CTX_DEVICE_TARGET); + const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {device_name, ms_context->get_param(MS_CTX_DEVICE_ID)}); + MS_EXCEPTION_IF_NULL(device_context); if (backend == "ge") { - const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( - {ms_context->get_param(MS_CTX_DEVICE_TARGET), ms_context->get_param(MS_CTX_DEVICE_ID)}); device_context->Initialize(); } -#endif - if (!context::OpenTsd(ms_context)) { - MS_LOG(EXCEPTION) << "Open tsd failed"; + if (device_name == kAscendDevice) { + MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); + if (!device_context->GetDeprecatedInterface()->OpenTsd(ms_context)) { + MS_LOG(EXCEPTION) << "Open tsd failed"; + } } +#endif } -void FinalizeBackend() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - (void)context::CloseTsd(context_ptr); -} +void FinalizeBackend() { CloseTsd(); } void MemoryRecycle() { #ifdef ENABLE_DUMP_IR @@ -1768,8 +1782,6 @@ void ClearResAtexit() { device::DeviceContextManager::GetInstance().ClearDeviceContexts(); MS_LOG(INFO) << "End clear device context."; - ReleaseGeTsd(); - MS_LOG(INFO) << "Start clear AnalysisResultCacheMgr..."; abstract::AnalysisResultCacheMgr::GetInstance().Clear(); MS_LOG(INFO) << "End clear AnalysisResultCacheMgr."; diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 43b55eacb49..ebeb8a77752 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -179,7 +179,7 @@ uint32_t GetHcclRankSize(); void InitPipeline(); void FinalizeBackend(); void ClearResAtexit(); -void ReleaseGeTsd(); +void CloseTsd(bool force = false); void MemoryRecycle(); FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode, diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 986a60e03d6..c027ffd8b9d 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -38,7 +38,6 @@ #include "include/common/utils/utils.h" #include "utils/ms_context.h" #include "utils/check_convert_utils.h" -#include "runtime/device/context_extends.h" #include "include/common/utils/config_manager.h" #include "include/common/utils/convert_utils_py.h" #include "include/common/utils/scoped_long_running.h" @@ -2505,11 +2504,18 @@ MsBackendPolicy ForwardExecutor::GetBackendPolicy(const OpExecInfoPtr &op_exec_i if (ms_context->backend_policy() == "ge") { MS_LOG(EXCEPTION) << "In PyNative mode, not support ge backend!"; } - if (!context::IsTsdOpened(ms_context)) { - if (!context::OpenTsd(ms_context)) { +#ifdef WITH_BACKEND + const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {kAscendDevice, ms_context->get_param(MS_CTX_DEVICE_ID)}); + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); + + if (!device_context->GetDeprecatedInterface()->IsTsdOpened(ms_context)) { + if (!device_context->GetDeprecatedInterface()->OpenTsd(ms_context)) { MS_LOG(EXCEPTION) << "Open tsd failed"; } } +#endif } return backend_policy; } diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc index febc5721096..a8d4a98d022 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc @@ -25,7 +25,6 @@ #include "plugin/device/ascend/hal/device/ascend_device_address.h" #include "plugin/device/ascend/hal/device/distribute/ascend_collective.h" #include "utils/ms_context.h" -#include "runtime/device/context_extends.h" #include "include/common/utils/mpi/mpi_config.h" #include "runtime/device/ms_device_shape_transfer.h" #include "runtime/rt.h" @@ -1275,9 +1274,6 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) { bool AscendKernelRuntime::HcclInit() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!context::IsTsdOpened(context_ptr)) { - MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; - } MS_LOG(INFO) << "Do hcom init."; auto device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); std::string rank_id_str = GetRankIdStr(); diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/readme.md b/mindspore/ccsrc/plugin/device/ascend/hal/device/readme.md deleted file mode 100644 index 037e3b658ad..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/readme.md +++ /dev/null @@ -1 +0,0 @@ -ascend diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/tensorprint_utils.cc similarity index 97% rename from mindspore/ccsrc/utils/tensorprint_utils.cc rename to mindspore/ccsrc/plugin/device/ascend/hal/device/tensorprint_utils.cc index 1cff5a4eba6..94637690bf1 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/tensorprint_utils.cc @@ -13,23 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "include/common/utils/tensorprint_utils.h" +#include "plugin/device/ascend/hal/device/tensorprint_utils.h" #include #include #include #include -#include #include "ir/tensor.h" #include "pybind11/pybind11.h" #include "include/common/utils/utils.h" -#include "utils/ms_utils.h" #include "utils/shape_utils.h" #include "mindspore/core/utils/file_utils.h" namespace py = pybind11; namespace mindspore { - -#ifndef NO_DLIB static std::map print_acl_data_type_map = { {ACL_INT8, TypeId::kNumberTypeInt8}, {ACL_UINT8, TypeId::kNumberTypeUInt8}, {ACL_INT16, TypeId::kNumberTypeInt16}, {ACL_UINT16, TypeId::kNumberTypeUInt16}, @@ -75,7 +71,7 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co } template -void PrintScalarToString(const char *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) { +void PrintScalarToString(const void *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) { MS_EXCEPTION_IF_NULL(str_data_ptr); MS_EXCEPTION_IF_NULL(buf); *buf << "Tensor(shape=[], dtype=" << GetParseType(acl_data_type) << ", value="; @@ -101,13 +97,13 @@ void PrintScalarToBoolString(const char *str_data_ptr, const aclDataType &acl_da } } -void convertDataItem2Scalar(const char *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) { +void ConvertDataItem2Scalar(const void *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) { MS_EXCEPTION_IF_NULL(str_data_ptr); MS_EXCEPTION_IF_NULL(buf); auto type_iter = print_acl_data_type_map.find(acl_data_type); auto type_id = type_iter->second; if (type_id == TypeId::kNumberTypeBool) { - PrintScalarToBoolString(str_data_ptr, acl_data_type, buf); + PrintScalarToBoolString(reinterpret_cast(str_data_ptr), acl_data_type, buf); } else if (type_id == TypeId::kNumberTypeInt8) { PrintScalarToString(str_data_ptr, acl_data_type, buf); } else if (type_id == TypeId::kNumberTypeUInt8) { @@ -178,7 +174,7 @@ bool ConvertDataset2Tensor(acltdtDataset *acl_dataset) { if (!judgeLengthValid(acl_data_size, acl_data_type)) { MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; } - convertDataItem2Scalar(acl_data, acl_data_type, &buf); + ConvertDataItem2Scalar(reinterpret_cast(acl_data), acl_data_type, &buf); continue; } @@ -210,7 +206,6 @@ bool SaveDataset2File(acltdtDataset *acl_dataset, const std::string &print_file_ acltdtDataItem *item = acltdtGetDataItem(acl_dataset, i); MS_EXCEPTION_IF_NULL(item); acltdtTensorType acl_tensor_type = acltdtGetTensorTypeFromItem(item); - if (acl_tensor_type == ACL_TENSOR_DATA_END_OF_SEQUENCE) { MS_LOG(INFO) << "Acl channel received end-of-sequence for print op."; ret_end_thread = true; @@ -358,5 +353,4 @@ void TensorPrint::operator()() { TensorPrintOut2File(acl_handle_, print_file_path_); } } -#endif } // namespace mindspore diff --git a/mindspore/ccsrc/include/common/utils/tensorprint_utils.h b/mindspore/ccsrc/plugin/device/ascend/hal/device/tensorprint_utils.h similarity index 81% rename from mindspore/ccsrc/include/common/utils/tensorprint_utils.h rename to mindspore/ccsrc/plugin/device/ascend/hal/device/tensorprint_utils.h index 6024e8e0fd0..7ae6a315599 100644 --- a/mindspore/ccsrc/include/common/utils/tensorprint_utils.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/tensorprint_utils.h @@ -14,13 +14,12 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_ -#define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_ #include #include #include "ir/dtype/type.h" -#ifndef NO_DLIB #include "acl/acl_tdt.h" #include "tdt/tsd_client.h" #include "tdt/data_common.h" @@ -41,5 +40,4 @@ class COMMON_EXPORT TensorPrint { const acltdtChannelHandle *acl_handle_; }; } // namespace mindspore -#endif -#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_ +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc index f7212572695..af08a9f48ff 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc @@ -15,12 +15,207 @@ */ #include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h" +#include +#include "plugin/device/ascend/hal/hardware/ge_device_context.h" +#include "include/transform/graph_ir/types.h" +#include "include/transform/graph_ir/utils.h" +#include "include/common/utils/scoped_long_running.h" +#include "graph/model.h" +#include "transform/graph_ir/op_adapter_map.h" +#include "plugin/device/ascend/hal/device/tensorprint_utils.h" +#include "acl/acl_tdt.h" +#include "runtime/dev.h" +#include "toolchain/plog.h" +#include "common/util/error_manager/error_manager.h" #include "plugin/device/ascend/hal/device/distribute/ascend_collective.h" #include "plugin/device/ascend/hal/profiler/parallel_strategy_profiling.h" +using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractTensor; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; +using mindspore::transform::GeTensorPtr; +using mindspore::transform::MeTensorPtr; +using mindspore::transform::Status; + +namespace py = pybind11; + namespace mindspore { namespace device { namespace ascend { +namespace { +constexpr auto kUnknowErrorString = "Unknown error occurred"; + +void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *const tensors) { + for (auto item : dict) { + if ((!py::isinstance(item.first))) { + MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; + continue; + } + std::shared_ptr tensor; + std::string name = py::cast(item.first); + if (py::isinstance(item.second.attr("data"))) { + // convert float to tensor with shape([1]) + tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); + *(static_cast(tensor->data_c())) = py::cast(item.second.attr("data")); + } else if (py::isinstance(item.second.attr("data"))) { + // convert int64_t to tensor with shape([1]) + tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); + *(static_cast(tensor->data_c())) = py::cast(item.second.attr("data")); + } else if (py::isinstance(item.second.attr("data"))) { + // cast tensor + tensor = py::cast>(item.second.attr("data")); + } + + if (tensor == nullptr) { + MS_LOG(EXCEPTION) << "Get default value for " << name << " failed"; + } + (void)tensors->emplace(name, tensor); + } +} +} // namespace + +void AscendDeprecatedInterface::DoExecNonInputGraph(const std::string &phase) { + std::vector ge_tensors; + std::vector ge_outputs; + transform::RunOptions run_options; + run_options.name = phase; + auto graph_runner = transform::GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(ERROR) << "Can not found GraphRunner"; + return; + } + + { + // Release GIL before calling into (potentially long-running) C++ code + ScopedLongRunning release; + Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs); + if (ret != Status::SUCCESS) { + MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed"; + return; + } + } +} + +bool AscendDeprecatedInterface::InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, + const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + ge_device_context_->Initialize(); + std::vector ge_types; + (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), + [](const TypePtr &i) -> int64_t { return transform::ConvertDataType(i->type_id()); }); + + ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE); + ConfigManager::GetInstance().set_iter_num(queue_name, size); + ConfigManager::GetInstance().set_dataset_phase(phase); + + DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes); + ConfigManager::GetInstance().set_dataset_param(param); + + auto env_ge = common::GetEnv("MS_ENABLE_GE"); + auto env_training = common::GetEnv("MS_GE_TRAIN"); + bool training = false; + if (env_ge == "1" && env_training == "1") { + training = true; + } + if (training) { + (void)setenv("GE_TRAIN", "1", 1); + } else { + (void)setenv("GE_TRAIN", "0", 1); + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + + if (!ms_context->get_param(MS_CTX_ENABLE_GE_HETEROGENOUS)) { + if (transform::CompileDatasetGraph(param, phase) != transform::SUCCESS) { + MS_LOG(ERROR) << "Build dateset graph failed."; + return false; + } + + GeDeviceResManager::CreateSessionAndGraphRunner(training); + + MS_LOG(INFO) << "DoExecNonInputGraph:" << phase; + DoExecNonInputGraph(phase); + } + + return true; +} + +void AscendDeprecatedInterface::ExportDFGraph(const std::string &file_name, const std::string &phase, + const py::object &encrypt, char *key) { + MS_LOG(DEBUG) << "Export graph begin."; + transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase); + if (wrap_ptr == nullptr) { + MS_LOG(ERROR) << "Get graph form DfGraphManager failed, phase = " << phase; + return; + } + + transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_; + if (ge_graph == nullptr) { + MS_LOG(ERROR) << "Graph is null!"; + return; + } + if (key != nullptr) { + if (py::isinstance(encrypt)) { + MS_LOG(ERROR) << "ERROR: encrypt is not a function"; + return; + } + // get model stream + ge::Model model("", ""); + model.SetGraph(*ge_graph); + ge::Buffer model_data; + auto ge_ret = model.Save(model_data); + if (ge_ret != ge::SUCCESS) { + MS_LOG(ERROR) << "ERROR: GE model save fail"; + return; + } + // convert model and key into py::bytes + const std::string str(reinterpret_cast(model_data.GetData()), model_data.GetSize()); + py::bytes model_bytes(str); + py::bytes key_bytes(key); + + // call python encrypt func + py::bytes encrypted_model_stream = encrypt(model_bytes, key_bytes); + if (encrypted_model_stream == py::none()) { + MS_LOG(ERROR) << "ERROR: Model encrypt fail"; + return; + } + // save to file + std::ofstream ofs(file_name); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "ERROR: Open File '" << file_name << "' failed!"; + return; + } + ofs << std::string(encrypted_model_stream); + ofs.close(); + } else { + if (ge_graph->SaveToFile(file_name) != 0) { + MS_LOG(EXCEPTION) << "Export air model failed."; + } + } + MS_LOG(INFO) << "Export air model finish."; +} + +FuncGraphPtr AscendDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) { + MS_EXCEPTION_IF_NULL(anf_graph); + transform::TensorOrderMap init_tensors{}; + ConvertObjectToTensors(init_params, &init_tensors); + return GeGraphExecutor::BuildDFGraph(anf_graph, init_tensors, true); +} + +void AscendDeprecatedInterface::ClearGraphWrapper() { transform::DfGraphManager::GetInstance().ClearGraph(); } + +void AscendDeprecatedInterface::ClearOpAdapterMap() { transform::OpAdapterMap::get().clear(); } + +void AscendDeprecatedInterface::EraseGeResource() { + transform::DfGraphManager::GetInstance().DeleteGraphRunner(); + transform::DfGraphManager::GetInstance().EraseAnfGraph(); + transform::DfGraphManager::GetInstance().DeleteGeSession(); +} + uint32_t AscendDeprecatedInterface::InitCollective() { #ifdef WITH_BACKEND auto ms_context = MsContext::GetInstance(); @@ -45,6 +240,112 @@ uint32_t AscendDeprecatedInterface::InitCollective() { void AscendDeprecatedInterface::DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) { return profiler::ascend::DumpProfileParallelStrategy(func_graph); } + +bool AscendDeprecatedInterface::OpenTsd(const std::shared_ptr &ms_context_ptr) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } + + if (ms_context_ptr->get_param(MS_CTX_IS_PYNATIVE_GE_INIT)) { + return true; + } + + if (ms_context_ptr->get_param(MS_CTX_TSD_REF) != 0) { + MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened."; + ms_context_ptr->increase_param(MS_CTX_TSD_REF); + return true; + } + + auto role = common::GetEnv("MS_ROLE"); + if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) { + return true; + } + + uint32_t device_id = ms_context_ptr->get_param(MS_CTX_DEVICE_ID); + + uint32_t rank_size; + auto rank_size_env = common::GetEnv("RANK_SIZE"); + if (rank_size_env.empty()) { + MS_LOG(INFO) << "Should config rank size."; + rank_size = 1; + } else { + int rank_env = std::stoi(rank_size_env); + if (rank_env <= 0) { + MS_LOG(EXCEPTION) << "Error rank size " << rank_env << "."; + } + rank_size = IntToUint(rank_env); + } + + int log_ret = DlogReportInitialize(); + if (log_ret != 0) { + MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret; + } + + if (ErrorManager::GetInstance().Init() != 0) { + MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out."; + } + MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; + auto ret = rtSetDevice(static_cast(device_id)); + if (ret != RT_ERROR_NONE) { + const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage(); + if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) { + MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message; + } + MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast(ret) << "]"; + } + ms_context_ptr->increase_param(MS_CTX_TSD_REF); +#ifdef ENABLE_TDTQUE + auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) { + return std::thread(TensorPrint(path, acl_handle)); + }; + ms_context_ptr->CreateTensorPrintThread(thread_crt); +#endif + return true; +} + +bool AscendDeprecatedInterface::CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "ms_context_prt is nullptr"; + } + MS_LOG(INFO) << "Start to close tsd, ref = " << ms_context_ptr->get_param(MS_CTX_TSD_REF); + if (ms_context_ptr->get_param(MS_CTX_TSD_REF) == 0) { + return true; + } + ms_context_ptr->decrease_param(MS_CTX_TSD_REF); + if (force || ms_context_ptr->get_param(MS_CTX_TSD_REF) == 0) { + ms_context_ptr->set_param(MS_CTX_TSD_REF, 0); +#ifdef ENABLE_TDTQUE + pybind11::gil_scoped_release gil_release; + ms_context_ptr->DestroyTensorPrintThread(); +#endif + if (ErrorManager::GetInstance().Init() != 0) { + MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out."; + } + uint32_t device_id = ms_context_ptr->get_param(MS_CTX_DEVICE_ID); + auto ret = rtDeviceReset(static_cast(device_id)); + if (ret != RT_ERROR_NONE) { + const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage(); + if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) { + MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message; + } + MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast(ret) << "]"; + } + ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, false); + MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast(ret) << "]"; + (void)DlogReportFinalize(); + } else { + MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = " + << ms_context_ptr->get_param(MS_CTX_TSD_REF) << "."; + } + return true; +} + +bool AscendDeprecatedInterface::IsTsdOpened(const std::shared_ptr &ms_context_ptr) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } + return ms_context_ptr->get_param(MS_CTX_TSD_REF) > 0; +} } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h index ad91ee3fc62..7187c421894 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h @@ -27,21 +27,36 @@ namespace mindspore { namespace device { namespace ascend { -class AscendDeviceContext; +class GeDeviceContext; class AscendDeprecatedInterface : public DeprecatedInterface { public: - explicit AscendDeprecatedInterface(AscendDeviceContext *ascend_device_context) - : ascend_device_context_(ascend_device_context) {} + explicit AscendDeprecatedInterface(GeDeviceContext *ge_device_context) : ge_device_context_(ge_device_context) {} + ~AscendDeprecatedInterface() override = default; + // for ge + void DoExecNonInputGraph(const std::string &phase) override; + bool InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) override; + void ExportDFGraph(const std::string &file_name, const std::string &phase, const pybind11::object &encrypt, + char *key) override; + FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) override; + void ClearGraphWrapper() override; + void ClearOpAdapterMap() override; + void EraseGeResource() override; + // for ascend uint32_t InitCollective() override; void DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) override; + bool OpenTsd(const std::shared_ptr &ms_context_ptr) override; + bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) override; + bool IsTsdOpened(const std::shared_ptr &inst_context) override; + private: - AscendDeviceContext *const ascend_device_context_; + GeDeviceContext *const ge_device_context_; }; } // namespace ascend } // namespace device } // namespace mindspore - #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_DEPRECATED_INTERFACE_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc index 6e812c5009d..8a6653b5ad0 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc @@ -65,6 +65,9 @@ void AscendDeviceContext::Destroy() { #endif MS_LOG(INFO) << "Status record: Enter Destroy..."; if (!initialized_) { + if (deprecated_interface_ != nullptr) { + deprecated_interface_->CloseTsd(MsContext::GetInstance(), true); + } return; } @@ -77,6 +80,9 @@ void AscendDeviceContext::Destroy() { if (runtime_instance_) { runtime_instance_ = nullptr; } + if (deprecated_interface_ != nullptr) { + deprecated_interface_->CloseTsd(MsContext::GetInstance(), true); + } initialized_ = false; MS_LOG(INFO) << "Status record: Destroy success."; } @@ -101,7 +107,7 @@ RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const { DeprecatedInterface *AscendDeviceContext::GetDeprecatedInterface() { // need lock when multi-threads if (deprecated_interface_ == nullptr) { - deprecated_interface_ = std::make_unique(this); + deprecated_interface_ = std::make_unique(nullptr); } return deprecated_interface_.get(); } diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_deprecated_interface.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_deprecated_interface.cc deleted file mode 100644 index b14a4247ce2..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_deprecated_interface.cc +++ /dev/null @@ -1,211 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/hardware/ge_deprecated_interface.h" -#include -#include "plugin/device/ascend/hal/hardware/ge_device_context.h" -#include "include/transform/graph_ir/types.h" -#include "include/transform/graph_ir/utils.h" -#include "include/common/utils/scoped_long_running.h" -#include "graph/model.h" -#include "transform/graph_ir/op_adapter_map.h" - -using mindspore::abstract::AbstractScalar; -using mindspore::abstract::AbstractTensor; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractTuplePtr; -using mindspore::transform::GeTensorPtr; -using mindspore::transform::MeTensorPtr; -using mindspore::transform::Status; - -namespace py = pybind11; - -namespace mindspore { -namespace device { -namespace ascend { -namespace { -void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *const tensors) { - for (auto item : dict) { - if ((!py::isinstance(item.first))) { - MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; - continue; - } - std::shared_ptr tensor; - std::string name = py::cast(item.first); - if (py::isinstance(item.second.attr("data"))) { - // convert float to tensor with shape([1]) - tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); - *(static_cast(tensor->data_c())) = py::cast(item.second.attr("data")); - } else if (py::isinstance(item.second.attr("data"))) { - // convert int64_t to tensor with shape([1]) - tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); - *(static_cast(tensor->data_c())) = py::cast(item.second.attr("data")); - } else if (py::isinstance(item.second.attr("data"))) { - // cast tensor - tensor = py::cast>(item.second.attr("data")); - } - - if (tensor == nullptr) { - MS_LOG(EXCEPTION) << "Get default value for " << name << " failed"; - } - (void)tensors->emplace(name, tensor); - } -} -} // namespace - -void GeDeprecatedInterface::DoExecNonInputGraph(const std::string &phase) { - std::vector ge_tensors; - std::vector ge_outputs; - transform::RunOptions run_options; - run_options.name = phase; - auto graph_runner = transform::GetGraphRunner(); - if (graph_runner == nullptr) { - MS_LOG(ERROR) << "Can not found GraphRunner"; - return; - } - - { - // Release GIL before calling into (potentially long-running) C++ code - ScopedLongRunning release; - Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs); - if (ret != Status::SUCCESS) { - MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed"; - return; - } - } -} - -bool GeDeprecatedInterface::InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, - const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - ge_device_context_->Initialize(); - std::vector ge_types; - (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), - [](const TypePtr &i) -> int64_t { return transform::ConvertDataType(i->type_id()); }); - - ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE); - ConfigManager::GetInstance().set_iter_num(queue_name, size); - ConfigManager::GetInstance().set_dataset_phase(phase); - - DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes); - ConfigManager::GetInstance().set_dataset_param(param); - - auto env_ge = common::GetEnv("MS_ENABLE_GE"); - auto env_training = common::GetEnv("MS_GE_TRAIN"); - bool training = false; - if (env_ge == "1" && env_training == "1") { - training = true; - } - if (training) { - (void)setenv("GE_TRAIN", "1", 1); - } else { - (void)setenv("GE_TRAIN", "0", 1); - } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - - if (!ms_context->get_param(MS_CTX_ENABLE_GE_HETEROGENOUS)) { - if (transform::CompileDatasetGraph(param, phase) != transform::SUCCESS) { - MS_LOG(ERROR) << "Build dateset graph failed."; - return false; - } - - GeDeviceResManager::CreateSessionAndGraphRunner(training); - - MS_LOG(INFO) << "DoExecNonInputGraph:" << phase; - DoExecNonInputGraph(phase); - } - - return true; -} - -void GeDeprecatedInterface::ExportDFGraph(const std::string &file_name, const std::string &phase, - const py::object &encrypt, char *key) { - MS_LOG(DEBUG) << "Export graph begin."; - transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase); - if (wrap_ptr == nullptr) { - MS_LOG(ERROR) << "Get graph form DfGraphManager failed, phase = " << phase; - return; - } - - transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_; - if (ge_graph == nullptr) { - MS_LOG(ERROR) << "Graph is null!"; - return; - } - if (key != nullptr) { - if (py::isinstance(encrypt)) { - MS_LOG(ERROR) << "ERROR: encrypt is not a function"; - return; - } - // get model stream - ge::Model model("", ""); - model.SetGraph(*ge_graph); - ge::Buffer model_data; - auto ge_ret = model.Save(model_data); - if (ge_ret != ge::SUCCESS) { - MS_LOG(ERROR) << "ERROR: GE model save fail"; - return; - } - // convert model and key into py::bytes - const std::string str(reinterpret_cast(model_data.GetData()), model_data.GetSize()); - py::bytes model_bytes(str); - py::bytes key_bytes(key); - - // call python encrypt func - py::bytes encrypted_model_stream = encrypt(model_bytes, key_bytes); - if (encrypted_model_stream == py::none()) { - MS_LOG(ERROR) << "ERROR: Model encrypt fail"; - return; - } - // save to file - std::ofstream ofs(file_name); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "ERROR: Open File '" << file_name << "' failed!"; - return; - } - ofs << std::string(encrypted_model_stream); - ofs.close(); - } else { - if (ge_graph->SaveToFile(file_name) != 0) { - MS_LOG(EXCEPTION) << "Export air model failed."; - } - } - MS_LOG(INFO) << "Export air model finish."; -} - -FuncGraphPtr GeDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) { - MS_EXCEPTION_IF_NULL(anf_graph); - transform::TensorOrderMap init_tensors{}; - ConvertObjectToTensors(init_params, &init_tensors); - return GeGraphExecutor::BuildDFGraph(anf_graph, init_tensors, true); -} - -void GeDeprecatedInterface::ClearGraphWrapper() { transform::DfGraphManager::GetInstance().ClearGraph(); } - -void GeDeprecatedInterface::ClearOpAdapterMap() { transform::OpAdapterMap::get().clear(); } - -void GeDeprecatedInterface::EraseGeResource() { - transform::DfGraphManager::GetInstance().DeleteGraphRunner(); - transform::DfGraphManager::GetInstance().EraseAnfGraph(); - transform::DfGraphManager::GetInstance().DeleteGeSession(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_deprecated_interface.h b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_deprecated_interface.h deleted file mode 100644 index 393e649b55d..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_deprecated_interface.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_ -#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_ - -#include -#include -#include -#include -#include "runtime/hardware/device_context.h" -#include "runtime/device/memory_manager.h" -#include "utils/ms_context.h" - -namespace mindspore { -namespace device { -namespace ascend { -class GeDeviceContext; - -class GeDeprecatedInterface : public DeprecatedInterface { - public: - explicit GeDeprecatedInterface(GeDeviceContext *ge_device_context) : ge_device_context_(ge_device_context) {} - void DoExecNonInputGraph(const std::string &phase) override; - bool InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase) override; - void ExportDFGraph(const std::string &file_name, const std::string &phase, const pybind11::object &encrypt, - char *key) override; - FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) override; - void ClearGraphWrapper() override; - void ClearOpAdapterMap() override; - void EraseGeResource() override; - - private: - GeDeviceContext *const ge_device_context_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_device_context.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_device_context.cc index 9d0b23ded11..3bddc24d4ce 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_device_context.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_device_context.cc @@ -427,7 +427,12 @@ void GeDeviceContext::Initialize() { initialized_ = InitGe(MsContext::GetInstance()); } -void GeDeviceContext::Destroy() { (void)FinalizeGe(MsContext::GetInstance()); } +void GeDeviceContext::Destroy() { + (void)FinalizeGe(MsContext::GetInstance()); + if (deprecated_interface_ != nullptr) { + deprecated_interface_->CloseTsd(MsContext::GetInstance(), true); + } +} void GeDeviceResManager::Initialize() { if (mem_manager_ == nullptr) { @@ -748,7 +753,7 @@ FuncGraphPtr GeGraphExecutor::BuildDFGraph(const FuncGraphPtr &anf_graph, DeprecatedInterface *GeDeviceContext::GetDeprecatedInterface() { // need lock when multi-threads if (deprecated_interface_ == nullptr) { - deprecated_interface_ = std::make_unique(this); + deprecated_interface_ = std::make_unique(this); } return deprecated_interface_.get(); } diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_device_context.h b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_device_context.h index d7bf16c5fc1..5b7a9df34ef 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_device_context.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_device_context.h @@ -20,7 +20,7 @@ #include #include #include -#include "plugin/device/ascend/hal/hardware/ge_deprecated_interface.h" +#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h" #include "runtime/hardware/device_context.h" #include "runtime/device/memory_manager.h" #include "utils/ms_context.h" @@ -93,7 +93,7 @@ class GeDeviceContext : public DeviceInterface &inst_context, std::map *ge_options); void SetDisableReuseMemoryFlag(std::map *ge_options); - std::unique_ptr deprecated_interface_; + std::unique_ptr deprecated_interface_; bool initialized_; }; } // namespace ascend diff --git a/mindspore/ccsrc/runtime/device/context_extends.cc b/mindspore/ccsrc/runtime/device/context_extends.cc index b857ec241fb..b2b93a4db84 100644 --- a/mindspore/ccsrc/runtime/device/context_extends.cc +++ b/mindspore/ccsrc/runtime/device/context_extends.cc @@ -15,151 +15,14 @@ */ #include "runtime/device/context_extends.h" -#include #include #include -#include #include -#include "pybind11/pybind11.h" -#include "include/common/utils/config_manager.h" #include "utils/ms_utils.h" -#include "utils/convert_utils_base.h" -#ifndef NO_DLIB -#include "acl/acl_tdt.h" -#include "runtime/dev.h" -#include "toolchain/plog.h" -#include "common/util/error_manager/error_manager.h" -#endif -#ifdef ENABLE_D -#include "debug/data_dump/dump_json_parser.h" -#include "include/transform/graph_ir/utils.h" -#endif -#include "profiler/device/profiling.h" - -namespace py = pybind11; +#include "utils/ms_context.h" namespace mindspore { namespace context { -#ifdef ENABLE_D -namespace { -constexpr auto kMindsporeDumpConfig = "MINDSPORE_DUMP_CONFIG"; -const std::vector kGeDumpMode = {"all", "input", "output"}; -} // namespace -#endif - -constexpr auto kUnknowErrorString = "Unknown error occurred"; -#ifndef NO_DLIB -// Open tdt dataset -bool OpenTsd(const std::shared_ptr &ms_context_ptr) { - if (ms_context_ptr == nullptr) { - MS_LOG(EXCEPTION) << "nullptr"; - } - - if (ms_context_ptr->get_param(MS_CTX_IS_PYNATIVE_GE_INIT)) { - return true; - } - - if (ms_context_ptr->get_param(MS_CTX_TSD_REF) != 0) { - MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened."; - ms_context_ptr->increase_param(MS_CTX_TSD_REF); - return true; - } - - auto role = common::GetEnv("MS_ROLE"); - if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) { - return true; - } - - uint32_t rank_size = 1; - uint32_t device_id = ms_context_ptr->get_param(MS_CTX_DEVICE_ID); - - auto rank_size_env = common::GetEnv("RANK_SIZE"); - if (rank_size_env.empty()) { - MS_LOG(INFO) << "Should config rank size."; - rank_size = 1; - } else { - int rank_env = std::stoi(rank_size_env); - if (rank_env <= 0) { - MS_LOG(EXCEPTION) << "Error rank size " << rank_env << "."; - } - rank_size = IntToUint(rank_env); - } - - int log_ret = DlogReportInitialize(); - if (log_ret != 0) { - MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret; - } - - if (ErrorManager::GetInstance().Init() != 0) { - MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out."; - } - MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; - auto ret = rtSetDevice(static_cast(device_id)); - if (ret != RT_ERROR_NONE) { - const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage(); - if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) { - MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message; - } - MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast(ret) << "]"; - } - ms_context_ptr->increase_param(MS_CTX_TSD_REF); -#ifdef ENABLE_TDTQUE - auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) { - return std::thread(TensorPrint(path, acl_handle)); - }; - ms_context_ptr->CreateTensorPrintThread(thread_crt); -#endif - return true; -} - -bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) { - if (ms_context_ptr == nullptr) { - MS_LOG(EXCEPTION) << "ms_context_prt is nullptr"; - } - if (ms_context_ptr->get_param(MS_CTX_TSD_REF) == 0) { - return true; - } - ms_context_ptr->decrease_param(MS_CTX_TSD_REF); - if (force || ms_context_ptr->get_param(MS_CTX_TSD_REF) == 0) { - ms_context_ptr->set_param(MS_CTX_TSD_REF, 0); - -#ifdef ENABLE_TDTQUE - py::gil_scoped_release gil_release; - ms_context_ptr->DestroyTensorPrintThread(); -#endif - if (ErrorManager::GetInstance().Init() != 0) { - MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out."; - } - uint32_t device_id = ms_context_ptr->get_param(MS_CTX_DEVICE_ID); - auto ret = rtDeviceReset(static_cast(device_id)); - if (ret != RT_ERROR_NONE) { - const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage(); - if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) { - MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message; - } - MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast(ret) << "]"; - } - ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, false); - MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast(ret) << "]"; - (void)DlogReportFinalize(); - } else { - MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = " - << ms_context_ptr->get_param(MS_CTX_TSD_REF) << "."; - } - return true; -} -#else -bool OpenTsd(const std::shared_ptr &ms_context_ptr) { return true; } -bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool) { return true; } -#endif - -bool IsTsdOpened(const std::shared_ptr &ms_context_ptr) { - if (ms_context_ptr == nullptr) { - MS_LOG(EXCEPTION) << "nullptr"; - } - return ms_context_ptr->get_param(MS_CTX_TSD_REF) > 0; -} - // Register for device type. struct DeviceTypeSetRegister { DeviceTypeSetRegister() { diff --git a/mindspore/ccsrc/runtime/device/context_extends.h b/mindspore/ccsrc/runtime/device/context_extends.h index 82298c41baa..29b3c6aeaf5 100644 --- a/mindspore/ccsrc/runtime/device/context_extends.h +++ b/mindspore/ccsrc/runtime/device/context_extends.h @@ -17,18 +17,6 @@ #ifndef MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H #define MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H -#include -#include "utils/ms_context.h" -#include "include/common/utils/tensorprint_utils.h" -#include "include/backend/visible.h" - -namespace mindspore { -namespace context { -BACKEND_EXPORT bool OpenTsd(const std::shared_ptr &ms_context_ptr); -BACKEND_EXPORT bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force = false); - -BACKEND_EXPORT bool IsTsdOpened(const std::shared_ptr &inst_context); -} // namespace context -} // namespace mindspore +// this file makes static-code-checking happy #endif // MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H diff --git a/mindspore/ccsrc/runtime/hardware/deprecated_interface.h b/mindspore/ccsrc/runtime/hardware/deprecated_interface.h index 5755740a232..bd2f2e2efd4 100644 --- a/mindspore/ccsrc/runtime/hardware/deprecated_interface.h +++ b/mindspore/ccsrc/runtime/hardware/deprecated_interface.h @@ -19,7 +19,9 @@ #include #include +#include #include "pybind11/pybind11.h" +#include "utils/ms_context.h" namespace mindspore { namespace device { @@ -47,7 +49,9 @@ class DeprecatedInterface { // ascend virtual uint32_t InitCollective() { return 0; } // return device id virtual void DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) {} - + virtual bool OpenTsd(const std::shared_ptr &ms_context_ptr) { return true; } + virtual bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force = false) { return true; } + virtual bool IsTsdOpened(const std::shared_ptr &inst_context) { return true; } // gpu }; } // namespace device diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 908f8622d13..4cd7afef3bc 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -147,8 +147,8 @@ class MS_CORE_API MsContext { ~MsContext() = default; MsContext(const MsContext &) = delete; MsContext &operator=(const MsContext &) = delete; - using DeviceSeter = std::function; - using DeviceTypeSeter = std::function &)>; + using DeviceSeter = void (*)(const std::string &device_target); + using DeviceTypeSeter = void (*)(std::shared_ptr &); static std::shared_ptr GetInstance(); bool enable_dump_ir() const; diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 2a4d9ce6610..cb4e83cec92 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -150,6 +150,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_launch_transdata.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_graph_kernel.cc" "../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc" + "../../../mindspore/ccsrc/runtime/device/context_extends.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_bucket.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_event.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_build_ascend.cc" @@ -163,7 +164,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_executor.cc" - "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_utils.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc" "../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_topo.cc" diff --git a/tests/ut/cpp/stub/runtime/ascend_depreacted_interface_stub.cc b/tests/ut/cpp/stub/runtime/ascend_depreacted_interface_stub.cc new file mode 100644 index 00000000000..1dcba773f19 --- /dev/null +++ b/tests/ut/cpp/stub/runtime/ascend_depreacted_interface_stub.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h" + +namespace mindspore { +namespace device { +namespace ascend { +void AscendDeprecatedInterface::DoExecNonInputGraph(const std::string &) {} +bool AscendDeprecatedInterface::InitExecDataset(const std::string &, int64_t, int64_t, const std::vector &, + const std::vector> &, const std::vector &, + const std::string &) { + return true; +} +void AscendDeprecatedInterface::ExportDFGraph(const std::string &, const std::string &, const pybind11::object &, + char *) {} +FuncGraphPtr AscendDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &, const pybind11::dict &) { return nullptr; } +void AscendDeprecatedInterface::ClearGraphWrapper() {} +void AscendDeprecatedInterface::ClearOpAdapterMap() {} +void AscendDeprecatedInterface::EraseGeResource() {} +// for ascend +uint32_t AscendDeprecatedInterface::InitCollective() { return 0; } +void AscendDeprecatedInterface::DumpProfileParallelStrategy(const FuncGraphPtr &) {} + +bool AscendDeprecatedInterface::OpenTsd(const std::shared_ptr &) { return true; } +bool AscendDeprecatedInterface::CloseTsd(const std::shared_ptr &, bool) { return true; } +bool AscendDeprecatedInterface::IsTsdOpened(const std::shared_ptr &) { return true; } +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/tests/ut/cpp/stub/runtime/runtime_stub.cc b/tests/ut/cpp/stub/runtime/runtime_stub.cc index 6f5bdfff91c..c2d5cc24f39 100644 --- a/tests/ut/cpp/stub/runtime/runtime_stub.cc +++ b/tests/ut/cpp/stub/runtime/runtime_stub.cc @@ -22,6 +22,7 @@ #include "runtime/rt_model.h" #include "runtime/stream.h" #include "toolchain/adx_datadump_server.h" +#include "toolchain/plog.h" rtError_t rtEventSynchronize(rtEvent_t event) { return RT_ERROR_NONE; } @@ -219,3 +220,7 @@ RTS_API rtError_t rtKernelLaunchWithHandle(void *hdl, const uint64_t tilingKey, const void *kernelInfo) { return RT_ERROR_NONE; } + +int DlogReportInitialize(void) { return 0; } + +int DlogReportFinalize(void) { return 0; }