!9282 [lite]fix strided slice multi inputs bug
From: @xu_anyue Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiang
This commit is contained in:
commit
44ea3902b8
|
@ -29,12 +29,6 @@ using mindspore::lite::RET_OK;
|
||||||
using mindspore::schema::PrimitiveType_StridedSlice;
|
using mindspore::schema::PrimitiveType_StridedSlice;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
namespace {
|
|
||||||
constexpr size_t kMultiInputsSize = 4;
|
|
||||||
constexpr size_t kBeginsIndex = 1;
|
|
||||||
constexpr size_t kEndsIndex = 2;
|
|
||||||
constexpr size_t kStridesInex = 3;
|
|
||||||
} // namespace
|
|
||||||
int StridedSliceCPUKernel::Init() {
|
int StridedSliceCPUKernel::Init() {
|
||||||
if (!InferShapeDone()) {
|
if (!InferShapeDone()) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -57,38 +51,6 @@ int StridedSliceCPUKernel::ReSize() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int StridedSliceCPUKernel::HandleMultiInputs() {
|
|
||||||
if (in_tensors_.size() != kMultiInputsSize) {
|
|
||||||
MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size();
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (param_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "StridedSliceParamater cast nullptr";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
auto begins = in_tensors_.at(kBeginsIndex);
|
|
||||||
MS_ASSERT(begins != nullptr);
|
|
||||||
int axis_num = begins->ElementsNum();
|
|
||||||
if (axis_num > DIMENSION_6D) {
|
|
||||||
MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
memcpy(param_->begins_, begins->MutableData(), axis_num * sizeof(int));
|
|
||||||
|
|
||||||
auto ends = in_tensors_.at(kEndsIndex);
|
|
||||||
MS_ASSERT(ends != nullptr);
|
|
||||||
MS_ASSERT(axis_num == ends->ElementsNum());
|
|
||||||
memcpy(param_->ends_, ends->MutableData(), axis_num * sizeof(int));
|
|
||||||
|
|
||||||
auto strides = in_tensors_.at(kStridesInex);
|
|
||||||
MS_ASSERT(strides != nullptr);
|
|
||||||
MS_ASSERT(axis_num == strides->ElementsNum());
|
|
||||||
memcpy(param_->strides_, strides->MutableData(), axis_num * sizeof(int));
|
|
||||||
|
|
||||||
param_->num_axes_ = axis_num;
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
int StridedSliceCPUKernel::Run() {
|
int StridedSliceCPUKernel::Run() {
|
||||||
auto input = in_tensors_.at(0);
|
auto input = in_tensors_.at(0);
|
||||||
MS_ASSERT(input);
|
MS_ASSERT(input);
|
||||||
|
@ -108,13 +70,6 @@ int StridedSliceCPUKernel::Run() {
|
||||||
}
|
}
|
||||||
auto output = out_tensors_.at(0);
|
auto output = out_tensors_.at(0);
|
||||||
MS_ASSERT(output);
|
MS_ASSERT(output);
|
||||||
// inputs order: input, begin, end, stride
|
|
||||||
if (in_tensors().size() == kMultiInputsSize) {
|
|
||||||
auto ret = HandleMultiInputs();
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_);
|
auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]";
|
MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]";
|
||||||
|
|
|
@ -36,9 +36,6 @@ class StridedSliceCPUKernel : public LiteKernel {
|
||||||
int ReSize() override;
|
int ReSize() override;
|
||||||
int Run() override;
|
int Run() override;
|
||||||
|
|
||||||
private:
|
|
||||||
int HandleMultiInputs();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
StridedSliceParameter *param_;
|
StridedSliceParameter *param_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -60,7 +60,6 @@ bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce ||
|
if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce ||
|
||||||
opt::GetCNodeType(node) == schema::PrimitiveType_StridedSlice ||
|
|
||||||
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin ||
|
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin ||
|
||||||
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax ||
|
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax ||
|
||||||
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch ||
|
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch ||
|
||||||
|
|
Loading…
Reference in New Issue