!11918 [MS_LITE] support tf model

From: @YeFeng_24
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-02 14:43:56 +08:00 committed by Gitee
commit c2763dbb6f
15 changed files with 528 additions and 76 deletions

View File

@ -78,14 +78,124 @@ Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCr
namespace {
constexpr int kBatchToSpaceOutputNum = 1;
constexpr int kBatchToSpaceInputNum = 1;
constexpr int kBatchToSpaceOneInput = 1;
constexpr int kBatchToSpaceThreeInput = 3;
constexpr int kBlockShapeSize = 2;
constexpr int kCropsSize = 4;
} // namespace
int BatchToSpace::SetOutputShapeFromParam(const std::vector<lite::Tensor *> inputs,
std::vector<lite::Tensor *> outputs) {
auto input_shape = inputs[0]->shape();
if (input_shape.size() != kQuadrupleNum) {
MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum;
return RET_PARAM_INVALID;
}
auto block_shape = GetBlockShape();
if (block_shape.size() != kBlockShapeSize) {
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
return RET_PARAM_INVALID;
}
auto crops = GetCrops();
if (crops.size() != kCropsSize) {
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
return RET_PARAM_INVALID;
}
mul_block_shape_ = 1;
for (size_t i = 0; i < kBlockShapeSize; ++i) {
if (block_shape[i] <= 0) {
MS_LOG(ERROR) << "Input block_shape should > 0!";
return RET_PARAM_INVALID;
}
if (input_shape[NHWC_N] % block_shape[i]) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
<< block_shape[i];
return 1;
}
mul_block_shape_ *= block_shape[i];
}
if (input_shape[NHWC_N] < mul_block_shape_) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
return RET_PARAM_INVALID;
}
for (size_t i = 0; i < kCropsSize; ++i) {
if (crops[i] < 0) {
MS_LOG(ERROR) << "Input crops should >= 0";
return RET_PARAM_INVALID;
}
}
std::vector<int32_t> output_shape(input_shape.size());
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape_;
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
if (input_shape.size() > 3) {
output_shape[NHWC_C] = input_shape[NHWC_C];
}
outputs[0]->set_shape(output_shape);
return RET_OK;
}
int BatchToSpace::SetOutputShapeFromInput(const std::vector<lite::Tensor *> inputs,
std::vector<lite::Tensor *> outputs) {
auto input_shape = inputs[0]->shape();
if (input_shape.size() != kQuadrupleNum) {
MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum;
return RET_PARAM_INVALID;
}
auto block_shape_data = inputs[1]->data_c();
auto crops_data = inputs[2]->data_c();
auto block_shape = static_cast<int *>(block_shape_data);
auto crops = static_cast<int *>(crops_data);
if (inputs[1]->ElementsNum() != kBlockShapeSize) {
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
return RET_PARAM_INVALID;
}
if (inputs[2]->ElementsNum() != kCropsSize) {
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
return RET_PARAM_INVALID;
}
mul_block_shape_ = 1;
for (size_t i = 0; i < kBlockShapeSize; ++i) {
if (block_shape[i] <= 0) {
MS_LOG(ERROR) << "Input block_shape should > 0!";
return RET_PARAM_INVALID;
}
if (input_shape[NHWC_N] % block_shape[i]) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
<< block_shape[i];
return 1;
}
mul_block_shape_ *= block_shape[i];
}
if (input_shape[NHWC_N] < mul_block_shape_) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
return RET_PARAM_INVALID;
}
for (size_t i = 0; i < kCropsSize; ++i) {
if (crops[i] < 0) {
MS_LOG(ERROR) << "Input crops should >= 0";
return RET_PARAM_INVALID;
}
}
std::vector<int32_t> output_shape(input_shape.size());
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape_;
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
if (input_shape.size() > 3) {
output_shape[NHWC_C] = input_shape[NHWC_C];
}
outputs[0]->set_shape(output_shape);
return RET_OK;
}
int BatchToSpace::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) {
if (outputs.size() != kBatchToSpaceOutputNum ||
(inputs.size() != kBatchToSpaceOneInput && inputs.size() != kBatchToSpaceThreeInput)) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return RET_PARAM_INVALID;
}
@ -100,54 +210,21 @@ int BatchToSpace::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
if (!infer_flag()) {
return RET_INFER_INVALID;
}
auto input_shape = input->shape();
if (input_shape.size() != kQuadrupleNum) {
MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum;
return RET_PARAM_INVALID;
}
auto block_shape = GetBlockShape();
if (block_shape.size() != kBlockShapeSize) {
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
return RET_PARAM_INVALID;
if (inputs.size() == kBatchToSpaceOneInput) {
auto ret = SetOutputShapeFromParam(inputs, outputs);
return ret;
}
auto crops = GetCrops();
if (crops.size() != kCropsSize) {
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
return RET_PARAM_INVALID;
}
int mul_block_shape = 1;
for (size_t i = 0; i < kBlockShapeSize; ++i) {
if (block_shape[i] <= 0) {
MS_LOG(ERROR) << "Input block_shape should > 0!";
return RET_PARAM_INVALID;
if (inputs.size() == kBatchToSpaceThreeInput) {
if (inputs[0]->data_c() == nullptr) {
return RET_INFER_INVALID;
}
if (input_shape[NHWC_N] % block_shape[i]) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
<< block_shape[i];
return 1;
}
mul_block_shape *= block_shape[i];
MS_ASSERT(inputs[1]->data_c() != nullptr);
MS_ASSERT(inputs[2]->data_c() != nullptr);
auto ret = SetOutputShapeFromInput(inputs, outputs);
return ret;
}
if (input_shape[NHWC_N] < mul_block_shape) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
return RET_PARAM_INVALID;
}
for (size_t i = 0; i < kCropsSize; ++i) {
if (crops[i] < 0) {
MS_LOG(ERROR) << "Input crops should >= 0";
return RET_PARAM_INVALID;
}
}
std::vector<int32_t> output_shape(input_shape.size());
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape;
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
output_shape[NHWC_C] = input_shape[NHWC_C];
outputs[0]->set_shape(output_shape);
return RET_OK;
}
} // namespace lite

View File

@ -40,6 +40,11 @@ class BatchToSpace : public PrimitiveC {
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
std::vector<int> GetBlockShape() const;
std::vector<int> GetCrops() const;
private:
int SetOutputShapeFromParam(const std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
int SetOutputShapeFromInput(const std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
int mul_block_shape_;
};
} // namespace lite
} // namespace mindspore

View File

@ -396,6 +396,9 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
return RET_INFER_INVALID;
}
auto in_shape = input_tensor->shape();
if (in_shape.size() == 0) {
return RET_INFER_INVALID;
}
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
int output_w = 0, output_h = 0;

View File

@ -34,6 +34,9 @@ OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *pr
batch_space_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::BatchToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
auto block_shape = param->GetBlockShape();
if (block_shape.empty()) {
return reinterpret_cast<OpParameter *>(batch_space_param);
}
if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
free(batch_space_param);
@ -41,6 +44,9 @@ OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *pr
}
auto crops = param->GetCrops();
if (crops.empty()) {
return reinterpret_cast<OpParameter *>(batch_space_param);
}
if (crops.size() != COMM_SHAPE_SIZE) {
MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE;
free(batch_space_param);

View File

@ -31,6 +31,9 @@ OpParameter *PopulateSpaceToBatchNDParameter(const mindspore::lite::PrimitiveC *
space_batch_param_nd->op_parameter_.type_ = primitive->Type();
auto block_sizes = ((mindspore::lite::SpaceToBatchND *)primitive)->GetBlockShape();
if (block_sizes.empty()) {
return reinterpret_cast<OpParameter *>(space_batch_param_nd);
}
space_batch_param_nd->m_ = block_sizes.size();
if (block_sizes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) {
MS_LOG(ERROR) << "The value of block_sizes.size() is too big";
@ -39,6 +42,9 @@ OpParameter *PopulateSpaceToBatchNDParameter(const mindspore::lite::PrimitiveC *
}
memcpy(space_batch_param_nd->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int));
auto paddings = ((mindspore::lite::SpaceToBatchND *)primitive)->GetPaddings();
if (paddings.empty()) {
return reinterpret_cast<OpParameter *>(space_batch_param_nd);
}
if (paddings.size() > std::numeric_limits<size_t>::max() / sizeof(int)) {
MS_LOG(ERROR) << "The value of paddings.size() is too big";
free(space_batch_param_nd);

View File

@ -26,7 +26,8 @@ namespace mindspore {
namespace lite {
namespace {
constexpr int kSpaceToBatchNDOutputNum = 1;
constexpr int kSpaceToBatchNDInputNum = 1;
constexpr int kSpaceToBatchNDOneInput = 1;
constexpr int kSpaceToBatchNDThreeInput = 3;
} // namespace
#ifdef PRIMITIVE_WRITEABLE
@ -86,23 +87,9 @@ Registry SpaceToBatchNDRegistry(schema::PrimitiveType_SpaceToBatchND, SpaceToBat
#endif // PRIMITIVE_WRITEABLE
int SpaceToBatchND::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return 1;
}
auto input = inputs.at(0);
if (input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "space_to_batch_nd only support NHWC now!";
return RET_ERROR;
}
outputs.at(0)->set_data_type(input->data_type());
outputs.at(0)->set_format(input->format());
if (!infer_flag()) {
return RET_INFER_INVALID;
}
auto input_shape = input->shape();
int SpaceToBatchND::SetOutputShapeFromParam(const std::vector<lite::Tensor *> inputs,
std::vector<lite::Tensor *> outputs) {
auto input_shape = inputs[0]->shape();
if (input_shape.size() != kQuadrupleNum) {
MS_LOG(ERROR) << "input shape dimension size only support " << kQuadrupleNum << " now!";
return RET_ERROR;
@ -133,9 +120,94 @@ int SpaceToBatchND::InferShape(std::vector<lite::Tensor *> inputs, std::vector<l
return RET_ERROR;
}
output_shape.at(NHWC_W) = (input_shape.at(NHWC_W) + padding_left + padding_right) / block_w;
output_shape.at(NHWC_C) = input_shape.at(NHWC_C);
if (input_shape.size() > 3) {
output_shape.at(NHWC_C) = input_shape.at(NHWC_C);
}
outputs.at(0)->set_shape(output_shape);
return RET_OK;
}
int SpaceToBatchND::SetOutputShapeFromInput(const std::vector<lite::Tensor *> inputs,
std::vector<lite::Tensor *> outputs) {
auto input_shape = inputs[0]->shape();
if (input_shape.size() != kQuadrupleNum) {
MS_LOG(ERROR) << "input shape dimension size only support " << kQuadrupleNum << " now!";
return RET_ERROR;
}
MS_ASSERT(inputs[2]->ElementsNum() == 4);
auto block_shape_data = inputs[1]->data_c();
auto block_shape = static_cast<int *>(block_shape_data);
auto padding_data = inputs[2]->data_c();
auto padding = static_cast<int *>(padding_data);
int padding_left = 0;
int padding_right = 0;
int block_w = 1;
if (inputs[1]->ElementsNum() == 2) {
padding_left = padding[2];
padding_right = padding[3];
block_w = block_shape[1];
}
std::vector<int32_t> output_shape(input_shape.size());
if (block_shape[0] * block_w > std::numeric_limits<int>::max() / input_shape.at(NHWC_N)) {
MS_LOG(ERROR) << "The value of block_shape.at(0) * block_w is too big";
return RET_ERROR;
}
output_shape.at(NHWC_N) = input_shape.at(NHWC_N) * block_shape[0] * block_w;
if (padding[0] + padding[1] > std::numeric_limits<int>::max() - input_shape.at(NHWC_H)) {
MS_LOG(ERROR) << "The value of padding.at(0) + padding.at(1) is too big";
return RET_ERROR;
}
output_shape.at(NHWC_H) = (input_shape.at(NHWC_H) + padding[0] + padding[1]) / block_shape[0];
if (padding_left + padding_right > std::numeric_limits<int>::max() - input_shape.at(NHWC_W)) {
MS_LOG(ERROR) << "The value of padding_left + padding_right is too big";
return RET_ERROR;
}
output_shape.at(NHWC_W) = (input_shape.at(NHWC_W) + padding_left + padding_right) / block_w;
if (input_shape.size() > 3) {
output_shape.at(NHWC_C) = input_shape.at(NHWC_C);
}
outputs.at(0)->set_shape(output_shape);
return RET_OK;
}
int SpaceToBatchND::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (outputs.size() != kSpaceToBatchNDOutputNum ||
(inputs.size() != kSpaceToBatchNDOneInput && inputs.size() != kSpaceToBatchNDThreeInput)) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return 1;
}
auto input = inputs.at(0);
if (input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "space_to_batch_nd only support NHWC now!";
return RET_ERROR;
}
outputs.at(0)->set_data_type(input->data_type());
outputs.at(0)->set_format(input->format());
if (!infer_flag()) {
return RET_INFER_INVALID;
}
if (inputs.size() == kSpaceToBatchNDOneInput) {
auto ret = SetOutputShapeFromParam(inputs, outputs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOutputShapeFromParam failed";
return ret;
}
}
if (inputs.size() == kSpaceToBatchNDThreeInput) {
if (inputs[0]->data_c() == nullptr) {
return RET_INFER_INVALID;
}
MS_ASSERT(inputs[1]->data_c() != nullptr);
MS_ASSERT(inputs[2]->data_c() != nullptr);
auto ret = SetOutputShapeFromInput(inputs, outputs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOutputShapeFromInput failed";
return ret;
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -39,6 +39,8 @@ class SpaceToBatchND : public PrimitiveC {
std::vector<int> GetBlockShape() const;
std::vector<int> GetPaddings() const;
int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) override;
int SetOutputShapeFromParam(const std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
int SetOutputShapeFromInput(const std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
};
} // namespace lite
} // namespace mindspore

View File

@ -16,23 +16,46 @@
#include "src/runtime/kernel/arm/fp32/batch_to_space_fp32.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/ops/batch_to_space.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BatchToSpace;
using mindspore::schema::PrimitiveType_BatchToSpaceND;
namespace mindspore::kernel {
int BatchToSpaceCPUKernel::Processinput() {
MS_ASSERT(in_tensors_[1]->data_c() != nullptr);
MS_ASSERT(in_tensors_[2]->data_c() != nullptr);
auto block_shape_data = in_tensors_[1]->data_c();
auto crops_data = in_tensors_[2]->data_c();
auto block_shape = static_cast<int *>(block_shape_data);
auto crops = static_cast<int *>(crops_data);
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
block_shape_[i] = block_shape[i];
}
no_crop_ = true;
for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
crops_[i] = crops[i];
if (crops_[i] != 0) {
no_crop_ = false;
}
}
return RET_OK;
}
int BatchToSpaceCPUKernel::Init() {
MS_ASSERT(in_tensors_.at(0)->format() == schema::Format::Format_NHWC);
if (!InferShapeDone()) {
return lite::RET_OK;
return RET_OK;
}
return ReSize();
}
int BatchToSpaceCPUKernel::ReSize() {
MS_ASSERT(in_tensors_.at(0)->shape().size() == 4);
return lite::RET_OK;
return RET_OK;
}
int BatchToSpaceCPUKernel::Run() {
@ -42,17 +65,29 @@ int BatchToSpaceCPUKernel::Run() {
float *output_data = reinterpret_cast<float *>(output->MutableData());
auto in_shape = input->shape();
auto out_shape = output->shape();
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
if (param->no_crop_) {
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
sizeof(float));
} else {
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
sizeof(float));
if (in_tensors_.size() == 1) {
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
if (param->no_crop_) {
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
sizeof(float));
} else {
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
sizeof(float));
}
}
return lite::RET_OK;
if (in_tensors_.size() == 3) {
auto ret = Processinput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Processinput failed in BatchToSpace.";
return ret;
}
if (no_crop_) {
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(float));
} else {
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_, sizeof(float));
}
}
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, LiteKernelCreator<BatchToSpaceCPUKernel>)

View File

@ -34,6 +34,12 @@ class BatchToSpaceCPUKernel : public LiteKernel {
int Init() override;
int ReSize() override;
int Run() override;
int Processinput();
private:
int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE];
int32_t crops_[COMM_SHAPE_SIZE];
bool no_crop_;
};
} // namespace mindspore::kernel

View File

@ -25,6 +25,32 @@ using mindspore::schema::PrimitiveType_SpaceToBatch;
using mindspore::schema::PrimitiveType_SpaceToBatchND;
namespace mindspore::kernel {
void SpaceToBatchCPUKernel::ProcessInput() {
MS_ASSERT(in_tensors_[1] != nullptr);
MS_ASSERT(in_tensors_[2] != nullptr);
auto input_tensor = in_tensors_.at(0);
MS_ASSERT(input_tensor);
auto output_tensor = out_tensors_.at(0);
MS_ASSERT(output_tensor);
MS_ASSERT(param_);
for (size_t i = 0; i < DIMENSION_4D; i++) {
param_->input_shape_[i] = input_tensor->shape().at(i);
param_->output_shape_[i] = output_tensor->shape().at(i);
}
ComputeStrides(param_->input_shape_, param_->in_stride_, DIMENSION_4D);
ComputeStrides(param_->output_shape_, param_->out_stride_, DIMENSION_4D);
auto block_shape_data = in_tensors_[1]->data_c();
auto block_shape = static_cast<int *>(block_shape_data);
for (int i = 0; i < in_tensors_[1]->ElementsNum(); i++) {
param_->block_sizes_[i] = block_shape[i];
}
auto padding_data = in_tensors_[2]->data_c();
auto padding = static_cast<int *>(padding_data);
for (int i = 0; i < in_tensors_[2]->ElementsNum(); i++) {
param_->paddings_[i] = padding[i];
}
}
int SpaceToBatchCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
@ -39,6 +65,12 @@ int SpaceToBatchFp32Run(void *cdata, int task_id) {
}
int SpaceToBatchCPUKernel::ReSize() {
if (in_tensors_.size() == 3) {
if (in_tensors_[1] != nullptr && in_tensors_[1]->IsConst() && in_tensors_[2] != nullptr &&
in_tensors_[2]->IsConst()) {
ProcessInput();
}
}
auto input_tensor = in_tensors_.at(0);
MS_ASSERT(input_tensor);
auto output_tensor = out_tensors_.at(0);
@ -61,8 +93,14 @@ void SpaceToBatchCPUKernel::DoRun(int task_id) {
}
int SpaceToBatchCPUKernel::Run() {
MS_ASSERT(in_tensors_[0] != nullptr);
input_ptr_ = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
output_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
if (in_tensors_.size() == 3) {
if (!in_tensors_[1]->IsConst() || !in_tensors_[2]->IsConst()) {
ProcessInput();
}
}
ParallelLaunch(this->context_->thread_pool_, SpaceToBatchFp32Run, this, op_parameter_->thread_num_);

View File

@ -35,6 +35,7 @@ class SpaceToBatchCPUKernel : public LiteKernel {
int Init() override;
int ReSize() override;
int Run() override;
void ProcessInput();
public:
void DoRun(int task_id);

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_batch_to_space_nd_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFBatchToSpaceNDParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(WARNING) << "TF BatchToSpaceNDParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::BatchToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_BatchToSpace;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfBatchToSpaceNDParser("BatchToSpaceND", new TFBatchToSpaceNDParser());
TFNodeRegistrar g_tfBatchToSpaceParser("BatchToSpace", new TFBatchToSpaceNDParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFBatchToSpaceNDParser : public TFNodeParser {
public:
TFBatchToSpaceNDParser() = default;
~TFBatchToSpaceNDParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_space_to_batch_nd_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFSpaceToBatchNDParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(WARNING) << "TF SpaceToBatchNDParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::SpaceToBatchNDT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_SpaceToBatchND;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfSpaceToBatchNDParser("SpaceToBatchND", new TFSpaceToBatchNDParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFSpaceToBatchNDParser : public TFNodeParser {
public:
TFSpaceToBatchNDParser() = default;
~TFSpaceToBatchNDParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_