cast support fp32->int

This commit is contained in:
chenjianping 2020-08-12 15:37:04 +08:00
parent 13c2b23356
commit 4b8c9da7e7
7 changed files with 36 additions and 9 deletions

View File

@ -37,7 +37,7 @@ constexpr uint32_t kNHWC_w_index = 2;
constexpr uint32_t kNHWC_c_index = 3;
constexpr uint32_t kDimension_4d = 4;
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32};
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32};
class Primitive {
public:

View File

@ -65,17 +65,32 @@ int CastCPUKernel::DoCast(int thread_id) {
}
auto offset = thread_id * stride_;
auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
switch (input->data_type()) {
auto output = out_tensors_.at(0);
auto output_data = output->Data();
auto input_data_type = input->data_type();
auto output_data_type = output->data_type();
if (output_data_type != kNumberTypeFloat32) {
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) {
Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset,
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
} else {
MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type;
return RET_ERROR;
}
} else {
switch (input_data_type) {
case kNumberTypeUInt8:
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->Data()) + offset, output_data + offset, data_num);
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->Data()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeInt32:
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset, output_data + offset, data_num);
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupport input data type " << input->data_type();
MS_LOG(ERROR) << "Unsupport input data type " << input_data_type;
return RET_ERROR;
}
}
return RET_OK;
}

View File

@ -47,8 +47,9 @@ int CropCPUKernel::CropParallelRun(int thread_id) {
auto output = out_tensors_[0];
float *input_data = reinterpret_cast<float *>(input->Data());
float *output_data = reinterpret_cast<float *>(output->Data());
Crop4D(input_data, output_data, input->shape().data(), output->shape().data(),
reinterpret_cast<CropParameter *>(op_parameter_));
auto param = reinterpret_cast<CropParameter *>(op_parameter_);
param->thread_id_ = thread_id;
Crop4D(input_data, output_data, input->shape().data(), output->shape().data(), param);
return RET_OK;
}

View File

@ -65,6 +65,10 @@ int SoftmaxCPUKernel::ReSize() {
free(sum_data_);
}
sum_data_ = reinterpret_cast<float *>(malloc(out_plane_size * in_plane_size * sizeof(float)));
if (sum_data_ == nullptr) {
MS_LOG(ERROR) << "malloc data for softmax fail!";
return RET_ERROR;
}
memset(sum_data_, 0, out_plane_size * in_plane_size * sizeof(float));
return RET_OK;
}

View File

@ -27,7 +27,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel {
SoftmaxCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr) {}
~SoftmaxCPUKernel() override {
if (sum_data_ != nullptr) {
free(sum_data_);

View File

@ -40,6 +40,12 @@ void Int32ToFloat32(const int32_t *input, float *output, int number) {
}
}
void Float32ToInt32(const float *input, int32_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (int32_t)input[i];
}
}
#ifdef ENABLE_FP16
void Float32ToFloat16(const float *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {

View File

@ -32,6 +32,7 @@ void Uint8ToFloat32(const uint8_t *input, float *output, int number);
void Uint8ToInt8(const uint8_t *input, int8_t *output, int number);
void Int8ToUint8(const int8_t *input, uint8_t *output, int number);
void Int32ToFloat32(const int32_t *input, float *output, int number);
void Float32ToInt32(const float *input, int32_t *output, int number);
#ifdef ENABLE_FP16
void Float32ToFloat16(const float *input, float16_t *output, int number);
void Float16ToFloat32(const float16_t *input, float *output, int number);