From bcaf43e0fbc1d3f488d63ee0d1fc7788e33276ba Mon Sep 17 00:00:00 2001 From: liuwenhao4 Date: Fri, 6 Nov 2020 10:00:08 +0800 Subject: [PATCH] add supported data types for cast ops --- .../lite/src/runtime/kernel/arm/fp32/cast_fp32.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc index f67cdd8f024..3b826b5a3f4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc @@ -17,6 +17,7 @@ #include #include "schema/model_generated.h" #include "src/kernel_registry.h" +#include "src/tensor.h" #include "nnacl/fp32/cast.h" #include "nnacl/op_base.h" #include "src/runtime/runtime_api.h" @@ -70,6 +71,12 @@ int CastCPUKernel::DoCast(int thread_id) { MS_ASSERT(output_data != nullptr); auto input_data_type = input->data_type(); auto output_data_type = output->data_type(); + if (input_data_type == output_data_type) { + auto datalen = lite::DataTypeSize(input_data_type); + memcpy(reinterpret_cast(output_data) + offset * datalen, + reinterpret_cast(input->data_c()) + offset * datalen, data_num * datalen); + return RET_OK; + } if (output_data_type != kNumberTypeFloat32) { if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { Float32ToInt64(reinterpret_cast(input->data_c()) + offset, @@ -83,9 +90,6 @@ int CastCPUKernel::DoCast(int thread_id) { } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { Int32ToInt64(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, data_num); - } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt32) { - memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, - data_num * sizeof(int32_t)); } else { MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; return RET_ERROR; @@ -108,10 +112,6 @@ int CastCPUKernel::DoCast(int thread_id) { Fp16ToFloat32(reinterpret_cast(input->MutableData()) + offset, reinterpret_cast(output_data) + offset, data_num); break; - case kNumberTypeFloat32: - memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, - data_num * sizeof(float)); - break; default: MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; return RET_ERROR;