delete ENABLE_D from opentsd/closetsd
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
aac5d76186
commit
d6589a68c5
|
@ -20,7 +20,6 @@
|
|||
#include "cxx_api/akg_kernel_register.h"
|
||||
#include "cxx_api/acl_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "runtime/device/context_extends.h"
|
||||
#include "mindspore/core/base/base_ref_utils.h"
|
||||
#include "backend/common/session/session_factory.h"
|
||||
#include "backend/common/session/executor_manager.h"
|
||||
|
@ -28,6 +27,7 @@
|
|||
#include "runtime/dev.h"
|
||||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
|
@ -47,12 +47,15 @@ void InitHccl() {
|
|||
#ifndef ENABLE_SECURITY
|
||||
runtime_instance->PreInit();
|
||||
#endif
|
||||
(void)context::OpenTsd(ms_context);
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
(void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context);
|
||||
|
||||
if (!runtime_instance->Init()) {
|
||||
MS_LOG(EXCEPTION) << "Runtime init failed.";
|
||||
}
|
||||
} else {
|
||||
(void)context::OpenTsd(ms_context);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -381,7 +384,11 @@ AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
|
|||
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
||||
PythonEnvGuard guard;
|
||||
if (!context::CloseTsd(ms_context)) {
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
if (!device_context->GetDeprecatedInterface()->CloseTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "CloseTsd failed!";
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@
|
|||
#include "include/common/utils/config_manager.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "runtime/device/context_extends.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
|
@ -1376,10 +1375,17 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
|
|||
std::string name = MsContext::GetInstance()->backend_policy();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
#ifndef NO_DLIB
|
||||
if (!context::IsTsdOpened(ms_context)) {
|
||||
InitPipeline();
|
||||
#ifdef WITH_BACKEND
|
||||
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
if (!device_context->GetDeprecatedInterface()->IsTsdOpened(ms_context)) {
|
||||
InitPipeline();
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
if (iter_num == -1) {
|
||||
iter_num = INT32_MAX;
|
||||
|
@ -1390,10 +1396,8 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
|
|||
std::string backend = ms_context->backend_policy();
|
||||
#ifdef WITH_BACKEND
|
||||
if (backend == "ge") {
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET),
|
||||
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
{ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET), ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
|
||||
|
@ -1534,15 +1538,15 @@ void InitHccl() {
|
|||
#endif
|
||||
|
||||
mindspore::python_adapter::set_python_env_flag(true);
|
||||
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
if (common::UseMPI() && ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
if (common::UseMPI() && device_name == kAscendDevice) {
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET), ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
{device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
device_id = device_context->GetDeprecatedInterface()->InitCollective();
|
||||
}
|
||||
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
|
||||
if (ms_context->backend_policy() == "ms" &&
|
||||
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
|
@ -1551,12 +1555,14 @@ void InitHccl() {
|
|||
#ifndef ENABLE_SECURITY
|
||||
runtime_instance->PreInit();
|
||||
#endif
|
||||
(void)context::OpenTsd(ms_context);
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
(void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context);
|
||||
if (!runtime_instance->Init()) {
|
||||
MS_LOG(EXCEPTION) << "Runtime init failed.";
|
||||
}
|
||||
} else {
|
||||
(void)context::OpenTsd(ms_context);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1633,11 +1639,18 @@ FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
void ReleaseGeTsd() {
|
||||
void CloseTsd(bool force) {
|
||||
#ifdef WITH_BACKEND
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
if (context_ptr != nullptr) {
|
||||
(void)context::CloseTsd(context_ptr, true);
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{kAscendDevice, context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
(void)device_context->GetDeprecatedInterface()->CloseTsd(context_ptr, force);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void InitPipeline() {
|
||||
|
@ -1648,22 +1661,23 @@ void InitPipeline() {
|
|||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
#ifdef WITH_BACKEND
|
||||
auto backend = ms_context->backend_policy();
|
||||
auto device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if (backend == "ge") {
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET), ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
device_context->Initialize();
|
||||
}
|
||||
#endif
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
MS_LOG(EXCEPTION) << "Open tsd failed";
|
||||
if (device_name == kAscendDevice) {
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
if (!device_context->GetDeprecatedInterface()->OpenTsd(ms_context)) {
|
||||
MS_LOG(EXCEPTION) << "Open tsd failed";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void FinalizeBackend() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
(void)context::CloseTsd(context_ptr);
|
||||
}
|
||||
void FinalizeBackend() { CloseTsd(); }
|
||||
|
||||
void MemoryRecycle() {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
|
@ -1768,8 +1782,6 @@ void ClearResAtexit() {
|
|||
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
|
||||
MS_LOG(INFO) << "End clear device context.";
|
||||
|
||||
ReleaseGeTsd();
|
||||
|
||||
MS_LOG(INFO) << "Start clear AnalysisResultCacheMgr...";
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
MS_LOG(INFO) << "End clear AnalysisResultCacheMgr.";
|
||||
|
|
|
@ -179,7 +179,7 @@ uint32_t GetHcclRankSize();
|
|||
void InitPipeline();
|
||||
void FinalizeBackend();
|
||||
void ClearResAtexit();
|
||||
void ReleaseGeTsd();
|
||||
void CloseTsd(bool force = false);
|
||||
void MemoryRecycle();
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode,
|
||||
|
|
|
@ -38,7 +38,6 @@
|
|||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "runtime/device/context_extends.h"
|
||||
#include "include/common/utils/config_manager.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
|
@ -2505,11 +2504,18 @@ MsBackendPolicy ForwardExecutor::GetBackendPolicy(const OpExecInfoPtr &op_exec_i
|
|||
if (ms_context->backend_policy() == "ge") {
|
||||
MS_LOG(EXCEPTION) << "In PyNative mode, not support ge backend!";
|
||||
}
|
||||
if (!context::IsTsdOpened(ms_context)) {
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
#ifdef WITH_BACKEND
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
|
||||
|
||||
if (!device_context->GetDeprecatedInterface()->IsTsdOpened(ms_context)) {
|
||||
if (!device_context->GetDeprecatedInterface()->OpenTsd(ms_context)) {
|
||||
MS_LOG(EXCEPTION) << "Open tsd failed";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return backend_policy;
|
||||
}
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "plugin/device/ascend/hal/device/ascend_device_address.h"
|
||||
#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "runtime/device/context_extends.h"
|
||||
#include "include/common/utils/mpi/mpi_config.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "runtime/rt.h"
|
||||
|
@ -1275,9 +1274,6 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
|
|||
bool AscendKernelRuntime::HcclInit() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context::IsTsdOpened(context_ptr)) {
|
||||
MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
|
||||
}
|
||||
MS_LOG(INFO) << "Do hcom init.";
|
||||
auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
std::string rank_id_str = GetRankIdStr();
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
ascend
|
|
@ -13,23 +13,19 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/common/utils/tensorprint_utils.h"
|
||||
#include "plugin/device/ascend/hal/device/tensorprint_utils.h"
|
||||
#include <atomic>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/tensor.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "mindspore/core/utils/file_utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
|
||||
#ifndef NO_DLIB
|
||||
static std::map<aclDataType, TypeId> print_acl_data_type_map = {
|
||||
{ACL_INT8, TypeId::kNumberTypeInt8}, {ACL_UINT8, TypeId::kNumberTypeUInt8},
|
||||
{ACL_INT16, TypeId::kNumberTypeInt16}, {ACL_UINT16, TypeId::kNumberTypeUInt16},
|
||||
|
@ -75,7 +71,7 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintScalarToString(const char *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
|
||||
void PrintScalarToString(const void *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
*buf << "Tensor(shape=[], dtype=" << GetParseType(acl_data_type) << ", value=";
|
||||
|
@ -101,13 +97,13 @@ void PrintScalarToBoolString(const char *str_data_ptr, const aclDataType &acl_da
|
|||
}
|
||||
}
|
||||
|
||||
void convertDataItem2Scalar(const char *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
|
||||
void ConvertDataItem2Scalar(const void *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
auto type_iter = print_acl_data_type_map.find(acl_data_type);
|
||||
auto type_id = type_iter->second;
|
||||
if (type_id == TypeId::kNumberTypeBool) {
|
||||
PrintScalarToBoolString(str_data_ptr, acl_data_type, buf);
|
||||
PrintScalarToBoolString(reinterpret_cast<const char *>(str_data_ptr), acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt8) {
|
||||
PrintScalarToString<int8_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt8) {
|
||||
|
@ -178,7 +174,7 @@ bool ConvertDataset2Tensor(acltdtDataset *acl_dataset) {
|
|||
if (!judgeLengthValid(acl_data_size, acl_data_type)) {
|
||||
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
|
||||
}
|
||||
convertDataItem2Scalar(acl_data, acl_data_type, &buf);
|
||||
ConvertDataItem2Scalar(reinterpret_cast<void *>(acl_data), acl_data_type, &buf);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -210,7 +206,6 @@ bool SaveDataset2File(acltdtDataset *acl_dataset, const std::string &print_file_
|
|||
acltdtDataItem *item = acltdtGetDataItem(acl_dataset, i);
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
acltdtTensorType acl_tensor_type = acltdtGetTensorTypeFromItem(item);
|
||||
|
||||
if (acl_tensor_type == ACL_TENSOR_DATA_END_OF_SEQUENCE) {
|
||||
MS_LOG(INFO) << "Acl channel received end-of-sequence for print op.";
|
||||
ret_end_thread = true;
|
||||
|
@ -358,5 +353,4 @@ void TensorPrint::operator()() {
|
|||
TensorPrintOut2File(acl_handle_, print_file_path_);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace mindspore
|
|
@ -14,13 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "ir/dtype/type.h"
|
||||
#ifndef NO_DLIB
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "tdt/tsd_client.h"
|
||||
#include "tdt/data_common.h"
|
||||
|
@ -41,5 +40,4 @@ class COMMON_EXPORT TensorPrint {
|
|||
const acltdtChannelHandle *acl_handle_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_TENSORPRINT_UTILS_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_
|
|
@ -15,12 +15,207 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
|
||||
#include <algorithm>
|
||||
#include "plugin/device/ascend/hal/hardware/ge_device_context.h"
|
||||
#include "include/transform/graph_ir/types.h"
|
||||
#include "include/transform/graph_ir/utils.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#include "graph/model.h"
|
||||
#include "transform/graph_ir/op_adapter_map.h"
|
||||
#include "plugin/device/ascend/hal/device/tensorprint_utils.h"
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "runtime/dev.h"
|
||||
#include "toolchain/plog.h"
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h"
|
||||
#include "plugin/device/ascend/hal/profiler/parallel_strategy_profiling.h"
|
||||
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractTuplePtr;
|
||||
using mindspore::transform::GeTensorPtr;
|
||||
using mindspore::transform::MeTensorPtr;
|
||||
using mindspore::transform::Status;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace {
|
||||
constexpr auto kUnknowErrorString = "Unknown error occurred";
|
||||
|
||||
void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *const tensors) {
|
||||
for (auto item : dict) {
|
||||
if ((!py::isinstance<py::str>(item.first))) {
|
||||
MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it.";
|
||||
continue;
|
||||
}
|
||||
std::shared_ptr<tensor::Tensor> tensor;
|
||||
std::string name = py::cast<std::string>(item.first);
|
||||
if (py::isinstance<py::float_>(item.second.attr("data"))) {
|
||||
// convert float to tensor with shape([1])
|
||||
tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, std::vector<int64_t>({1}));
|
||||
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("data"));
|
||||
} else if (py::isinstance<py::int_>(item.second.attr("data"))) {
|
||||
// convert int64_t to tensor with shape([1])
|
||||
tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>({1}));
|
||||
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("data"));
|
||||
} else if (py::isinstance<tensor::Tensor>(item.second.attr("data"))) {
|
||||
// cast tensor
|
||||
tensor = py::cast<std::shared_ptr<tensor::Tensor>>(item.second.attr("data"));
|
||||
}
|
||||
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get default value for " << name << " failed";
|
||||
}
|
||||
(void)tensors->emplace(name, tensor);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AscendDeprecatedInterface::DoExecNonInputGraph(const std::string &phase) {
|
||||
std::vector<GeTensorPtr> ge_tensors;
|
||||
std::vector<GeTensorPtr> ge_outputs;
|
||||
transform::RunOptions run_options;
|
||||
run_options.name = phase;
|
||||
auto graph_runner = transform::GetGraphRunner();
|
||||
if (graph_runner == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not found GraphRunner";
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
// Release GIL before calling into (potentially long-running) C++ code
|
||||
ScopedLongRunning release;
|
||||
Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs);
|
||||
if (ret != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed";
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool AscendDeprecatedInterface::InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||
const std::vector<TypePtr> &types,
|
||||
const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, const std::string &phase) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
ge_device_context_->Initialize();
|
||||
std::vector<int64_t> ge_types;
|
||||
(void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types),
|
||||
[](const TypePtr &i) -> int64_t { return transform::ConvertDataType(i->type_id()); });
|
||||
|
||||
ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE);
|
||||
ConfigManager::GetInstance().set_iter_num(queue_name, size);
|
||||
ConfigManager::GetInstance().set_dataset_phase(phase);
|
||||
|
||||
DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes);
|
||||
ConfigManager::GetInstance().set_dataset_param(param);
|
||||
|
||||
auto env_ge = common::GetEnv("MS_ENABLE_GE");
|
||||
auto env_training = common::GetEnv("MS_GE_TRAIN");
|
||||
bool training = false;
|
||||
if (env_ge == "1" && env_training == "1") {
|
||||
training = true;
|
||||
}
|
||||
if (training) {
|
||||
(void)setenv("GE_TRAIN", "1", 1);
|
||||
} else {
|
||||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
||||
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
|
||||
if (transform::CompileDatasetGraph(param, phase) != transform::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Build dateset graph failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
GeDeviceResManager::CreateSessionAndGraphRunner(training);
|
||||
|
||||
MS_LOG(INFO) << "DoExecNonInputGraph:" << phase;
|
||||
DoExecNonInputGraph(phase);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void AscendDeprecatedInterface::ExportDFGraph(const std::string &file_name, const std::string &phase,
|
||||
const py::object &encrypt, char *key) {
|
||||
MS_LOG(DEBUG) << "Export graph begin.";
|
||||
transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase);
|
||||
if (wrap_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Get graph form DfGraphManager failed, phase = " << phase;
|
||||
return;
|
||||
}
|
||||
|
||||
transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_;
|
||||
if (ge_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Graph is null!";
|
||||
return;
|
||||
}
|
||||
if (key != nullptr) {
|
||||
if (py::isinstance<py::none()>(encrypt)) {
|
||||
MS_LOG(ERROR) << "ERROR: encrypt is not a function";
|
||||
return;
|
||||
}
|
||||
// get model stream
|
||||
ge::Model model("", "");
|
||||
model.SetGraph(*ge_graph);
|
||||
ge::Buffer model_data;
|
||||
auto ge_ret = model.Save(model_data);
|
||||
if (ge_ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "ERROR: GE model save fail";
|
||||
return;
|
||||
}
|
||||
// convert model and key into py::bytes
|
||||
const std::string str(reinterpret_cast<char *>(model_data.GetData()), model_data.GetSize());
|
||||
py::bytes model_bytes(str);
|
||||
py::bytes key_bytes(key);
|
||||
|
||||
// call python encrypt func
|
||||
py::bytes encrypted_model_stream = encrypt(model_bytes, key_bytes);
|
||||
if (encrypted_model_stream == py::none()) {
|
||||
MS_LOG(ERROR) << "ERROR: Model encrypt fail";
|
||||
return;
|
||||
}
|
||||
// save to file
|
||||
std::ofstream ofs(file_name);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "ERROR: Open File '" << file_name << "' failed!";
|
||||
return;
|
||||
}
|
||||
ofs << std::string(encrypted_model_stream);
|
||||
ofs.close();
|
||||
} else {
|
||||
if (ge_graph->SaveToFile(file_name) != 0) {
|
||||
MS_LOG(EXCEPTION) << "Export air model failed.";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Export air model finish.";
|
||||
}
|
||||
|
||||
FuncGraphPtr AscendDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
transform::TensorOrderMap init_tensors{};
|
||||
ConvertObjectToTensors(init_params, &init_tensors);
|
||||
return GeGraphExecutor::BuildDFGraph(anf_graph, init_tensors, true);
|
||||
}
|
||||
|
||||
void AscendDeprecatedInterface::ClearGraphWrapper() { transform::DfGraphManager::GetInstance().ClearGraph(); }
|
||||
|
||||
void AscendDeprecatedInterface::ClearOpAdapterMap() { transform::OpAdapterMap::get().clear(); }
|
||||
|
||||
void AscendDeprecatedInterface::EraseGeResource() {
|
||||
transform::DfGraphManager::GetInstance().DeleteGraphRunner();
|
||||
transform::DfGraphManager::GetInstance().EraseAnfGraph();
|
||||
transform::DfGraphManager::GetInstance().DeleteGeSession();
|
||||
}
|
||||
|
||||
uint32_t AscendDeprecatedInterface::InitCollective() {
|
||||
#ifdef WITH_BACKEND
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -45,6 +240,112 @@ uint32_t AscendDeprecatedInterface::InitCollective() {
|
|||
void AscendDeprecatedInterface::DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) {
|
||||
return profiler::ascend::DumpProfileParallelStrategy(func_graph);
|
||||
}
|
||||
|
||||
bool AscendDeprecatedInterface::OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) != 0) {
|
||||
MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto role = common::GetEnv("MS_ROLE");
|
||||
if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
|
||||
uint32_t rank_size;
|
||||
auto rank_size_env = common::GetEnv("RANK_SIZE");
|
||||
if (rank_size_env.empty()) {
|
||||
MS_LOG(INFO) << "Should config rank size.";
|
||||
rank_size = 1;
|
||||
} else {
|
||||
int rank_env = std::stoi(rank_size_env);
|
||||
if (rank_env <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
|
||||
}
|
||||
rank_size = IntToUint(rank_env);
|
||||
}
|
||||
|
||||
int log_ret = DlogReportInitialize();
|
||||
if (log_ret != 0) {
|
||||
MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret;
|
||||
}
|
||||
|
||||
if (ErrorManager::GetInstance().Init() != 0) {
|
||||
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
|
||||
}
|
||||
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
|
||||
auto ret = rtSetDevice(static_cast<int32_t>(device_id));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
|
||||
if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
|
||||
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
#ifdef ENABLE_TDTQUE
|
||||
auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) {
|
||||
return std::thread(TensorPrint(path, acl_handle));
|
||||
};
|
||||
ms_context_ptr->CreateTensorPrintThread(thread_crt);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeprecatedInterface::CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
|
||||
}
|
||||
MS_LOG(INFO) << "Start to close tsd, ref = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
|
||||
return true;
|
||||
}
|
||||
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
|
||||
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
|
||||
#ifdef ENABLE_TDTQUE
|
||||
pybind11::gil_scoped_release gil_release;
|
||||
ms_context_ptr->DestroyTensorPrintThread();
|
||||
#endif
|
||||
if (ErrorManager::GetInstance().Init() != 0) {
|
||||
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
|
||||
}
|
||||
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto ret = rtDeviceReset(static_cast<int32_t>(device_id));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
|
||||
if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
|
||||
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
|
||||
MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
|
||||
(void)DlogReportFinalize();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
|
||||
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeprecatedInterface::IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,21 +27,36 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
class AscendDeviceContext;
|
||||
class GeDeviceContext;
|
||||
|
||||
class AscendDeprecatedInterface : public DeprecatedInterface {
|
||||
public:
|
||||
explicit AscendDeprecatedInterface(AscendDeviceContext *ascend_device_context)
|
||||
: ascend_device_context_(ascend_device_context) {}
|
||||
explicit AscendDeprecatedInterface(GeDeviceContext *ge_device_context) : ge_device_context_(ge_device_context) {}
|
||||
~AscendDeprecatedInterface() override = default;
|
||||
|
||||
// for ge
|
||||
void DoExecNonInputGraph(const std::string &phase) override;
|
||||
bool InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, const std::string &phase) override;
|
||||
void ExportDFGraph(const std::string &file_name, const std::string &phase, const pybind11::object &encrypt,
|
||||
char *key) override;
|
||||
FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) override;
|
||||
void ClearGraphWrapper() override;
|
||||
void ClearOpAdapterMap() override;
|
||||
void EraseGeResource() override;
|
||||
// for ascend
|
||||
uint32_t InitCollective() override;
|
||||
void DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) override;
|
||||
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) override;
|
||||
bool IsTsdOpened(const std::shared_ptr<MsContext> &inst_context) override;
|
||||
|
||||
private:
|
||||
AscendDeviceContext *const ascend_device_context_;
|
||||
GeDeviceContext *const ge_device_context_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_DEPRECATED_INTERFACE_H_
|
||||
|
|
|
@ -65,6 +65,9 @@ void AscendDeviceContext::Destroy() {
|
|||
#endif
|
||||
MS_LOG(INFO) << "Status record: Enter Destroy...";
|
||||
if (!initialized_) {
|
||||
if (deprecated_interface_ != nullptr) {
|
||||
deprecated_interface_->CloseTsd(MsContext::GetInstance(), true);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -77,6 +80,9 @@ void AscendDeviceContext::Destroy() {
|
|||
if (runtime_instance_) {
|
||||
runtime_instance_ = nullptr;
|
||||
}
|
||||
if (deprecated_interface_ != nullptr) {
|
||||
deprecated_interface_->CloseTsd(MsContext::GetInstance(), true);
|
||||
}
|
||||
initialized_ = false;
|
||||
MS_LOG(INFO) << "Status record: Destroy success.";
|
||||
}
|
||||
|
@ -101,7 +107,7 @@ RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const {
|
|||
DeprecatedInterface *AscendDeviceContext::GetDeprecatedInterface() {
|
||||
// need lock when multi-threads
|
||||
if (deprecated_interface_ == nullptr) {
|
||||
deprecated_interface_ = std::make_unique<AscendDeprecatedInterface>(this);
|
||||
deprecated_interface_ = std::make_unique<AscendDeprecatedInterface>(nullptr);
|
||||
}
|
||||
return deprecated_interface_.get();
|
||||
}
|
||||
|
|
|
@ -1,211 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ge_deprecated_interface.h"
|
||||
#include <algorithm>
|
||||
#include "plugin/device/ascend/hal/hardware/ge_device_context.h"
|
||||
#include "include/transform/graph_ir/types.h"
|
||||
#include "include/transform/graph_ir/utils.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#include "graph/model.h"
|
||||
#include "transform/graph_ir/op_adapter_map.h"
|
||||
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractTuplePtr;
|
||||
using mindspore::transform::GeTensorPtr;
|
||||
using mindspore::transform::MeTensorPtr;
|
||||
using mindspore::transform::Status;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace {
|
||||
void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *const tensors) {
|
||||
for (auto item : dict) {
|
||||
if ((!py::isinstance<py::str>(item.first))) {
|
||||
MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it.";
|
||||
continue;
|
||||
}
|
||||
std::shared_ptr<tensor::Tensor> tensor;
|
||||
std::string name = py::cast<std::string>(item.first);
|
||||
if (py::isinstance<py::float_>(item.second.attr("data"))) {
|
||||
// convert float to tensor with shape([1])
|
||||
tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, std::vector<int64_t>({1}));
|
||||
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("data"));
|
||||
} else if (py::isinstance<py::int_>(item.second.attr("data"))) {
|
||||
// convert int64_t to tensor with shape([1])
|
||||
tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>({1}));
|
||||
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("data"));
|
||||
} else if (py::isinstance<tensor::Tensor>(item.second.attr("data"))) {
|
||||
// cast tensor
|
||||
tensor = py::cast<std::shared_ptr<tensor::Tensor>>(item.second.attr("data"));
|
||||
}
|
||||
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get default value for " << name << " failed";
|
||||
}
|
||||
(void)tensors->emplace(name, tensor);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GeDeprecatedInterface::DoExecNonInputGraph(const std::string &phase) {
|
||||
std::vector<GeTensorPtr> ge_tensors;
|
||||
std::vector<GeTensorPtr> ge_outputs;
|
||||
transform::RunOptions run_options;
|
||||
run_options.name = phase;
|
||||
auto graph_runner = transform::GetGraphRunner();
|
||||
if (graph_runner == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not found GraphRunner";
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
// Release GIL before calling into (potentially long-running) C++ code
|
||||
ScopedLongRunning release;
|
||||
Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs);
|
||||
if (ret != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed";
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool GeDeprecatedInterface::InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||
const std::vector<TypePtr> &types,
|
||||
const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, const std::string &phase) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
ge_device_context_->Initialize();
|
||||
std::vector<int64_t> ge_types;
|
||||
(void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types),
|
||||
[](const TypePtr &i) -> int64_t { return transform::ConvertDataType(i->type_id()); });
|
||||
|
||||
ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE);
|
||||
ConfigManager::GetInstance().set_iter_num(queue_name, size);
|
||||
ConfigManager::GetInstance().set_dataset_phase(phase);
|
||||
|
||||
DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes);
|
||||
ConfigManager::GetInstance().set_dataset_param(param);
|
||||
|
||||
auto env_ge = common::GetEnv("MS_ENABLE_GE");
|
||||
auto env_training = common::GetEnv("MS_GE_TRAIN");
|
||||
bool training = false;
|
||||
if (env_ge == "1" && env_training == "1") {
|
||||
training = true;
|
||||
}
|
||||
if (training) {
|
||||
(void)setenv("GE_TRAIN", "1", 1);
|
||||
} else {
|
||||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
||||
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
|
||||
if (transform::CompileDatasetGraph(param, phase) != transform::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Build dateset graph failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
GeDeviceResManager::CreateSessionAndGraphRunner(training);
|
||||
|
||||
MS_LOG(INFO) << "DoExecNonInputGraph:" << phase;
|
||||
DoExecNonInputGraph(phase);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void GeDeprecatedInterface::ExportDFGraph(const std::string &file_name, const std::string &phase,
|
||||
const py::object &encrypt, char *key) {
|
||||
MS_LOG(DEBUG) << "Export graph begin.";
|
||||
transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase);
|
||||
if (wrap_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Get graph form DfGraphManager failed, phase = " << phase;
|
||||
return;
|
||||
}
|
||||
|
||||
transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_;
|
||||
if (ge_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Graph is null!";
|
||||
return;
|
||||
}
|
||||
if (key != nullptr) {
|
||||
if (py::isinstance<py::none()>(encrypt)) {
|
||||
MS_LOG(ERROR) << "ERROR: encrypt is not a function";
|
||||
return;
|
||||
}
|
||||
// get model stream
|
||||
ge::Model model("", "");
|
||||
model.SetGraph(*ge_graph);
|
||||
ge::Buffer model_data;
|
||||
auto ge_ret = model.Save(model_data);
|
||||
if (ge_ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "ERROR: GE model save fail";
|
||||
return;
|
||||
}
|
||||
// convert model and key into py::bytes
|
||||
const std::string str(reinterpret_cast<char *>(model_data.GetData()), model_data.GetSize());
|
||||
py::bytes model_bytes(str);
|
||||
py::bytes key_bytes(key);
|
||||
|
||||
// call python encrypt func
|
||||
py::bytes encrypted_model_stream = encrypt(model_bytes, key_bytes);
|
||||
if (encrypted_model_stream == py::none()) {
|
||||
MS_LOG(ERROR) << "ERROR: Model encrypt fail";
|
||||
return;
|
||||
}
|
||||
// save to file
|
||||
std::ofstream ofs(file_name);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "ERROR: Open File '" << file_name << "' failed!";
|
||||
return;
|
||||
}
|
||||
ofs << std::string(encrypted_model_stream);
|
||||
ofs.close();
|
||||
} else {
|
||||
if (ge_graph->SaveToFile(file_name) != 0) {
|
||||
MS_LOG(EXCEPTION) << "Export air model failed.";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Export air model finish.";
|
||||
}
|
||||
|
||||
FuncGraphPtr GeDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
transform::TensorOrderMap init_tensors{};
|
||||
ConvertObjectToTensors(init_params, &init_tensors);
|
||||
return GeGraphExecutor::BuildDFGraph(anf_graph, init_tensors, true);
|
||||
}
|
||||
|
||||
void GeDeprecatedInterface::ClearGraphWrapper() { transform::DfGraphManager::GetInstance().ClearGraph(); }
|
||||
|
||||
void GeDeprecatedInterface::ClearOpAdapterMap() { transform::OpAdapterMap::get().clear(); }
|
||||
|
||||
void GeDeprecatedInterface::EraseGeResource() {
|
||||
transform::DfGraphManager::GetInstance().DeleteGraphRunner();
|
||||
transform::DfGraphManager::GetInstance().EraseAnfGraph();
|
||||
transform::DfGraphManager::GetInstance().DeleteGeSession();
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -1,53 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
class GeDeviceContext;
|
||||
|
||||
class GeDeprecatedInterface : public DeprecatedInterface {
|
||||
public:
|
||||
explicit GeDeprecatedInterface(GeDeviceContext *ge_device_context) : ge_device_context_(ge_device_context) {}
|
||||
void DoExecNonInputGraph(const std::string &phase) override;
|
||||
bool InitExecDataset(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, const std::string &phase) override;
|
||||
void ExportDFGraph(const std::string &file_name, const std::string &phase, const pybind11::object &encrypt,
|
||||
char *key) override;
|
||||
FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const pybind11::dict &init_params) override;
|
||||
void ClearGraphWrapper() override;
|
||||
void ClearOpAdapterMap() override;
|
||||
void EraseGeResource() override;
|
||||
|
||||
private:
|
||||
GeDeviceContext *const ge_device_context_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEPRECATED_INTERFACE_H_
|
|
@ -427,7 +427,12 @@ void GeDeviceContext::Initialize() {
|
|||
initialized_ = InitGe(MsContext::GetInstance());
|
||||
}
|
||||
|
||||
void GeDeviceContext::Destroy() { (void)FinalizeGe(MsContext::GetInstance()); }
|
||||
void GeDeviceContext::Destroy() {
|
||||
(void)FinalizeGe(MsContext::GetInstance());
|
||||
if (deprecated_interface_ != nullptr) {
|
||||
deprecated_interface_->CloseTsd(MsContext::GetInstance(), true);
|
||||
}
|
||||
}
|
||||
|
||||
void GeDeviceResManager::Initialize() {
|
||||
if (mem_manager_ == nullptr) {
|
||||
|
@ -748,7 +753,7 @@ FuncGraphPtr GeGraphExecutor::BuildDFGraph(const FuncGraphPtr &anf_graph,
|
|||
DeprecatedInterface *GeDeviceContext::GetDeprecatedInterface() {
|
||||
// need lock when multi-threads
|
||||
if (deprecated_interface_ == nullptr) {
|
||||
deprecated_interface_ = std::make_unique<GeDeprecatedInterface>(this);
|
||||
deprecated_interface_ = std::make_unique<AscendDeprecatedInterface>(this);
|
||||
}
|
||||
return deprecated_interface_.get();
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "plugin/device/ascend/hal/hardware/ge_deprecated_interface.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -93,7 +93,7 @@ class GeDeviceContext : public DeviceInterface<GeGraphExecutor, GeDeviceResManag
|
|||
void SetHcclOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
|
||||
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options);
|
||||
|
||||
std::unique_ptr<GeDeprecatedInterface> deprecated_interface_;
|
||||
std::unique_ptr<AscendDeprecatedInterface> deprecated_interface_;
|
||||
bool initialized_;
|
||||
};
|
||||
} // namespace ascend
|
||||
|
|
|
@ -15,151 +15,14 @@
|
|||
*/
|
||||
|
||||
#include "runtime/device/context_extends.h"
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "include/common/utils/config_manager.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#ifndef NO_DLIB
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "runtime/dev.h"
|
||||
#include "toolchain/plog.h"
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
#endif
|
||||
#ifdef ENABLE_D
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#include "include/transform/graph_ir/utils.h"
|
||||
#endif
|
||||
#include "profiler/device/profiling.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace context {
|
||||
#ifdef ENABLE_D
|
||||
namespace {
|
||||
constexpr auto kMindsporeDumpConfig = "MINDSPORE_DUMP_CONFIG";
|
||||
const std::vector<std::string> kGeDumpMode = {"all", "input", "output"};
|
||||
} // namespace
|
||||
#endif
|
||||
|
||||
constexpr auto kUnknowErrorString = "Unknown error occurred";
|
||||
#ifndef NO_DLIB
|
||||
// Open tdt dataset
|
||||
bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) != 0) {
|
||||
MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto role = common::GetEnv("MS_ROLE");
|
||||
if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t rank_size = 1;
|
||||
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
|
||||
auto rank_size_env = common::GetEnv("RANK_SIZE");
|
||||
if (rank_size_env.empty()) {
|
||||
MS_LOG(INFO) << "Should config rank size.";
|
||||
rank_size = 1;
|
||||
} else {
|
||||
int rank_env = std::stoi(rank_size_env);
|
||||
if (rank_env <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
|
||||
}
|
||||
rank_size = IntToUint(rank_env);
|
||||
}
|
||||
|
||||
int log_ret = DlogReportInitialize();
|
||||
if (log_ret != 0) {
|
||||
MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret;
|
||||
}
|
||||
|
||||
if (ErrorManager::GetInstance().Init() != 0) {
|
||||
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
|
||||
}
|
||||
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
|
||||
auto ret = rtSetDevice(static_cast<int32_t>(device_id));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
|
||||
if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
|
||||
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
#ifdef ENABLE_TDTQUE
|
||||
auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) {
|
||||
return std::thread(TensorPrint(path, acl_handle));
|
||||
};
|
||||
ms_context_ptr->CreateTensorPrintThread(thread_crt);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
|
||||
}
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
|
||||
return true;
|
||||
}
|
||||
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
|
||||
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
py::gil_scoped_release gil_release;
|
||||
ms_context_ptr->DestroyTensorPrintThread();
|
||||
#endif
|
||||
if (ErrorManager::GetInstance().Init() != 0) {
|
||||
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
|
||||
}
|
||||
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto ret = rtDeviceReset(static_cast<int32_t>(device_id));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
|
||||
if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
|
||||
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
|
||||
MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
|
||||
(void)DlogReportFinalize();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
|
||||
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#else
|
||||
bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
|
||||
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool) { return true; }
|
||||
#endif
|
||||
|
||||
bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
|
||||
}
|
||||
|
||||
// Register for device type.
|
||||
struct DeviceTypeSetRegister {
|
||||
DeviceTypeSetRegister() {
|
||||
|
|
|
@ -17,18 +17,6 @@
|
|||
#ifndef MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H
|
||||
#define MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H
|
||||
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/tensorprint_utils.h"
|
||||
#include "include/backend/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace context {
|
||||
BACKEND_EXPORT bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr);
|
||||
BACKEND_EXPORT bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force = false);
|
||||
|
||||
BACKEND_EXPORT bool IsTsdOpened(const std::shared_ptr<MsContext> &inst_context);
|
||||
} // namespace context
|
||||
} // namespace mindspore
|
||||
// this file makes static-code-checking happy
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H
|
||||
|
|
|
@ -19,7 +19,9 @@
|
|||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -47,7 +49,9 @@ class DeprecatedInterface {
|
|||
// ascend
|
||||
virtual uint32_t InitCollective() { return 0; } // return device id
|
||||
virtual void DumpProfileParallelStrategy(const FuncGraphPtr &func_graph) {}
|
||||
|
||||
virtual bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
|
||||
virtual bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force = false) { return true; }
|
||||
virtual bool IsTsdOpened(const std::shared_ptr<MsContext> &inst_context) { return true; }
|
||||
// gpu
|
||||
};
|
||||
} // namespace device
|
||||
|
|
|
@ -147,8 +147,8 @@ class MS_CORE_API MsContext {
|
|||
~MsContext() = default;
|
||||
MsContext(const MsContext &) = delete;
|
||||
MsContext &operator=(const MsContext &) = delete;
|
||||
using DeviceSeter = std::function<void(const std::string &device_target)>;
|
||||
using DeviceTypeSeter = std::function<void(std::shared_ptr<MsContext> &)>;
|
||||
using DeviceSeter = void (*)(const std::string &device_target);
|
||||
using DeviceTypeSeter = void (*)(std::shared_ptr<MsContext> &);
|
||||
static std::shared_ptr<MsContext> GetInstance();
|
||||
|
||||
bool enable_dump_ir() const;
|
||||
|
|
|
@ -150,6 +150,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_launch_transdata.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_graph_kernel.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/context_extends.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_bucket.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_event.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_build_ascend.cc"
|
||||
|
@ -163,7 +164,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_executor.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_utils.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_topo.cc"
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
void AscendDeprecatedInterface::DoExecNonInputGraph(const std::string &) {}
|
||||
bool AscendDeprecatedInterface::InitExecDataset(const std::string &, int64_t, int64_t, const std::vector<TypePtr> &,
|
||||
const std::vector<std::vector<int64_t>> &, const std::vector<int64_t> &,
|
||||
const std::string &) {
|
||||
return true;
|
||||
}
|
||||
void AscendDeprecatedInterface::ExportDFGraph(const std::string &, const std::string &, const pybind11::object &,
|
||||
char *) {}
|
||||
FuncGraphPtr AscendDeprecatedInterface::BuildDFGraph(const FuncGraphPtr &, const pybind11::dict &) { return nullptr; }
|
||||
void AscendDeprecatedInterface::ClearGraphWrapper() {}
|
||||
void AscendDeprecatedInterface::ClearOpAdapterMap() {}
|
||||
void AscendDeprecatedInterface::EraseGeResource() {}
|
||||
// for ascend
|
||||
uint32_t AscendDeprecatedInterface::InitCollective() { return 0; }
|
||||
void AscendDeprecatedInterface::DumpProfileParallelStrategy(const FuncGraphPtr &) {}
|
||||
|
||||
bool AscendDeprecatedInterface::OpenTsd(const std::shared_ptr<MsContext> &) { return true; }
|
||||
bool AscendDeprecatedInterface::CloseTsd(const std::shared_ptr<MsContext> &, bool) { return true; }
|
||||
bool AscendDeprecatedInterface::IsTsdOpened(const std::shared_ptr<MsContext> &) { return true; }
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -22,6 +22,7 @@
|
|||
#include "runtime/rt_model.h"
|
||||
#include "runtime/stream.h"
|
||||
#include "toolchain/adx_datadump_server.h"
|
||||
#include "toolchain/plog.h"
|
||||
|
||||
rtError_t rtEventSynchronize(rtEvent_t event) { return RT_ERROR_NONE; }
|
||||
|
||||
|
@ -219,3 +220,7 @@ RTS_API rtError_t rtKernelLaunchWithHandle(void *hdl, const uint64_t tilingKey,
|
|||
const void *kernelInfo) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
int DlogReportInitialize(void) { return 0; }
|
||||
|
||||
int DlogReportFinalize(void) { return 0; }
|
||||
|
|
Loading…
Reference in New Issue