!9109 [MS][LITE]Fix Range

From: @gongdaguo
Reviewed-by: @zhang_xue_tong,@zhanghaibo5,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-11-28 15:35:12 +08:00 committed by Gitee
commit fa45308d67
5 changed files with 77 additions and 25 deletions

View File

@ -16,9 +16,14 @@
#include "nnacl/fp32/range_fp32.h"
void Range(float *output_ptr, int start, int limit, int delta) {
size_t index = 0;
for (size_t i = start; i < limit; i += delta) {
output_ptr[index++] = (float)(i);
void Range(float *output_ptr, float start, float delta, int nums) {
for (int i = 0; i < nums; ++i, start += delta) {
output_ptr[i] = start;
}
}
void RangeInt(int *output_ptr, int start, int delta, int nums) {
for (int i = 0; i < nums; ++i, start += delta) {
output_ptr[i] = start;
}
}

View File

@ -30,7 +30,8 @@ typedef struct RangeParameter {
#ifdef __cplusplus
extern "C" {
#endif
void Range(float *output_ptr, int start, int limit, int delta);
void Range(float *output_ptr, float start, float delta, int nums);
void RangeInt(int *output_ptr, int start, int delta, int nums);
#ifdef __cplusplus
}
#endif

View File

@ -64,7 +64,11 @@ int Range::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(mindspore::kNumberTypeFloat32);
if (inputs_.size() == 3) {
output->set_data_type(input->data_type());
} else {
output->set_data_type(mindspore::kNumberTypeInt32);
}
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
@ -72,14 +76,36 @@ int Range::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
int shape_size = 0;
if (inputs_.size() == 3) {
shape_size = -1;
if ((inputs_.at(0)->data_c() == nullptr) || (inputs_.at(1)->data_c() == nullptr) ||
(inputs_.at(2)->data_c() == nullptr)) {
return RET_INFER_INVALID;
}
switch (inputs_.at(0)->data_type()) {
case kNumberTypeInt:
case kNumberTypeInt32: {
auto start = *reinterpret_cast<int *>(inputs_.at(0)->data_c());
auto limit = *reinterpret_cast<int *>(inputs_.at(1)->data_c());
auto delta = *reinterpret_cast<int *>(inputs_.at(2)->data_c());
shape_size = std::max(static_cast<int>(std::ceil(static_cast<float>(limit - start) / delta)), 0);
} break;
case kNumberTypeFloat32:
case kNumberTypeFloat: {
auto start = *reinterpret_cast<float *>(inputs_.at(0)->data_c());
auto limit = *reinterpret_cast<float *>(inputs_.at(1)->data_c());
auto delta = *reinterpret_cast<float *>(inputs_.at(2)->data_c());
shape_size = std::max(static_cast<int>(std::ceil(static_cast<float>(limit - start) / delta)), 0);
} break;
default: {
MS_LOG(ERROR) << "Range has unsupported dataType: " << inputs_.at(0)->data_type();
return RET_INFER_ERR;
}
}
} else {
shape_size = std::ceil(static_cast<float>(GetLimit() - GetStart()) / GetDelta());
}
std::vector<int> in_shape;
in_shape.push_back(shape_size);
output->set_shape(in_shape);
std::vector<int> in_shape = {shape_size};
output->set_shape(in_shape);
return RET_OK;
}
} // namespace lite

View File

@ -27,29 +27,44 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Range;
namespace mindspore::kernel {
int RangeCPUKernel::Init() { return RET_OK; }
int RangeCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int RangeCPUKernel::ReSize() { return RET_OK; }
int RangeCPUKernel::ReSize() {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16 ||
in_tensors_[0]->data_type() == kNumberTypeFloat) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
}
return RET_OK;
}
int RangeCPUKernel::Run() {
size_t start = (reinterpret_cast<RangeParameter *>(op_parameter_))->start_;
size_t limit = (reinterpret_cast<RangeParameter *>(op_parameter_))->limit_;
size_t delta = (reinterpret_cast<RangeParameter *>(op_parameter_))->delta_;
if (in_tensors_.size() == 3) {
if ((in_tensors_.at(0)->data_type() == mindspore::kNumberTypeInt32) &&
(in_tensors_.at(1)->data_type() == mindspore::kNumberTypeInt32) &&
(in_tensors_.at(2)->data_type() == mindspore::kNumberTypeInt32)) {
start = *reinterpret_cast<int *>(in_tensors_.at(0)->data_c());
limit = *reinterpret_cast<int *>(in_tensors_.at(1)->data_c());
delta = *reinterpret_cast<int *>(in_tensors_.at(2)->data_c());
if (data_type_ == kDataTypeInt) {
RangeInt(reinterpret_cast<int *>(out_tensors_.at(0)->data_c()),
*reinterpret_cast<int *>(in_tensors_.at(0)->data_c()),
*reinterpret_cast<int *>(in_tensors_.at(2)->data_c()), out_tensors_.at(0)->shape()[0]);
} else {
Range(reinterpret_cast<float *>(out_tensors_.at(0)->data_c()),
*reinterpret_cast<float *>(in_tensors_.at(0)->data_c()),
*reinterpret_cast<float *>(in_tensors_.at(2)->data_c()), out_tensors_.at(0)->shape()[0]);
}
} else {
if (data_type_ == kDataTypeInt) {
RangeInt(reinterpret_cast<int *>(out_tensors_.at(0)->data_c()),
(reinterpret_cast<RangeParameter *>(op_parameter_))->start_,
(reinterpret_cast<RangeParameter *>(op_parameter_))->delta_, out_tensors_.at(0)->shape()[0]);
} else {
MS_LOG(ERROR) << "Unsupported parameter type : " << in_tensors_.at(0)->data_type() << ".";
return RET_ERROR;
}
}
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
MS_ASSERT(output_ptr);
Range(output_ptr, start, limit, delta);
return RET_OK;
}
@ -77,5 +92,7 @@ kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector<lite::Tensor *>
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Range, CpuRangeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat, PrimitiveType_Range, CpuRangeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Range, CpuRangeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Range, CpuRangeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -33,6 +33,9 @@ class RangeCPUKernel : public LiteKernel {
int Init() override;
int ReSize() override;
int Run() override;
private:
LiteDataType data_type_ = kDataTypeFloat;
};
} // namespace mindspore::kernel