malloc ts memory for label

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2021-11-25 15:37:18 +08:00
parent 953acc0335
commit 881179fa10
2 changed files with 14 additions and 2 deletions

View File

@ -17,6 +17,7 @@
#include <algorithm>
#include <string>
#include "runtime/mem.h"
#include "runtime/dev.h"
#include "runtime/rt_model.h"
#include "mindspore/core/utils/log_adapter.h"
#include "mindspore/core/utils/convert_utils_base.h"
@ -98,7 +99,14 @@ std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const st
return nullptr;
}
uint32_t label_info_size = SizeToUint(sizeof(rtLabelDevInfo) * label_list.size());
rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM);
int64_t value = 0;
rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMORY, MEMORY_INFO_TS_LIMITED, &value);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rt api rtGetRtCapability failed, ret: " << rt_ret;
return nullptr;
}
rt_ret = rtMalloc(&label_info, label_info_size, (value == RT_CAPABILITY_SUPPORT) ? RT_MEMORY_TS : RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rt api rtMalloc failed, ret: " << rt_ret;
return nullptr;

View File

@ -197,7 +197,9 @@ RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoP
RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream) { return RT_ERROR_NONE; }
RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream) { return RT_ERROR_NONE; }
RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream) {
return RT_ERROR_NONE;
}
RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize,
rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) {
@ -207,3 +209,5 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim
RTS_API rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { return RT_ERROR_NONE; }
RTS_API rtError_t rtProfRegisterCtrlCallback(uint32_t moduleId, rtProfCtrlHandle callback) { return RT_ERROR_NONE; }
RTS_API rtError_t rtGetRtCapability(rtFeatureType_t, int32_t, int64_t *) { return RT_ERROR_NONE; }