forked from mindspore-Ecosystem/mindspore
[MSLITE][Develop] fix bug of arm cpu fp16 op cast
This commit is contained in:
parent
a3f9be98c0
commit
27cc6d6c17
|
@ -27,6 +27,18 @@ void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number) {
|
|||
}
|
||||
}
|
||||
|
||||
void Float16ToInt32(const float16_t *input, int32_t *output, int number) {
|
||||
for (int i = 0; i < number; ++i) {
|
||||
output[i] = (int32_t)input[i];
|
||||
}
|
||||
}
|
||||
|
||||
void Float16ToInt64(const float16_t *input, int64_t *output, int number) {
|
||||
for (int i = 0; i < number; ++i) {
|
||||
output[i] = (int64_t)input[i];
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ARM64
|
||||
void Float32ToFloat16(const float *input, float16_t *output, int number) {
|
||||
for (int i = 0; i < number; ++i) {
|
||||
|
|
|
@ -24,6 +24,8 @@ extern "C" {
|
|||
#endif
|
||||
void BoolToFloat16(const bool *input, float16_t *output, int number);
|
||||
void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number);
|
||||
void Float16ToInt32(const float16_t *input, int32_t *output, int number);
|
||||
void Float16ToInt64(const float16_t *input, int64_t *output, int number);
|
||||
void Float32ToFloat16(const float *input, float16_t *output, int number);
|
||||
void Float16ToFloat32(const float16_t *input, float *output, int number);
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -65,25 +65,58 @@ int CastFp16CPUKernel::DoCast(int thread_id) {
|
|||
}
|
||||
|
||||
auto offset = thread_id * stride_;
|
||||
auto output_data = out_tensors_.at(0)->MutableData();
|
||||
switch (input->data_type()) {
|
||||
case kNumberTypeBool:
|
||||
BoolToFloat16(reinterpret_cast<bool *>(input->MutableData()) + offset,
|
||||
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
|
||||
case kNumberTypeUInt8:
|
||||
Uint8ToFloat16(reinterpret_cast<uint8_t *>(input->MutableData()) + offset,
|
||||
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
|
||||
case kNumberTypeFloat32:
|
||||
Float32ToFloat16(reinterpret_cast<float *>(input->MutableData()) + offset,
|
||||
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
Float16ToFloat32(reinterpret_cast<float16_t *>(input->MutableData()) + offset,
|
||||
reinterpret_cast<float *>(output_data) + offset, data_num);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
|
||||
return RET_ERROR;
|
||||
auto output = out_tensors_.at(0);
|
||||
auto output_data = output->data_c();
|
||||
auto input_data_type = input->data_type();
|
||||
auto output_data_type = output->data_type();
|
||||
|
||||
if (input_data_type == kNumberTypeFloat16) {
|
||||
switch (output_data_type) {
|
||||
case kNumberTypeInt64:
|
||||
Float16ToInt64(reinterpret_cast<float16_t *>(input->data_c()) + offset,
|
||||
reinterpret_cast<int64_t *>(output_data) + offset, data_num);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
Float16ToInt32(reinterpret_cast<float16_t *>(input->data_c()) + offset,
|
||||
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
Float16ToFloat32(reinterpret_cast<float16_t *>(input->MutableData()) + offset,
|
||||
reinterpret_cast<float *>(output_data) + offset, data_num);
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
memcpy(reinterpret_cast<float16_t *>(output_data) + offset,
|
||||
reinterpret_cast<float16_t *>(input->data_c()) + offset, data_num * sizeof(float16_t));
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (input_data_type == kNumberTypeFloat32) {
|
||||
switch (output_data_type) {
|
||||
case kNumberTypeInt64:
|
||||
Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset,
|
||||
reinterpret_cast<int64_t *>(output_data) + offset, data_num);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
Float32ToInt32(reinterpret_cast<float *>(input->data_c()) + offset,
|
||||
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset,
|
||||
data_num * sizeof(float));
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
Float32ToFloat16(reinterpret_cast<float *>(input->MutableData()) + offset,
|
||||
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -94,6 +94,4 @@ kernel::LiteKernel *CpuPadFp16KernelCreator(const std::vector<lite::Tensor *> &i
|
|||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Pad, CpuPadFp16KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -304,7 +304,8 @@ TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_ten
|
|||
return dtype;
|
||||
}
|
||||
}
|
||||
return kNumberTypeFloat32;
|
||||
MS_ASSERT(in_tensors.size() > 0);
|
||||
return in_tensors[0]->data_type();
|
||||
}
|
||||
|
||||
void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) {
|
||||
|
@ -346,7 +347,8 @@ kernel::SubGraphType Scheduler::GetKernelSubGraphType(kernel::LiteKernel *kernel
|
|||
if (desc.data_type == kNumberTypeFloat16) {
|
||||
return kernel::kCpuFP16SubGraph;
|
||||
} else if (desc.data_type == kNumberTypeFloat32 || desc.data_type == kNumberTypeInt8 ||
|
||||
desc.data_type == kNumberTypeInt32 || desc.data_type == kNumberTypeBool) {
|
||||
desc.data_type == kNumberTypeInt32 || desc.data_type == kNumberTypeBool ||
|
||||
desc.data_type == kNumberTypeUInt8) {
|
||||
return kernel::kCpuFP32SubGraph;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue