forked from mindspore-Ecosystem/mindspore
!9109 [MS][LITE]Fix Range
From: @gongdaguo Reviewed-by: @zhang_xue_tong,@zhanghaibo5,@zhanghaibo5 Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
fa45308d67
|
@ -16,9 +16,14 @@
|
|||
|
||||
#include "nnacl/fp32/range_fp32.h"
|
||||
|
||||
void Range(float *output_ptr, int start, int limit, int delta) {
|
||||
size_t index = 0;
|
||||
for (size_t i = start; i < limit; i += delta) {
|
||||
output_ptr[index++] = (float)(i);
|
||||
void Range(float *output_ptr, float start, float delta, int nums) {
|
||||
for (int i = 0; i < nums; ++i, start += delta) {
|
||||
output_ptr[i] = start;
|
||||
}
|
||||
}
|
||||
|
||||
void RangeInt(int *output_ptr, int start, int delta, int nums) {
|
||||
for (int i = 0; i < nums; ++i, start += delta) {
|
||||
output_ptr[i] = start;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,8 @@ typedef struct RangeParameter {
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void Range(float *output_ptr, int start, int limit, int delta);
|
||||
void Range(float *output_ptr, float start, float delta, int nums);
|
||||
void RangeInt(int *output_ptr, int start, int delta, int nums);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -64,7 +64,11 @@ int Range::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
|||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->set_data_type(mindspore::kNumberTypeFloat32);
|
||||
if (inputs_.size() == 3) {
|
||||
output->set_data_type(input->data_type());
|
||||
} else {
|
||||
output->set_data_type(mindspore::kNumberTypeInt32);
|
||||
}
|
||||
output->set_format(input->format());
|
||||
if (!infer_flag()) {
|
||||
return RET_OK;
|
||||
|
@ -72,14 +76,36 @@ int Range::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
|||
|
||||
int shape_size = 0;
|
||||
if (inputs_.size() == 3) {
|
||||
shape_size = -1;
|
||||
if ((inputs_.at(0)->data_c() == nullptr) || (inputs_.at(1)->data_c() == nullptr) ||
|
||||
(inputs_.at(2)->data_c() == nullptr)) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
switch (inputs_.at(0)->data_type()) {
|
||||
case kNumberTypeInt:
|
||||
case kNumberTypeInt32: {
|
||||
auto start = *reinterpret_cast<int *>(inputs_.at(0)->data_c());
|
||||
auto limit = *reinterpret_cast<int *>(inputs_.at(1)->data_c());
|
||||
auto delta = *reinterpret_cast<int *>(inputs_.at(2)->data_c());
|
||||
shape_size = std::max(static_cast<int>(std::ceil(static_cast<float>(limit - start) / delta)), 0);
|
||||
} break;
|
||||
case kNumberTypeFloat32:
|
||||
case kNumberTypeFloat: {
|
||||
auto start = *reinterpret_cast<float *>(inputs_.at(0)->data_c());
|
||||
auto limit = *reinterpret_cast<float *>(inputs_.at(1)->data_c());
|
||||
auto delta = *reinterpret_cast<float *>(inputs_.at(2)->data_c());
|
||||
shape_size = std::max(static_cast<int>(std::ceil(static_cast<float>(limit - start) / delta)), 0);
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Range has unsupported dataType: " << inputs_.at(0)->data_type();
|
||||
return RET_INFER_ERR;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
shape_size = std::ceil(static_cast<float>(GetLimit() - GetStart()) / GetDelta());
|
||||
}
|
||||
std::vector<int> in_shape;
|
||||
in_shape.push_back(shape_size);
|
||||
output->set_shape(in_shape);
|
||||
|
||||
std::vector<int> in_shape = {shape_size};
|
||||
output->set_shape(in_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -27,29 +27,44 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_Range;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int RangeCPUKernel::Init() { return RET_OK; }
|
||||
int RangeCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int RangeCPUKernel::ReSize() { return RET_OK; }
|
||||
int RangeCPUKernel::ReSize() {
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16 ||
|
||||
in_tensors_[0]->data_type() == kNumberTypeFloat) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
} else {
|
||||
data_type_ = kDataTypeInt;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int RangeCPUKernel::Run() {
|
||||
size_t start = (reinterpret_cast<RangeParameter *>(op_parameter_))->start_;
|
||||
size_t limit = (reinterpret_cast<RangeParameter *>(op_parameter_))->limit_;
|
||||
size_t delta = (reinterpret_cast<RangeParameter *>(op_parameter_))->delta_;
|
||||
if (in_tensors_.size() == 3) {
|
||||
if ((in_tensors_.at(0)->data_type() == mindspore::kNumberTypeInt32) &&
|
||||
(in_tensors_.at(1)->data_type() == mindspore::kNumberTypeInt32) &&
|
||||
(in_tensors_.at(2)->data_type() == mindspore::kNumberTypeInt32)) {
|
||||
start = *reinterpret_cast<int *>(in_tensors_.at(0)->data_c());
|
||||
limit = *reinterpret_cast<int *>(in_tensors_.at(1)->data_c());
|
||||
delta = *reinterpret_cast<int *>(in_tensors_.at(2)->data_c());
|
||||
if (data_type_ == kDataTypeInt) {
|
||||
RangeInt(reinterpret_cast<int *>(out_tensors_.at(0)->data_c()),
|
||||
*reinterpret_cast<int *>(in_tensors_.at(0)->data_c()),
|
||||
*reinterpret_cast<int *>(in_tensors_.at(2)->data_c()), out_tensors_.at(0)->shape()[0]);
|
||||
} else {
|
||||
Range(reinterpret_cast<float *>(out_tensors_.at(0)->data_c()),
|
||||
*reinterpret_cast<float *>(in_tensors_.at(0)->data_c()),
|
||||
*reinterpret_cast<float *>(in_tensors_.at(2)->data_c()), out_tensors_.at(0)->shape()[0]);
|
||||
}
|
||||
} else {
|
||||
if (data_type_ == kDataTypeInt) {
|
||||
RangeInt(reinterpret_cast<int *>(out_tensors_.at(0)->data_c()),
|
||||
(reinterpret_cast<RangeParameter *>(op_parameter_))->start_,
|
||||
(reinterpret_cast<RangeParameter *>(op_parameter_))->delta_, out_tensors_.at(0)->shape()[0]);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported parameter type : " << in_tensors_.at(0)->data_type() << ".";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
MS_ASSERT(output_ptr);
|
||||
Range(output_ptr, start, limit, delta);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -77,5 +92,7 @@ kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Range, CpuRangeFp32KernelCreator)
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat, PrimitiveType_Range, CpuRangeFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Range, CpuRangeFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Range, CpuRangeFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -33,6 +33,9 @@ class RangeCPUKernel : public LiteKernel {
|
|||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
LiteDataType data_type_ = kDataTypeFloat;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue