fix bug of cast operator.

This commit is contained in:
wang_shaocong 2020-10-28 18:00:57 +08:00
parent 23b07aee4c
commit 80bf20db84
4 changed files with 23 additions and 2 deletions
mindspore/lite
nnacl/fp32
src/runtime/kernel/arm/fp32
test

View File

@ -64,3 +64,15 @@ void Float32ToInt32(const float *input, int32_t *output, int number) {
output[i] = (int32_t)input[i];
}
}
void Float32ToInt64(const float *input, int64_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (int64_t)input[i];
}
}
void Int32ToInt64(const int32_t *input, int64_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (int64_t)input[i];
}
}

View File

@ -39,6 +39,8 @@ void Int32ToFloat32(const int32_t *input, float *output, int number);
void Fp16ToFloat32(const uint16_t *input, float *output, int number);
void Float32ToFp16(const float *input, uint16_t *output, int number);
void Float32ToInt32(const float *input, int32_t *output, int number);
void Float32ToInt64(const float *input, int64_t *output, int number);
void Int32ToInt64(const int32_t *input, int64_t *output, int number);
#ifdef __cplusplus
}
#endif

View File

@ -71,13 +71,18 @@ int CastCPUKernel::DoCast(int thread_id) {
auto input_data_type = input->data_type();
auto output_data_type = output->data_type();
if (output_data_type != kNumberTypeFloat32) {
if (input_data_type == kNumberTypeFloat32 &&
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) {
Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset,
reinterpret_cast<int64_t *>(output_data) + offset, data_num);
} else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) {
Float32ToInt32(reinterpret_cast<float *>(input->data_c()) + offset,
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
} else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) {
Float32ToFp16(reinterpret_cast<float *>(input->data_c()) + offset,
reinterpret_cast<uint16_t *>(output_data) + offset, data_num);
} else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) {
Int32ToInt64(reinterpret_cast<int32_t *>(input->data_c()) + offset,
reinterpret_cast<int64_t *>(output_data) + offset, data_num);
} else if (input_data_type == kNumberTypeInt32 &&
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset,

View File

@ -9,3 +9,5 @@ crnn_lite_lstm_v2.onnx;32,32,32,1
psenet_lite_mbv2.onnx;1,32,32,3
super-resolution-10.onnx;1,224,224,1
tinyyolov2-8.onnx;1,416,416,3
ml_2012_ocr_cn.onnx
ml_2012_ocr_cn_noLSTM.onnx