forked from mindspore-Ecosystem/mindspore
malloc ts memory for label
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
953acc0335
commit
881179fa10
|
@ -17,6 +17,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "runtime/mem.h"
|
#include "runtime/mem.h"
|
||||||
|
#include "runtime/dev.h"
|
||||||
#include "runtime/rt_model.h"
|
#include "runtime/rt_model.h"
|
||||||
#include "mindspore/core/utils/log_adapter.h"
|
#include "mindspore/core/utils/log_adapter.h"
|
||||||
#include "mindspore/core/utils/convert_utils_base.h"
|
#include "mindspore/core/utils/convert_utils_base.h"
|
||||||
|
@ -98,7 +99,14 @@ std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const st
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
uint32_t label_info_size = SizeToUint(sizeof(rtLabelDevInfo) * label_list.size());
|
uint32_t label_info_size = SizeToUint(sizeof(rtLabelDevInfo) * label_list.size());
|
||||||
rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM);
|
int64_t value = 0;
|
||||||
|
rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMORY, MEMORY_INFO_TS_LIMITED, &value);
|
||||||
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "Call rt api rtGetRtCapability failed, ret: " << rt_ret;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
rt_ret = rtMalloc(&label_info, label_info_size, (value == RT_CAPABILITY_SUPPORT) ? RT_MEMORY_TS : RT_MEMORY_HBM);
|
||||||
if (rt_ret != RT_ERROR_NONE) {
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
MS_LOG(ERROR) << "Call rt api rtMalloc failed, ret: " << rt_ret;
|
MS_LOG(ERROR) << "Call rt api rtMalloc failed, ret: " << rt_ret;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -197,7 +197,9 @@ RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoP
|
||||||
|
|
||||||
RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream) { return RT_ERROR_NONE; }
|
RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream) { return RT_ERROR_NONE; }
|
||||||
|
|
||||||
RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream) { return RT_ERROR_NONE; }
|
RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream) {
|
||||||
|
return RT_ERROR_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize,
|
RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize,
|
||||||
rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) {
|
rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) {
|
||||||
|
@ -207,3 +209,5 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim
|
||||||
RTS_API rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { return RT_ERROR_NONE; }
|
RTS_API rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { return RT_ERROR_NONE; }
|
||||||
|
|
||||||
RTS_API rtError_t rtProfRegisterCtrlCallback(uint32_t moduleId, rtProfCtrlHandle callback) { return RT_ERROR_NONE; }
|
RTS_API rtError_t rtProfRegisterCtrlCallback(uint32_t moduleId, rtProfCtrlHandle callback) { return RT_ERROR_NONE; }
|
||||||
|
|
||||||
|
RTS_API rtError_t rtGetRtCapability(rtFeatureType_t, int32_t, int64_t *) { return RT_ERROR_NONE; }
|
||||||
|
|
Loading…
Reference in New Issue