delete ENABLE_D from opentsd/closetsd

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2022-07-28 16:54:36 +08:00
parent aac5d76186
commit d6589a68c5
22 changed files with 465 additions and 487 deletions

View File

@ -20,7 +20,6 @@
#include "cxx_api/akg_kernel_register.h"
#include "cxx_api/acl_utils.h"
#include "utils/log_adapter.h"
#include "runtime/device/context_extends.h"
#include "mindspore/core/base/base_ref_utils.h"
#include "backend/common/session/session_factory.h"
#include "backend/common/session/executor_manager.h"
@ -28,6 +27,7 @@
#include "runtime/dev.h"
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "include/common/utils/python_adapter.h"
#include "runtime/hardware/device_context_manager.h"
namespace mindspore {
namespace {
@ -47,12 +47,15 @@ void InitHccl() {
#ifndef ENABLE_SECURITY
runtime_instance->PreInit();
#endif
(void)context::OpenTsd(ms_context);
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
(void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context);
if (!runtime_instance->Init()) {
MS_LOG(EXCEPTION) << "Runtime init failed.";
}
} else {
(void)context::OpenTsd(ms_context);
}
}
@ -381,7 +384,11 @@ AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
PythonEnvGuard guard;
if (!context::CloseTsd(ms_context)) {
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
if (!device_context->GetDeprecatedInterface()->CloseTsd(ms_context)) {
MS_LOG(ERROR) << "CloseTsd failed!";
return;
}

View File

@ -44,7 +44,6 @@
#include "include/common/utils/config_manager.h"
#include "include/common/utils/convert_utils.h"
#include "include/common/utils/convert_utils_py.h"
#include "runtime/device/context_extends.h"
#include "utils/ms_context.h"
#include "utils/shape_utils.h"
#include "utils/info.h"
@ -1376,10 +1375,17 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
std::string name = MsContext::GetInstance()->backend_policy();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
#ifndef NO_DLIB
if (!context::IsTsdOpened(ms_context)) {
InitPipeline();
#ifdef WITH_BACKEND
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
if (!device_context->GetDeprecatedInterface()->IsTsdOpened(ms_context)) {
InitPipeline();
}
}
#endif
if (iter_num == -1) {
iter_num = INT32_MAX;
@ -1390,10 +1396,8 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
std::string backend = ms_context->backend_policy();
#ifdef WITH_BACKEND
if (backend == "ge") {
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET),
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
{ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET), ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
@ -1534,15 +1538,15 @@ void InitHccl() {
#endif
mindspore::python_adapter::set_python_env_flag(true);
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
if (common::UseMPI() && ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
if (common::UseMPI() && device_name == kAscendDevice) {
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET), ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
{device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
device_id = device_context->GetDeprecatedInterface()->InitCollective();
}
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
if (ms_context->backend_policy() == "ms" &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
@ -1551,12 +1555,14 @@ void InitHccl() {
#ifndef ENABLE_SECURITY
runtime_instance->PreInit();
#endif
(void)context::OpenTsd(ms_context);
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
(void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context);
if (!runtime_instance->Init()) {
MS_LOG(EXCEPTION) << "Runtime init failed.";
}
} else {
(void)context::OpenTsd(ms_context);
}
}
@ -1633,11 +1639,18 @@ FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_
return func_graph;
}
void ReleaseGeTsd() {
void CloseTsd(bool force) {
#ifdef WITH_BACKEND
auto context_ptr = MsContext::GetInstance();
if (context_ptr != nullptr) {
(void)context::CloseTsd(context_ptr, true);
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{kAscendDevice, context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
(void)device_context->GetDeprecatedInterface()->CloseTsd(context_ptr, force);
}
#endif
}
void InitPipeline() {
@ -1648,22 +1661,23 @@ void InitPipeline() {
MS_EXCEPTION_IF_NULL(ms_context);
#ifdef WITH_BACKEND
auto backend = ms_context->backend_policy();
auto device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
if (backend == "ge") {
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET), ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
device_context->Initialize();
}
#endif
if (!context::OpenTsd(ms_context)) {
MS_LOG(EXCEPTION) << "Open tsd failed";
if (device_name == kAscendDevice) {
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
if (!device_context->GetDeprecatedInterface()->OpenTsd(ms_context)) {
MS_LOG(EXCEPTION) << "Open tsd failed";
}
}
#endif
}
void FinalizeBackend() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
(void)context::CloseTsd(context_ptr);
}
void FinalizeBackend() { CloseTsd(); }
void MemoryRecycle() {
#ifdef ENABLE_DUMP_IR
@ -1768,8 +1782,6 @@ void ClearResAtexit() {
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
MS_LOG(INFO) << "End clear device context.";
ReleaseGeTsd();
MS_LOG(INFO) << "Start clear AnalysisResultCacheMgr...";
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
MS_LOG(INFO) << "End clear AnalysisResultCacheMgr.";

View File

@ -179,7 +179,7 @@ uint32_t GetHcclRankSize();
void InitPipeline();
void FinalizeBackend();
void ClearResAtexit();
void ReleaseGeTsd();
void CloseTsd(bool force = false);
void MemoryRecycle();
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode,

View File

@ -38,7 +38,6 @@
#include "include/common/utils/utils.h"
#include "utils/ms_context.h"
#include "utils/check_convert_utils.h"
#include "runtime/device/context_extends.h"
#include "include/common/utils/config_manager.h"
#include "include/common/utils/convert_utils_py.h"
#include "include/common/utils/scoped_long_running.h"
@ -2505,11 +2504,18 @@ MsBackendPolicy ForwardExecutor::GetBackendPolicy(const OpExecInfoPtr &op_exec_i
if (ms_context->backend_policy() == "ge") {
MS_LOG(EXCEPTION) << "In PyNative mode, not support ge backend!";
}
if (!context::IsTsdOpened(ms_context)) {
if (!context::OpenTsd(ms_context)) {
#ifdef WITH_BACKEND
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
if (!device_context->GetDeprecatedInterface()->IsTsdOpened(ms_context)) {
if (!device_context->GetDeprecatedInterface()->OpenTsd(ms_context)) {
MS_LOG(EXCEPTION) << "Open tsd failed";
}
}
#endif
}
return backend_policy;
}

View File

@ -25,7 +25,6 @@
#include "plugin/device/ascend/hal/device/ascend_device_address.h"
#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h"
#include "utils/ms_context.h"
#include "runtime/device/context_extends.h"
#include "include/common/utils/mpi/mpi_config.h"
#include "runtime/device/ms_device_shape_transfer.h"
#include "runtime/rt.h"
@ -1275,9 +1274,6 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
bool AscendKernelRuntime::HcclInit() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context::IsTsdOpened(context_ptr)) {
MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
}
MS_LOG(INFO) << "Do hcom init.";
auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
std::string rank_id_str = GetRankIdStr();

View File

@ -13,23 +13,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/common/utils/tensorprint_utils.h"
#include "plugin/device/ascend/hal/device/tensorprint_utils.h"
#include <atomic>
#include <fstream>
#include <memory>
#include <string>
#include <vector>
#include "ir/tensor.h"
#include "pybind11/pybind11.h"
#include "include/common/utils/utils.h"
#include "utils/ms_utils.h"
#include "utils/shape_utils.h"
#include "mindspore/core/utils/file_utils.h"
namespace py = pybind11;
namespace mindspore {
#ifndef NO_DLIB
static std::map<aclDataType, TypeId> print_acl_data_type_map = {
{ACL_INT8, TypeId::kNumberTypeInt8}, {ACL_UINT8, TypeId::kNumberTypeUInt8},
{ACL_INT16, TypeId::kNumberTypeInt16}, {ACL_UINT16, TypeId::kNumberTypeUInt16},
@ -75,7 +71,7 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
}
template <typename T>
void PrintScalarToString(const char *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
void PrintScalarToString(const void *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
*buf << "Tensor(shape=[], dtype=" << GetParseType(acl_data_type) << ", value=";
@ -101,13 +97,13 @@ void PrintScalarToBoolString(const char *str_data_ptr, const aclDataType &acl_da
}
}
void convertDataItem2Scalar(const char *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
void ConvertDataItem2Scalar(const void *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
auto type_iter = print_acl_data_type_map.find(acl_data_type);
auto type_id = type_iter->second;
if (type_id == TypeId::kNumberTypeBool) {
PrintScalarToBoolString(str_data_ptr, acl_data_type, buf);
PrintScalarToBoolString(reinterpret_cast<const char *>(str_data_ptr), acl_data_type, buf);
} else if (type_id == TypeId::kNumberTypeInt8) {
PrintScalarToString<int8_t>(str_data_ptr, acl_data_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt8) {
@ -178,7 +174,7 @@ bool ConvertDataset2Tensor(acltdtDataset *acl_dataset) {
if (!judgeLengthValid(acl_data_size, acl_data_type)) {
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
}
convertDataItem2Scalar(acl_data, acl_data_type, &buf);
ConvertDataItem2Scalar(reinterpret_cast<void *>(acl_data), acl_data_type, &buf);
continue;
}
@ -210,7 +206,6 @@ bool SaveDataset2File(acltdtDataset *acl_dataset, const std::string &print_file_
acltdtDataItem *item = acltdtGetDataItem(acl_dataset, i);
MS_EXCEPTION_IF_NULL(item);
acltdtTensorType acl_tensor_type = acltdtGetTensorTypeFromItem(item);
if (acl_tensor_type == ACL_TENSOR_DATA_END_OF_SEQUENCE) {
MS_LOG(INFO) << "Acl channel received end-of-sequence for print op.";
ret_end_thread = true;
@ -358,5 +353,4 @@ void TensorPrint::operator()() {
TensorPrintOut2File(acl_handle_, print_file_path_);
}
}
#endif
} // namespace mindspore

View File

@ -14,13 +14,12 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_
#define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_
#include <map>
#include <string>
#include "ir/dtype/type.h"
#ifndef NO_DLIB
#include "acl/acl_tdt.h"
#include "tdt/tsd_client.h"
#include "tdt/data_common.h"
@ -41,5 +40,4 @@ class COMMON_EXPORT TensorPrint {
const acltdtChannelHandle *acl_handle_;
};
} // namespace mindspore
#endif
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_

View File

@ -15,12 +15,207 @@
*/
#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
#include <algorithm>
#include "plugin/device/ascend/hal/hardware/ge_device_context.h"
#include "include/transform/graph_ir/types.h"
#include "include/transform/graph_ir/utils.h"
#include "include/common/utils/scoped_long_running.h"
#include "graph/model.h"
#include "transform/graph_ir/op_adapter_map.h"
#include "plugin/device/ascend/hal/device/tensorprint_utils.h"
#include "acl/acl_tdt.h"
#include "runtime/dev.h"
#include "toolchain/plog.h"
#include "common/util/error_manager/error_manager.h"
#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h"
#include "plugin/device/ascend/hal/profiler/parallel_strategy_profiling.h"
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
using mindspore::transform::GeTensorPtr;
using mindspore::transform::MeTensorPtr;
using mindspore::transform::Status;
namespace py = pybind11;
namespace mindspore {
namespace device {
namespace ascend {
namespace {
constexpr auto kUnknowErrorString = "Unknown error occurred";
void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *const tensors) {
for (auto item : dict) {
if ((!py::isinstance<py::str>(item.first))) {
MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it.";
continue;
}
std::shared_ptr<tensor::Tensor> tensor;
std::string name = py::cast<std::string>(item.first);
if (py::isinstance<py::float_>(item.second.attr("data"))) {
// convert float to tensor with shape([1])
tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, std::vector<int64_t>({1}));
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("data"));
} else if (py::isinstance<py::int_>(item.second.attr("data"))) {
// convert int64_t to tensor with shape([1])
tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>({1}));
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("data"));
} else if (py::isinstance<tensor::Tensor>(item.second.attr("data"))) {
// cast tensor
tensor = py::cast<std::shared_ptr<tensor::Tensor>>(item.second.attr("data"));
}
if (tensor == nullptr) {
MS_LOG(EXCEPTION) << "Get default value for " << name << " failed";
}
(void)tensors->emplace(name, tensor);
}
}
} // namespace
void AscendDeprecatedInterface::DoExecNonInputGraph(const std::string &phase) {
std::vector<GeTensorPtr> ge_tensors;
std::vector<GeTensorPtr> ge_outputs;
transform::RunOptions run_options;
run_options.name = phase;
auto graph_runner = transform::GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(ERROR) << "Can not found GraphRunner";
return;
}
{
// Release GIL before calling into (potentially long-running) C++ code
ScopedLongRunning release;
Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs);
if (ret != Status::SUCCESS) {
MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed";
return;
}
}
}
bool AscendDeprecatedInterface::InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types,
const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
ge_device_context_->Initialize();
std::vector<int64_t> ge_types;
(void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types),
[](const TypePtr &i) -> int64_t { return transform::ConvertDataType(i->type_id()); });
ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE);
ConfigManager::GetInstance().set_iter_num(queue_name, size);
ConfigManager::GetInstance().set_dataset_phase(phase);
DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes);
ConfigManager::GetInstance().set_dataset_param(param);
auto env_ge = common::GetEnv("MS_ENABLE_GE");
auto env_training = common::GetEnv("MS_GE_TRAIN");
bool training = false;
if (env_ge == "1" && env_training == "1") {
training = true;
}
if (training) {
(void)setenv("GE_TRAIN", "1", 1);
} else {
(void)setenv("GE_TRAIN", "0", 1);
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
if (transform::CompileDatasetGraph(param, phase) != transform::SUCCESS) {
MS_LOG(ERROR) << "Build dateset graph failed.";
return false;
}
GeDeviceResManager::CreateSessionAndGraphRunner(training);
MS_LOG(INFO) << "DoExecNonInputGraph:" << phase;
DoExecNonInputGraph(phase);
}
return true;
}
void AscendDeprecatedInterface::ExportDFGraph(const std::string &file_name, const std::string &phase,
const py::object &encrypt, char *key) {
MS_LOG(DEBUG) << "Export graph begin.";
transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase);
if (wrap_ptr == nullptr) {
MS_LOG(ERROR) << "Get graph form DfGraphManager failed, phase = " << phase;
return;
}
transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_;
if (ge_graph == nullptr) {
MS_LOG(ERROR) << "Graph is null!";
return;
}
if (key != nullptr) {
if (py::isinstance<py::none()>(encrypt)) {
MS_LOG(ERROR) << "ERROR: encrypt is not a function";
return;
}
// get model stream
ge::Model model("", "");
model.SetGraph(*ge_graph);
ge::Buffer model_data;
auto ge_ret = model.Save(model_data);
if (ge_ret != ge::SUCCESS) {
MS_LOG(ERROR) << "ERROR: GE model save fail";
return;
}
// convert model and key into py::bytes
const std::string str(reinterpret_cast<char *>(model_data.GetData()), model_data.GetSize());
py::bytes model_bytes(str);
py::bytes key_bytes(key);
// call python encrypt func
py::bytes encrypted_model_stream = encrypt(model_bytes, key_bytes);
if (encrypted_model_stream == py::none()) {
MS_LOG(ERROR) << "ERROR: Model encrypt fail";
return;
}
// save to file
std::ofstream ofs(file_name);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "ERROR: Open File '" << file_name << "' failed!";
return;
}
ofs << std::string(encrypted_model_stream);
ofs.close();
} else {
if (ge_graph->SaveToFile(file_name) != 0) {
MS_LOG(EXCEPTION) << "Export air model failed.";
}
}
MS_LOG(INFO) << "Export air model finish.";
}
FuncGraphPtr AscendDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) {
MS_EXCEPTION_IF_NULL(anf_graph);
transform::TensorOrderMap init_tensors{};
ConvertObjectToTensors(init_params, &init_tensors);
return GeGraphExecutor::BuildDFGraph(anf_graph, init_tensors, true);
}
void AscendDeprecatedInterface::ClearGraphWrapper() { transform::DfGraphManager::GetInstance().ClearGraph(); }
void AscendDeprecatedInterface::ClearOpAdapterMap() { transform::OpAdapterMap::get().clear(); }
void AscendDeprecatedInterface::EraseGeResource() {
transform::DfGraphManager::GetInstance().DeleteGraphRunner();
transform::DfGraphManager::GetInstance().EraseAnfGraph();
transform::DfGraphManager::GetInstance().DeleteGeSession();
}
uint32_t AscendDeprecatedInterface::InitCollective() {
#ifdef WITH_BACKEND
auto ms_context = MsContext::GetInstance();
@ -45,6 +240,112 @@ uint32_t AscendDeprecatedInterface::InitCollective() {
void AscendDeprecatedInterface::DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) {
return profiler::ascend::DumpProfileParallelStrategy(func_graph);
}
bool AscendDeprecatedInterface::OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
return true;
}
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) != 0) {
MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
return true;
}
auto role = common::GetEnv("MS_ROLE");
if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
return true;
}
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
uint32_t rank_size;
auto rank_size_env = common::GetEnv("RANK_SIZE");
if (rank_size_env.empty()) {
MS_LOG(INFO) << "Should config rank size.";
rank_size = 1;
} else {
int rank_env = std::stoi(rank_size_env);
if (rank_env <= 0) {
MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
}
rank_size = IntToUint(rank_env);
}
int log_ret = DlogReportInitialize();
if (log_ret != 0) {
MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret;
}
if (ErrorManager::GetInstance().Init() != 0) {
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
}
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
auto ret = rtSetDevice(static_cast<int32_t>(device_id));
if (ret != RT_ERROR_NONE) {
const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
}
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
}
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
#ifdef ENABLE_TDTQUE
auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) {
return std::thread(TensorPrint(path, acl_handle));
};
ms_context_ptr->CreateTensorPrintThread(thread_crt);
#endif
return true;
}
bool AscendDeprecatedInterface::CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
}
MS_LOG(INFO) << "Start to close tsd, ref = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF);
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
return true;
}
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
#ifdef ENABLE_TDTQUE
pybind11::gil_scoped_release gil_release;
ms_context_ptr->DestroyTensorPrintThread();
#endif
if (ErrorManager::GetInstance().Init() != 0) {
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
}
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto ret = rtDeviceReset(static_cast<int32_t>(device_id));
if (ret != RT_ERROR_NONE) {
const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
}
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
}
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
(void)DlogReportFinalize();
} else {
MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
}
return true;
}
bool AscendDeprecatedInterface::IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -27,21 +27,36 @@
namespace mindspore {
namespace device {
namespace ascend {
class AscendDeviceContext;
class GeDeviceContext;
class AscendDeprecatedInterface : public DeprecatedInterface {
public:
explicit AscendDeprecatedInterface(AscendDeviceContext *ascend_device_context)
: ascend_device_context_(ascend_device_context) {}
explicit AscendDeprecatedInterface(GeDeviceContext *ge_device_context) : ge_device_context_(ge_device_context) {}
~AscendDeprecatedInterface() override = default;
// for ge
void DoExecNonInputGraph(const std::string &phase) override;
bool InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase) override;
void ExportDFGraph(const std::string &file_name, const std::string &phase, const pybind11::object &encrypt,
char *key) override;
FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) override;
void ClearGraphWrapper() override;
void ClearOpAdapterMap() override;
void EraseGeResource() override;
// for ascend
uint32_t InitCollective() override;
void DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) override;
bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) override;
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) override;
bool IsTsdOpened(const std::shared_ptr<MsContext> &inst_context) override;
private:
AscendDeviceContext *const ascend_device_context_;
GeDeviceContext *const ge_device_context_;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_DEPRECATED_INTERFACE_H_

View File

@ -65,6 +65,9 @@ void AscendDeviceContext::Destroy() {
#endif
MS_LOG(INFO) << "Status record: Enter Destroy...";
if (!initialized_) {
if (deprecated_interface_ != nullptr) {
deprecated_interface_->CloseTsd(MsContext::GetInstance(), true);
}
return;
}
@ -77,6 +80,9 @@ void AscendDeviceContext::Destroy() {
if (runtime_instance_) {
runtime_instance_ = nullptr;
}
if (deprecated_interface_ != nullptr) {
deprecated_interface_->CloseTsd(MsContext::GetInstance(), true);
}
initialized_ = false;
MS_LOG(INFO) << "Status record: Destroy success.";
}
@ -101,7 +107,7 @@ RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const {
DeprecatedInterface *AscendDeviceContext::GetDeprecatedInterface() {
// need lock when multi-threads
if (deprecated_interface_ == nullptr) {
deprecated_interface_ = std::make_unique<AscendDeprecatedInterface>(this);
deprecated_interface_ = std::make_unique<AscendDeprecatedInterface>(nullptr);
}
return deprecated_interface_.get();
}

View File

@ -1,211 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/hal/hardware/ge_deprecated_interface.h"
#include <algorithm>
#include "plugin/device/ascend/hal/hardware/ge_device_context.h"
#include "include/transform/graph_ir/types.h"
#include "include/transform/graph_ir/utils.h"
#include "include/common/utils/scoped_long_running.h"
#include "graph/model.h"
#include "transform/graph_ir/op_adapter_map.h"
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
using mindspore::transform::GeTensorPtr;
using mindspore::transform::MeTensorPtr;
using mindspore::transform::Status;
namespace py = pybind11;
namespace mindspore {
namespace device {
namespace ascend {
namespace {
void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *const tensors) {
for (auto item : dict) {
if ((!py::isinstance<py::str>(item.first))) {
MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it.";
continue;
}
std::shared_ptr<tensor::Tensor> tensor;
std::string name = py::cast<std::string>(item.first);
if (py::isinstance<py::float_>(item.second.attr("data"))) {
// convert float to tensor with shape([1])
tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, std::vector<int64_t>({1}));
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("data"));
} else if (py::isinstance<py::int_>(item.second.attr("data"))) {
// convert int64_t to tensor with shape([1])
tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>({1}));
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("data"));
} else if (py::isinstance<tensor::Tensor>(item.second.attr("data"))) {
// cast tensor
tensor = py::cast<std::shared_ptr<tensor::Tensor>>(item.second.attr("data"));
}
if (tensor == nullptr) {
MS_LOG(EXCEPTION) << "Get default value for " << name << " failed";
}
(void)tensors->emplace(name, tensor);
}
}
} // namespace
void GeDeprecatedInterface::DoExecNonInputGraph(const std::string &phase) {
std::vector<GeTensorPtr> ge_tensors;
std::vector<GeTensorPtr> ge_outputs;
transform::RunOptions run_options;
run_options.name = phase;
auto graph_runner = transform::GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(ERROR) << "Can not found GraphRunner";
return;
}
{
// Release GIL before calling into (potentially long-running) C++ code
ScopedLongRunning release;
Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs);
if (ret != Status::SUCCESS) {
MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed";
return;
}
}
}
bool GeDeprecatedInterface::InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types,
const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
ge_device_context_->Initialize();
std::vector<int64_t> ge_types;
(void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types),
[](const TypePtr &i) -> int64_t { return transform::ConvertDataType(i->type_id()); });
ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE);
ConfigManager::GetInstance().set_iter_num(queue_name, size);
ConfigManager::GetInstance().set_dataset_phase(phase);
DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes);
ConfigManager::GetInstance().set_dataset_param(param);
auto env_ge = common::GetEnv("MS_ENABLE_GE");
auto env_training = common::GetEnv("MS_GE_TRAIN");
bool training = false;
if (env_ge == "1" && env_training == "1") {
training = true;
}
if (training) {
(void)setenv("GE_TRAIN", "1", 1);
} else {
(void)setenv("GE_TRAIN", "0", 1);
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
if (transform::CompileDatasetGraph(param, phase) != transform::SUCCESS) {
MS_LOG(ERROR) << "Build dateset graph failed.";
return false;
}
GeDeviceResManager::CreateSessionAndGraphRunner(training);
MS_LOG(INFO) << "DoExecNonInputGraph:" << phase;
DoExecNonInputGraph(phase);
}
return true;
}
void GeDeprecatedInterface::ExportDFGraph(const std::string &file_name, const std::string &phase,
const py::object &encrypt, char *key) {
MS_LOG(DEBUG) << "Export graph begin.";
transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase);
if (wrap_ptr == nullptr) {
MS_LOG(ERROR) << "Get graph form DfGraphManager failed, phase = " << phase;
return;
}
transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_;
if (ge_graph == nullptr) {
MS_LOG(ERROR) << "Graph is null!";
return;
}
if (key != nullptr) {
if (py::isinstance<py::none()>(encrypt)) {
MS_LOG(ERROR) << "ERROR: encrypt is not a function";
return;
}
// get model stream
ge::Model model("", "");
model.SetGraph(*ge_graph);
ge::Buffer model_data;
auto ge_ret = model.Save(model_data);
if (ge_ret != ge::SUCCESS) {
MS_LOG(ERROR) << "ERROR: GE model save fail";
return;
}
// convert model and key into py::bytes
const std::string str(reinterpret_cast<char *>(model_data.GetData()), model_data.GetSize());
py::bytes model_bytes(str);
py::bytes key_bytes(key);
// call python encrypt func
py::bytes encrypted_model_stream = encrypt(model_bytes, key_bytes);
if (encrypted_model_stream == py::none()) {
MS_LOG(ERROR) << "ERROR: Model encrypt fail";
return;
}
// save to file
std::ofstream ofs(file_name);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "ERROR: Open File '" << file_name << "' failed!";
return;
}
ofs << std::string(encrypted_model_stream);
ofs.close();
} else {
if (ge_graph->SaveToFile(file_name) != 0) {
MS_LOG(EXCEPTION) << "Export air model failed.";
}
}
MS_LOG(INFO) << "Export air model finish.";
}
FuncGraphPtr GeDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) {
MS_EXCEPTION_IF_NULL(anf_graph);
transform::TensorOrderMap init_tensors{};
ConvertObjectToTensors(init_params, &init_tensors);
return GeGraphExecutor::BuildDFGraph(anf_graph, init_tensors, true);
}
void GeDeprecatedInterface::ClearGraphWrapper() { transform::DfGraphManager::GetInstance().ClearGraph(); }
void GeDeprecatedInterface::ClearOpAdapterMap() { transform::OpAdapterMap::get().clear(); }
void GeDeprecatedInterface::EraseGeResource() {
transform::DfGraphManager::GetInstance().DeleteGraphRunner();
transform::DfGraphManager::GetInstance().EraseAnfGraph();
transform::DfGraphManager::GetInstance().DeleteGeSession();
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -1,53 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_
#include <vector>
#include <memory>
#include <string>
#include <map>
#include "runtime/hardware/device_context.h"
#include "runtime/device/memory_manager.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace device {
namespace ascend {
class GeDeviceContext;
class GeDeprecatedInterface : public DeprecatedInterface {
public:
explicit GeDeprecatedInterface(GeDeviceContext *ge_device_context) : ge_device_context_(ge_device_context) {}
void DoExecNonInputGraph(const std::string &phase) override;
bool InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase) override;
void ExportDFGraph(const std::string &file_name, const std::string &phase, const pybind11::object &encrypt,
char *key) override;
FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) override;
void ClearGraphWrapper() override;
void ClearOpAdapterMap() override;
void EraseGeResource() override;
private:
GeDeviceContext *const ge_device_context_;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_

View File

@ -427,7 +427,12 @@ void GeDeviceContext::Initialize() {
initialized_ = InitGe(MsContext::GetInstance());
}
void GeDeviceContext::Destroy() { (void)FinalizeGe(MsContext::GetInstance()); }
void GeDeviceContext::Destroy() {
(void)FinalizeGe(MsContext::GetInstance());
if (deprecated_interface_ != nullptr) {
deprecated_interface_->CloseTsd(MsContext::GetInstance(), true);
}
}
void GeDeviceResManager::Initialize() {
if (mem_manager_ == nullptr) {
@ -748,7 +753,7 @@ FuncGraphPtr GeGraphExecutor::BuildDFGraph(const FuncGraphPtr &anf_graph,
DeprecatedInterface *GeDeviceContext::GetDeprecatedInterface() {
// need lock when multi-threads
if (deprecated_interface_ == nullptr) {
deprecated_interface_ = std::make_unique<GeDeprecatedInterface>(this);
deprecated_interface_ = std::make_unique<AscendDeprecatedInterface>(this);
}
return deprecated_interface_.get();
}

View File

@ -20,7 +20,7 @@
#include <memory>
#include <string>
#include <map>
#include "plugin/device/ascend/hal/hardware/ge_deprecated_interface.h"
#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
#include "runtime/hardware/device_context.h"
#include "runtime/device/memory_manager.h"
#include "utils/ms_context.h"
@ -93,7 +93,7 @@ class GeDeviceContext : public DeviceInterface<GeGraphExecutor, GeDeviceResManag
void SetHcclOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options);
std::unique_ptr<GeDeprecatedInterface> deprecated_interface_;
std::unique_ptr<AscendDeprecatedInterface> deprecated_interface_;
bool initialized_;
};
} // namespace ascend

View File

@ -15,151 +15,14 @@
*/
#include "runtime/device/context_extends.h"
#include <cstdlib>
#include <string>
#include <memory>
#include <thread>
#include <vector>
#include "pybind11/pybind11.h"
#include "include/common/utils/config_manager.h"
#include "utils/ms_utils.h"
#include "utils/convert_utils_base.h"
#ifndef NO_DLIB
#include "acl/acl_tdt.h"
#include "runtime/dev.h"
#include "toolchain/plog.h"
#include "common/util/error_manager/error_manager.h"
#endif
#ifdef ENABLE_D
#include "debug/data_dump/dump_json_parser.h"
#include "include/transform/graph_ir/utils.h"
#endif
#include "profiler/device/profiling.h"
namespace py = pybind11;
#include "utils/ms_context.h"
namespace mindspore {
namespace context {
#ifdef ENABLE_D
namespace {
constexpr auto kMindsporeDumpConfig = "MINDSPORE_DUMP_CONFIG";
const std::vector<std::string> kGeDumpMode = {"all", "input", "output"};
} // namespace
#endif
constexpr auto kUnknowErrorString = "Unknown error occurred";
#ifndef NO_DLIB
// Open tdt dataset
bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
return true;
}
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) != 0) {
MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
return true;
}
auto role = common::GetEnv("MS_ROLE");
if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
return true;
}
uint32_t rank_size = 1;
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto rank_size_env = common::GetEnv("RANK_SIZE");
if (rank_size_env.empty()) {
MS_LOG(INFO) << "Should config rank size.";
rank_size = 1;
} else {
int rank_env = std::stoi(rank_size_env);
if (rank_env <= 0) {
MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
}
rank_size = IntToUint(rank_env);
}
int log_ret = DlogReportInitialize();
if (log_ret != 0) {
MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret;
}
if (ErrorManager::GetInstance().Init() != 0) {
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
}
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
auto ret = rtSetDevice(static_cast<int32_t>(device_id));
if (ret != RT_ERROR_NONE) {
const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
}
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
}
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
#ifdef ENABLE_TDTQUE
auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) {
return std::thread(TensorPrint(path, acl_handle));
};
ms_context_ptr->CreateTensorPrintThread(thread_crt);
#endif
return true;
}
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
}
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
return true;
}
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
#ifdef ENABLE_TDTQUE
py::gil_scoped_release gil_release;
ms_context_ptr->DestroyTensorPrintThread();
#endif
if (ErrorManager::GetInstance().Init() != 0) {
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
}
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto ret = rtDeviceReset(static_cast<int32_t>(device_id));
if (ret != RT_ERROR_NONE) {
const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
}
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
}
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
(void)DlogReportFinalize();
} else {
MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
}
return true;
}
#else
bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool) { return true; }
#endif
bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
}
// Register for device type.
struct DeviceTypeSetRegister {
DeviceTypeSetRegister() {

View File

@ -17,18 +17,6 @@
#ifndef MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H
#define MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H
#include <memory>
#include "utils/ms_context.h"
#include "include/common/utils/tensorprint_utils.h"
#include "include/backend/visible.h"
namespace mindspore {
namespace context {
BACKEND_EXPORT bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr);
BACKEND_EXPORT bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force = false);
BACKEND_EXPORT bool IsTsdOpened(const std::shared_ptr<MsContext> &inst_context);
} // namespace context
} // namespace mindspore
// this file makes static-code-checking happy
#endif // MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H

View File

@ -19,7 +19,9 @@
#include <vector>
#include <string>
#include <memory>
#include "pybind11/pybind11.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace device {
@ -47,7 +49,9 @@ class DeprecatedInterface {
// ascend
virtual uint32_t InitCollective() { return 0; } // return device id
virtual void DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) {}
virtual bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
virtual bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force = false) { return true; }
virtual bool IsTsdOpened(const std::shared_ptr<MsContext> &inst_context) { return true; }
// gpu
};
} // namespace device

View File

@ -147,8 +147,8 @@ class MS_CORE_API MsContext {
~MsContext() = default;
MsContext(const MsContext &) = delete;
MsContext &operator=(const MsContext &) = delete;
using DeviceSeter = std::function<void(const std::string &device_target)>;
using DeviceTypeSeter = std::function<void(std::shared_ptr<MsContext> &)>;
using DeviceSeter = void (*)(const std::string &device_target);
using DeviceTypeSeter = void (*)(std::shared_ptr<MsContext> &);
static std::shared_ptr<MsContext> GetInstance();
bool enable_dump_ir() const;

View File

@ -150,6 +150,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_launch_transdata.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_graph_kernel.cc"
"../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc"
"../../../mindspore/ccsrc/runtime/device/context_extends.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_bucket.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_event.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_build_ascend.cc"
@ -163,7 +164,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_executor.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_utils.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_topo.cc"

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
namespace mindspore {
namespace device {
namespace ascend {
void AscendDeprecatedInterface::DoExecNonInputGraph(const std::string &) {}
bool AscendDeprecatedInterface::InitExecDataset(const std::string &, int64_t, int64_t, const std::vector<TypePtr> &,
const std::vector<std::vector<int64_t>> &, const std::vector<int64_t> &,
const std::string &) {
return true;
}
void AscendDeprecatedInterface::ExportDFGraph(const std::string &, const std::string &, const pybind11::object &,
char *) {}
FuncGraphPtr AscendDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &, const pybind11::dict &) { return nullptr; }
void AscendDeprecatedInterface::ClearGraphWrapper() {}
void AscendDeprecatedInterface::ClearOpAdapterMap() {}
void AscendDeprecatedInterface::EraseGeResource() {}
// for ascend
uint32_t AscendDeprecatedInterface::InitCollective() { return 0; }
void AscendDeprecatedInterface::DumpProfileParallelStrategy(const FuncGraphPtr &) {}
bool AscendDeprecatedInterface::OpenTsd(const std::shared_ptr<MsContext> &) { return true; }
bool AscendDeprecatedInterface::CloseTsd(const std::shared_ptr<MsContext> &, bool) { return true; }
bool AscendDeprecatedInterface::IsTsdOpened(const std::shared_ptr<MsContext> &) { return true; }
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include "runtime/rt_model.h"
#include "runtime/stream.h"
#include "toolchain/adx_datadump_server.h"
#include "toolchain/plog.h"
rtError_t rtEventSynchronize(rtEvent_t event) { return RT_ERROR_NONE; }
@ -219,3 +220,7 @@ RTS_API rtError_t rtKernelLaunchWithHandle(void *hdl, const uint64_t tilingKey,
const void *kernelInfo) {
return RT_ERROR_NONE;
}
int DlogReportInitialize(void) { return 0; }
int DlogReportFinalize(void) { return 0; }