forked from mindspore-Ecosystem/mindspore
fix bug of cast operator.
This commit is contained in:
parent
23b07aee4c
commit
80bf20db84
mindspore/lite
|
@ -64,3 +64,15 @@ void Float32ToInt32(const float *input, int32_t *output, int number) {
|
|||
output[i] = (int32_t)input[i];
|
||||
}
|
||||
}
|
||||
|
||||
void Float32ToInt64(const float *input, int64_t *output, int number) {
|
||||
for (int i = 0; i < number; ++i) {
|
||||
output[i] = (int64_t)input[i];
|
||||
}
|
||||
}
|
||||
|
||||
void Int32ToInt64(const int32_t *input, int64_t *output, int number) {
|
||||
for (int i = 0; i < number; ++i) {
|
||||
output[i] = (int64_t)input[i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,8 @@ void Int32ToFloat32(const int32_t *input, float *output, int number);
|
|||
void Fp16ToFloat32(const uint16_t *input, float *output, int number);
|
||||
void Float32ToFp16(const float *input, uint16_t *output, int number);
|
||||
void Float32ToInt32(const float *input, int32_t *output, int number);
|
||||
void Float32ToInt64(const float *input, int64_t *output, int number);
|
||||
void Int32ToInt64(const int32_t *input, int64_t *output, int number);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -71,13 +71,18 @@ int CastCPUKernel::DoCast(int thread_id) {
|
|||
auto input_data_type = input->data_type();
|
||||
auto output_data_type = output->data_type();
|
||||
if (output_data_type != kNumberTypeFloat32) {
|
||||
if (input_data_type == kNumberTypeFloat32 &&
|
||||
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
|
||||
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) {
|
||||
Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset,
|
||||
reinterpret_cast<int64_t *>(output_data) + offset, data_num);
|
||||
} else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) {
|
||||
Float32ToInt32(reinterpret_cast<float *>(input->data_c()) + offset,
|
||||
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
|
||||
} else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) {
|
||||
Float32ToFp16(reinterpret_cast<float *>(input->data_c()) + offset,
|
||||
reinterpret_cast<uint16_t *>(output_data) + offset, data_num);
|
||||
} else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) {
|
||||
Int32ToInt64(reinterpret_cast<int32_t *>(input->data_c()) + offset,
|
||||
reinterpret_cast<int64_t *>(output_data) + offset, data_num);
|
||||
} else if (input_data_type == kNumberTypeInt32 &&
|
||||
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
|
||||
memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset,
|
||||
|
|
|
@ -9,3 +9,5 @@ crnn_lite_lstm_v2.onnx;32,32,32,1
|
|||
psenet_lite_mbv2.onnx;1,32,32,3
|
||||
super-resolution-10.onnx;1,224,224,1
|
||||
tinyyolov2-8.onnx;1,416,416,3
|
||||
ml_2012_ocr_cn.onnx
|
||||
ml_2012_ocr_cn_noLSTM.onnx
|
||||
|
|
Loading…
Reference in New Issue