forked from mindspore-Ecosystem/mindspore
cast support fp32->int
This commit is contained in:
parent
13c2b23356
commit
4b8c9da7e7
|
@ -37,7 +37,7 @@ constexpr uint32_t kNHWC_w_index = 2;
|
|||
constexpr uint32_t kNHWC_c_index = 3;
|
||||
constexpr uint32_t kDimension_4d = 4;
|
||||
|
||||
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32};
|
||||
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32};
|
||||
|
||||
class Primitive {
|
||||
public:
|
||||
|
|
|
@ -65,17 +65,32 @@ int CastCPUKernel::DoCast(int thread_id) {
|
|||
}
|
||||
|
||||
auto offset = thread_id * stride_;
|
||||
auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
|
||||
switch (input->data_type()) {
|
||||
auto output = out_tensors_.at(0);
|
||||
auto output_data = output->Data();
|
||||
auto input_data_type = input->data_type();
|
||||
auto output_data_type = output->data_type();
|
||||
if (output_data_type != kNumberTypeFloat32) {
|
||||
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) {
|
||||
Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset,
|
||||
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
switch (input_data_type) {
|
||||
case kNumberTypeUInt8:
|
||||
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->Data()) + offset, output_data + offset, data_num);
|
||||
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->Data()) + offset,
|
||||
reinterpret_cast<float *>(output_data) + offset, data_num);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset, output_data + offset, data_num);
|
||||
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset,
|
||||
reinterpret_cast<float *>(output_data) + offset, data_num);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport input data type " << input->data_type();
|
||||
MS_LOG(ERROR) << "Unsupport input data type " << input_data_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -47,8 +47,9 @@ int CropCPUKernel::CropParallelRun(int thread_id) {
|
|||
auto output = out_tensors_[0];
|
||||
float *input_data = reinterpret_cast<float *>(input->Data());
|
||||
float *output_data = reinterpret_cast<float *>(output->Data());
|
||||
Crop4D(input_data, output_data, input->shape().data(), output->shape().data(),
|
||||
reinterpret_cast<CropParameter *>(op_parameter_));
|
||||
auto param = reinterpret_cast<CropParameter *>(op_parameter_);
|
||||
param->thread_id_ = thread_id;
|
||||
Crop4D(input_data, output_data, input->shape().data(), output->shape().data(), param);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,6 +65,10 @@ int SoftmaxCPUKernel::ReSize() {
|
|||
free(sum_data_);
|
||||
}
|
||||
sum_data_ = reinterpret_cast<float *>(malloc(out_plane_size * in_plane_size * sizeof(float)));
|
||||
if (sum_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc data for softmax fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(sum_data_, 0, out_plane_size * in_plane_size * sizeof(float));
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel {
|
|||
SoftmaxCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr) {}
|
||||
~SoftmaxCPUKernel() override {
|
||||
if (sum_data_ != nullptr) {
|
||||
free(sum_data_);
|
||||
|
|
|
@ -40,6 +40,12 @@ void Int32ToFloat32(const int32_t *input, float *output, int number) {
|
|||
}
|
||||
}
|
||||
|
||||
void Float32ToInt32(const float *input, int32_t *output, int number) {
|
||||
for (int i = 0; i < number; ++i) {
|
||||
output[i] = (int32_t)input[i];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP16
|
||||
void Float32ToFloat16(const float *input, float16_t *output, int number) {
|
||||
for (int i = 0; i < number; ++i) {
|
||||
|
|
|
@ -32,6 +32,7 @@ void Uint8ToFloat32(const uint8_t *input, float *output, int number);
|
|||
void Uint8ToInt8(const uint8_t *input, int8_t *output, int number);
|
||||
void Int8ToUint8(const int8_t *input, uint8_t *output, int number);
|
||||
void Int32ToFloat32(const int32_t *input, float *output, int number);
|
||||
void Float32ToInt32(const float *input, int32_t *output, int number);
|
||||
#ifdef ENABLE_FP16
|
||||
void Float32ToFloat16(const float *input, float16_t *output, int number);
|
||||
void Float16ToFloat32(const float16_t *input, float *output, int number);
|
||||
|
|
Loading…
Reference in New Issue