transpose concat mean fp16 move allocator to Run

This commit is contained in:
zhaozhenlong 2020-08-19 17:18:15 +08:00
parent 5686315199
commit b39eca3826
3 changed files with 22 additions and 27 deletions

View File

@ -41,15 +41,7 @@ int ConcatFp16CPUKernel::Init() {
return ReSize();
}
int ConcatFp16CPUKernel::ReSize() {
FreeTmpBuffer();
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}
return ConcatBaseCPUKernel::ReSize();
}
int ConcatFp16CPUKernel::ReSize() { return ConcatBaseCPUKernel::ReSize(); }
int ConcatFp16CPUKernel::MallocTmpBuffer() {
for (const auto &in_tensor : in_tensors_) {
@ -105,6 +97,13 @@ int ConcatFp16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}
auto input_num = in_tensors_.size();
std::vector<int *> inputs_output_shape(input_num + 1, nullptr);

View File

@ -58,17 +58,7 @@ int ReduceFp16CPUKernel::Init() {
}
int ReduceFp16CPUKernel::ReSize() {
FreeTmpBuffer();
auto ret = ReduceBaseCPUKernel::ReSize();
if (ret != RET_OK) {
return ret;
}
ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}
return RET_OK;
return ReduceBaseCPUKernel::ReSize();
}
int ReduceFp16CPUKernel::CallReduceUnit(int task_id) {
@ -94,6 +84,12 @@ int ReduceFp16CPUKernel::Run() {
return prepare_ret;
}
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}
tmp_shape_ = in_tensors_.at(0)->shape();
auto in_tensor = in_tensors_.at(0);
if (in_tensor->data_type() == kNumberTypeFloat32 || in_tensor->data_type() == kNumberTypeFloat) {

View File

@ -59,12 +59,6 @@ int TransposeFp16CPUKernel::ReSize() {
param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1];
}
FreeFp16Buffer();
auto ret = MallocFp16Buffer();
if (ret != RET_OK) {
FreeFp16Buffer();
return ret;
}
return RET_OK;
}
@ -149,10 +143,16 @@ int TransposeFp16CPUKernel::Run() {
auto &out_tensor = out_tensors_.front();
if (in_tensor == nullptr || out_tensor == nullptr) {
MS_LOG(ERROR) << "null pointer referencing.";
FreeFp16Buffer();
return RET_ERROR;
}
// malloc when Run
ret = MallocFp16Buffer();
if (ret != RET_OK) {
FreeFp16Buffer();
return ret;
}
if (in_tensor->data_type() == kNumberTypeFloat || in_tensor->data_type() == kNumberTypeFloat32) {
in_data_ = reinterpret_cast<float *>(in_tensor->Data());
Float32ToFloat16(in_data_, fp16_in_data_, in_tensor->ElementsNum());