From 881179fa10985d69047f3bcf2ebf2ccf9e9ee5c0 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Thu, 25 Nov 2021 15:37:18 +0800 Subject: [PATCH] malloc ts memory for label Signed-off-by: zhoufeng --- .../device/ascend/ge_runtime/task/label_manager.cc | 10 +++++++++- tests/ut/cpp/stub/runtime/runtime_stub.cc | 6 +++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.cc index e8c9b48a2d9..b1dd9b667a2 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.cc @@ -17,6 +17,7 @@ #include #include #include "runtime/mem.h" +#include "runtime/dev.h" #include "runtime/rt_model.h" #include "mindspore/core/utils/log_adapter.h" #include "mindspore/core/utils/convert_utils_base.h" @@ -98,7 +99,14 @@ std::shared_ptr LabelManager::GetLabelInfo(rtModel_t model, const st return nullptr; } uint32_t label_info_size = SizeToUint(sizeof(rtLabelDevInfo) * label_list.size()); - rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); + int64_t value = 0; + rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMORY, MEMORY_INFO_TS_LIMITED, &value); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rt api rtGetRtCapability failed, ret: " << rt_ret; + return nullptr; + } + + rt_ret = rtMalloc(&label_info, label_info_size, (value == RT_CAPABILITY_SUPPORT) ? RT_MEMORY_TS : RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { MS_LOG(ERROR) << "Call rt api rtMalloc failed, ret: " << rt_ret; return nullptr; diff --git a/tests/ut/cpp/stub/runtime/runtime_stub.cc b/tests/ut/cpp/stub/runtime/runtime_stub.cc index 7e5d8f43411..600a97230d9 100644 --- a/tests/ut/cpp/stub/runtime/runtime_stub.cc +++ b/tests/ut/cpp/stub/runtime/runtime_stub.cc @@ -197,7 +197,9 @@ RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoP RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream) { return RT_ERROR_NONE; } -RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream) { return RT_ERROR_NONE; } +RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream) { + return RT_ERROR_NONE; +} RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) { @@ -207,3 +209,5 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim RTS_API rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { return RT_ERROR_NONE; } RTS_API rtError_t rtProfRegisterCtrlCallback(uint32_t moduleId, rtProfCtrlHandle callback) { return RT_ERROR_NONE; } + +RTS_API rtError_t rtGetRtCapability(rtFeatureType_t, int32_t, int64_t *) { return RT_ERROR_NONE; }