forked from mindspore-Ecosystem/mindspore
!11918 [MS_LITE] support tf model
From: @YeFeng_24 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c2763dbb6f
|
@ -78,14 +78,124 @@ Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCr
|
|||
|
||||
namespace {
|
||||
constexpr int kBatchToSpaceOutputNum = 1;
|
||||
constexpr int kBatchToSpaceInputNum = 1;
|
||||
constexpr int kBatchToSpaceOneInput = 1;
|
||||
constexpr int kBatchToSpaceThreeInput = 3;
|
||||
constexpr int kBlockShapeSize = 2;
|
||||
constexpr int kCropsSize = 4;
|
||||
} // namespace
|
||||
|
||||
int BatchToSpace::SetOutputShapeFromParam(const std::vector<lite::Tensor *> inputs,
|
||||
std::vector<lite::Tensor *> outputs) {
|
||||
auto input_shape = inputs[0]->shape();
|
||||
if (input_shape.size() != kQuadrupleNum) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto block_shape = GetBlockShape();
|
||||
if (block_shape.size() != kBlockShapeSize) {
|
||||
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto crops = GetCrops();
|
||||
if (crops.size() != kCropsSize) {
|
||||
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
mul_block_shape_ = 1;
|
||||
|
||||
for (size_t i = 0; i < kBlockShapeSize; ++i) {
|
||||
if (block_shape[i] <= 0) {
|
||||
MS_LOG(ERROR) << "Input block_shape should > 0!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (input_shape[NHWC_N] % block_shape[i]) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
|
||||
<< block_shape[i];
|
||||
return 1;
|
||||
}
|
||||
mul_block_shape_ *= block_shape[i];
|
||||
}
|
||||
|
||||
if (input_shape[NHWC_N] < mul_block_shape_) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
for (size_t i = 0; i < kCropsSize; ++i) {
|
||||
if (crops[i] < 0) {
|
||||
MS_LOG(ERROR) << "Input crops should >= 0";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape_;
|
||||
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
|
||||
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
|
||||
if (input_shape.size() > 3) {
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C];
|
||||
}
|
||||
outputs[0]->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BatchToSpace::SetOutputShapeFromInput(const std::vector<lite::Tensor *> inputs,
|
||||
std::vector<lite::Tensor *> outputs) {
|
||||
auto input_shape = inputs[0]->shape();
|
||||
if (input_shape.size() != kQuadrupleNum) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto block_shape_data = inputs[1]->data_c();
|
||||
auto crops_data = inputs[2]->data_c();
|
||||
auto block_shape = static_cast<int *>(block_shape_data);
|
||||
auto crops = static_cast<int *>(crops_data);
|
||||
if (inputs[1]->ElementsNum() != kBlockShapeSize) {
|
||||
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (inputs[2]->ElementsNum() != kCropsSize) {
|
||||
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
mul_block_shape_ = 1;
|
||||
|
||||
for (size_t i = 0; i < kBlockShapeSize; ++i) {
|
||||
if (block_shape[i] <= 0) {
|
||||
MS_LOG(ERROR) << "Input block_shape should > 0!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (input_shape[NHWC_N] % block_shape[i]) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
|
||||
<< block_shape[i];
|
||||
return 1;
|
||||
}
|
||||
mul_block_shape_ *= block_shape[i];
|
||||
}
|
||||
|
||||
if (input_shape[NHWC_N] < mul_block_shape_) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
for (size_t i = 0; i < kCropsSize; ++i) {
|
||||
if (crops[i] < 0) {
|
||||
MS_LOG(ERROR) << "Input crops should >= 0";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape_;
|
||||
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
|
||||
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
|
||||
if (input_shape.size() > 3) {
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C];
|
||||
}
|
||||
outputs[0]->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BatchToSpace::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) {
|
||||
if (outputs.size() != kBatchToSpaceOutputNum ||
|
||||
(inputs.size() != kBatchToSpaceOneInput && inputs.size() != kBatchToSpaceThreeInput)) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
@ -100,54 +210,21 @@ int BatchToSpace::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kQuadrupleNum) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto block_shape = GetBlockShape();
|
||||
if (block_shape.size() != kBlockShapeSize) {
|
||||
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
|
||||
return RET_PARAM_INVALID;
|
||||
if (inputs.size() == kBatchToSpaceOneInput) {
|
||||
auto ret = SetOutputShapeFromParam(inputs, outputs);
|
||||
return ret;
|
||||
}
|
||||
auto crops = GetCrops();
|
||||
if (crops.size() != kCropsSize) {
|
||||
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
int mul_block_shape = 1;
|
||||
|
||||
for (size_t i = 0; i < kBlockShapeSize; ++i) {
|
||||
if (block_shape[i] <= 0) {
|
||||
MS_LOG(ERROR) << "Input block_shape should > 0!";
|
||||
return RET_PARAM_INVALID;
|
||||
if (inputs.size() == kBatchToSpaceThreeInput) {
|
||||
if (inputs[0]->data_c() == nullptr) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
if (input_shape[NHWC_N] % block_shape[i]) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
|
||||
<< block_shape[i];
|
||||
return 1;
|
||||
}
|
||||
mul_block_shape *= block_shape[i];
|
||||
MS_ASSERT(inputs[1]->data_c() != nullptr);
|
||||
MS_ASSERT(inputs[2]->data_c() != nullptr);
|
||||
auto ret = SetOutputShapeFromInput(inputs, outputs);
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (input_shape[NHWC_N] < mul_block_shape) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
for (size_t i = 0; i < kCropsSize; ++i) {
|
||||
if (crops[i] < 0) {
|
||||
MS_LOG(ERROR) << "Input crops should >= 0";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape;
|
||||
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
|
||||
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C];
|
||||
|
||||
outputs[0]->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -40,6 +40,11 @@ class BatchToSpace : public PrimitiveC {
|
|||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
std::vector<int> GetBlockShape() const;
|
||||
std::vector<int> GetCrops() const;
|
||||
|
||||
private:
|
||||
int SetOutputShapeFromParam(const std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
|
||||
int SetOutputShapeFromInput(const std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
|
||||
int mul_block_shape_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -396,6 +396,9 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
|
|||
return RET_INFER_INVALID;
|
||||
}
|
||||
auto in_shape = input_tensor->shape();
|
||||
if (in_shape.size() == 0) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
int input_h = in_shape.at(1);
|
||||
int input_w = in_shape.at(2);
|
||||
int output_w = 0, output_h = 0;
|
||||
|
|
|
@ -34,6 +34,9 @@ OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *pr
|
|||
batch_space_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::BatchToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
auto block_shape = param->GetBlockShape();
|
||||
if (block_shape.empty()) {
|
||||
return reinterpret_cast<OpParameter *>(batch_space_param);
|
||||
}
|
||||
if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
|
||||
MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
|
||||
free(batch_space_param);
|
||||
|
@ -41,6 +44,9 @@ OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *pr
|
|||
}
|
||||
|
||||
auto crops = param->GetCrops();
|
||||
if (crops.empty()) {
|
||||
return reinterpret_cast<OpParameter *>(batch_space_param);
|
||||
}
|
||||
if (crops.size() != COMM_SHAPE_SIZE) {
|
||||
MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE;
|
||||
free(batch_space_param);
|
||||
|
|
|
@ -31,6 +31,9 @@ OpParameter *PopulateSpaceToBatchNDParameter(const mindspore::lite::PrimitiveC *
|
|||
|
||||
space_batch_param_nd->op_parameter_.type_ = primitive->Type();
|
||||
auto block_sizes = ((mindspore::lite::SpaceToBatchND *)primitive)->GetBlockShape();
|
||||
if (block_sizes.empty()) {
|
||||
return reinterpret_cast<OpParameter *>(space_batch_param_nd);
|
||||
}
|
||||
space_batch_param_nd->m_ = block_sizes.size();
|
||||
if (block_sizes.size() > std::numeric_limits<size_t>::max() / sizeof(int)) {
|
||||
MS_LOG(ERROR) << "The value of block_sizes.size() is too big";
|
||||
|
@ -39,6 +42,9 @@ OpParameter *PopulateSpaceToBatchNDParameter(const mindspore::lite::PrimitiveC *
|
|||
}
|
||||
memcpy(space_batch_param_nd->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int));
|
||||
auto paddings = ((mindspore::lite::SpaceToBatchND *)primitive)->GetPaddings();
|
||||
if (paddings.empty()) {
|
||||
return reinterpret_cast<OpParameter *>(space_batch_param_nd);
|
||||
}
|
||||
if (paddings.size() > std::numeric_limits<size_t>::max() / sizeof(int)) {
|
||||
MS_LOG(ERROR) << "The value of paddings.size() is too big";
|
||||
free(space_batch_param_nd);
|
||||
|
|
|
@ -26,7 +26,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
namespace {
|
||||
constexpr int kSpaceToBatchNDOutputNum = 1;
|
||||
constexpr int kSpaceToBatchNDInputNum = 1;
|
||||
constexpr int kSpaceToBatchNDOneInput = 1;
|
||||
constexpr int kSpaceToBatchNDThreeInput = 3;
|
||||
} // namespace
|
||||
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -86,23 +87,9 @@ Registry SpaceToBatchNDRegistry(schema::PrimitiveType_SpaceToBatchND, SpaceToBat
|
|||
|
||||
#endif // PRIMITIVE_WRITEABLE
|
||||
|
||||
int SpaceToBatchND::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input = inputs.at(0);
|
||||
if (input->format() != schema::Format::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "space_to_batch_nd only support NHWC now!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
outputs.at(0)->set_data_type(input->data_type());
|
||||
outputs.at(0)->set_format(input->format());
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
int SpaceToBatchND::SetOutputShapeFromParam(const std::vector<lite::Tensor *> inputs,
|
||||
std::vector<lite::Tensor *> outputs) {
|
||||
auto input_shape = inputs[0]->shape();
|
||||
if (input_shape.size() != kQuadrupleNum) {
|
||||
MS_LOG(ERROR) << "input shape dimension size only support " << kQuadrupleNum << " now!";
|
||||
return RET_ERROR;
|
||||
|
@ -133,9 +120,94 @@ int SpaceToBatchND::InferShape(std::vector<lite::Tensor *> inputs, std::vector<l
|
|||
return RET_ERROR;
|
||||
}
|
||||
output_shape.at(NHWC_W) = (input_shape.at(NHWC_W) + padding_left + padding_right) / block_w;
|
||||
output_shape.at(NHWC_C) = input_shape.at(NHWC_C);
|
||||
if (input_shape.size() > 3) {
|
||||
output_shape.at(NHWC_C) = input_shape.at(NHWC_C);
|
||||
}
|
||||
outputs.at(0)->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SpaceToBatchND::SetOutputShapeFromInput(const std::vector<lite::Tensor *> inputs,
|
||||
std::vector<lite::Tensor *> outputs) {
|
||||
auto input_shape = inputs[0]->shape();
|
||||
if (input_shape.size() != kQuadrupleNum) {
|
||||
MS_LOG(ERROR) << "input shape dimension size only support " << kQuadrupleNum << " now!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(inputs[2]->ElementsNum() == 4);
|
||||
auto block_shape_data = inputs[1]->data_c();
|
||||
auto block_shape = static_cast<int *>(block_shape_data);
|
||||
auto padding_data = inputs[2]->data_c();
|
||||
auto padding = static_cast<int *>(padding_data);
|
||||
int padding_left = 0;
|
||||
int padding_right = 0;
|
||||
int block_w = 1;
|
||||
if (inputs[1]->ElementsNum() == 2) {
|
||||
padding_left = padding[2];
|
||||
padding_right = padding[3];
|
||||
block_w = block_shape[1];
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
if (block_shape[0] * block_w > std::numeric_limits<int>::max() / input_shape.at(NHWC_N)) {
|
||||
MS_LOG(ERROR) << "The value of block_shape.at(0) * block_w is too big";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_shape.at(NHWC_N) = input_shape.at(NHWC_N) * block_shape[0] * block_w;
|
||||
if (padding[0] + padding[1] > std::numeric_limits<int>::max() - input_shape.at(NHWC_H)) {
|
||||
MS_LOG(ERROR) << "The value of padding.at(0) + padding.at(1) is too big";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_shape.at(NHWC_H) = (input_shape.at(NHWC_H) + padding[0] + padding[1]) / block_shape[0];
|
||||
if (padding_left + padding_right > std::numeric_limits<int>::max() - input_shape.at(NHWC_W)) {
|
||||
MS_LOG(ERROR) << "The value of padding_left + padding_right is too big";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_shape.at(NHWC_W) = (input_shape.at(NHWC_W) + padding_left + padding_right) / block_w;
|
||||
if (input_shape.size() > 3) {
|
||||
output_shape.at(NHWC_C) = input_shape.at(NHWC_C);
|
||||
}
|
||||
outputs.at(0)->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SpaceToBatchND::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
if (outputs.size() != kSpaceToBatchNDOutputNum ||
|
||||
(inputs.size() != kSpaceToBatchNDOneInput && inputs.size() != kSpaceToBatchNDThreeInput)) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input = inputs.at(0);
|
||||
if (input->format() != schema::Format::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "space_to_batch_nd only support NHWC now!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
outputs.at(0)->set_data_type(input->data_type());
|
||||
outputs.at(0)->set_format(input->format());
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
|
||||
if (inputs.size() == kSpaceToBatchNDOneInput) {
|
||||
auto ret = SetOutputShapeFromParam(inputs, outputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetOutputShapeFromParam failed";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
if (inputs.size() == kSpaceToBatchNDThreeInput) {
|
||||
if (inputs[0]->data_c() == nullptr) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
MS_ASSERT(inputs[1]->data_c() != nullptr);
|
||||
MS_ASSERT(inputs[2]->data_c() != nullptr);
|
||||
auto ret = SetOutputShapeFromInput(inputs, outputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetOutputShapeFromInput failed";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,6 +39,8 @@ class SpaceToBatchND : public PrimitiveC {
|
|||
std::vector<int> GetBlockShape() const;
|
||||
std::vector<int> GetPaddings() const;
|
||||
int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) override;
|
||||
int SetOutputShapeFromParam(const std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
|
||||
int SetOutputShapeFromInput(const std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,23 +16,46 @@
|
|||
#include "src/runtime/kernel/arm/fp32/batch_to_space_fp32.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/ops/batch_to_space.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_BatchToSpace;
|
||||
using mindspore::schema::PrimitiveType_BatchToSpaceND;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int BatchToSpaceCPUKernel::Processinput() {
|
||||
MS_ASSERT(in_tensors_[1]->data_c() != nullptr);
|
||||
MS_ASSERT(in_tensors_[2]->data_c() != nullptr);
|
||||
auto block_shape_data = in_tensors_[1]->data_c();
|
||||
auto crops_data = in_tensors_[2]->data_c();
|
||||
auto block_shape = static_cast<int *>(block_shape_data);
|
||||
auto crops = static_cast<int *>(crops_data);
|
||||
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
|
||||
block_shape_[i] = block_shape[i];
|
||||
}
|
||||
no_crop_ = true;
|
||||
for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
|
||||
crops_[i] = crops[i];
|
||||
if (crops_[i] != 0) {
|
||||
no_crop_ = false;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BatchToSpaceCPUKernel::Init() {
|
||||
MS_ASSERT(in_tensors_.at(0)->format() == schema::Format::Format_NHWC);
|
||||
if (!InferShapeDone()) {
|
||||
return lite::RET_OK;
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int BatchToSpaceCPUKernel::ReSize() {
|
||||
MS_ASSERT(in_tensors_.at(0)->shape().size() == 4);
|
||||
return lite::RET_OK;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BatchToSpaceCPUKernel::Run() {
|
||||
|
@ -42,17 +65,29 @@ int BatchToSpaceCPUKernel::Run() {
|
|||
float *output_data = reinterpret_cast<float *>(output->MutableData());
|
||||
auto in_shape = input->shape();
|
||||
auto out_shape = output->shape();
|
||||
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
|
||||
|
||||
if (param->no_crop_) {
|
||||
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
|
||||
sizeof(float));
|
||||
} else {
|
||||
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
|
||||
sizeof(float));
|
||||
if (in_tensors_.size() == 1) {
|
||||
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
|
||||
if (param->no_crop_) {
|
||||
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
|
||||
sizeof(float));
|
||||
} else {
|
||||
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
|
||||
sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
return lite::RET_OK;
|
||||
if (in_tensors_.size() == 3) {
|
||||
auto ret = Processinput();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Processinput failed in BatchToSpace.";
|
||||
return ret;
|
||||
}
|
||||
if (no_crop_) {
|
||||
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(float));
|
||||
} else {
|
||||
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_, sizeof(float));
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, LiteKernelCreator<BatchToSpaceCPUKernel>)
|
||||
|
|
|
@ -34,6 +34,12 @@ class BatchToSpaceCPUKernel : public LiteKernel {
|
|||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int Processinput();
|
||||
|
||||
private:
|
||||
int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE];
|
||||
int32_t crops_[COMM_SHAPE_SIZE];
|
||||
bool no_crop_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -25,6 +25,32 @@ using mindspore::schema::PrimitiveType_SpaceToBatch;
|
|||
using mindspore::schema::PrimitiveType_SpaceToBatchND;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void SpaceToBatchCPUKernel::ProcessInput() {
|
||||
MS_ASSERT(in_tensors_[1] != nullptr);
|
||||
MS_ASSERT(in_tensors_[2] != nullptr);
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
MS_ASSERT(input_tensor);
|
||||
auto output_tensor = out_tensors_.at(0);
|
||||
MS_ASSERT(output_tensor);
|
||||
MS_ASSERT(param_);
|
||||
for (size_t i = 0; i < DIMENSION_4D; i++) {
|
||||
param_->input_shape_[i] = input_tensor->shape().at(i);
|
||||
param_->output_shape_[i] = output_tensor->shape().at(i);
|
||||
}
|
||||
ComputeStrides(param_->input_shape_, param_->in_stride_, DIMENSION_4D);
|
||||
ComputeStrides(param_->output_shape_, param_->out_stride_, DIMENSION_4D);
|
||||
auto block_shape_data = in_tensors_[1]->data_c();
|
||||
auto block_shape = static_cast<int *>(block_shape_data);
|
||||
for (int i = 0; i < in_tensors_[1]->ElementsNum(); i++) {
|
||||
param_->block_sizes_[i] = block_shape[i];
|
||||
}
|
||||
auto padding_data = in_tensors_[2]->data_c();
|
||||
auto padding = static_cast<int *>(padding_data);
|
||||
for (int i = 0; i < in_tensors_[2]->ElementsNum(); i++) {
|
||||
param_->paddings_[i] = padding[i];
|
||||
}
|
||||
}
|
||||
|
||||
int SpaceToBatchCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
|
@ -39,6 +65,12 @@ int SpaceToBatchFp32Run(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int SpaceToBatchCPUKernel::ReSize() {
|
||||
if (in_tensors_.size() == 3) {
|
||||
if (in_tensors_[1] != nullptr && in_tensors_[1]->IsConst() && in_tensors_[2] != nullptr &&
|
||||
in_tensors_[2]->IsConst()) {
|
||||
ProcessInput();
|
||||
}
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
MS_ASSERT(input_tensor);
|
||||
auto output_tensor = out_tensors_.at(0);
|
||||
|
@ -61,8 +93,14 @@ void SpaceToBatchCPUKernel::DoRun(int task_id) {
|
|||
}
|
||||
|
||||
int SpaceToBatchCPUKernel::Run() {
|
||||
MS_ASSERT(in_tensors_[0] != nullptr);
|
||||
input_ptr_ = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||
output_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
if (in_tensors_.size() == 3) {
|
||||
if (!in_tensors_[1]->IsConst() || !in_tensors_[2]->IsConst()) {
|
||||
ProcessInput();
|
||||
}
|
||||
}
|
||||
|
||||
ParallelLaunch(this->context_->thread_pool_, SpaceToBatchFp32Run, this, op_parameter_->thread_num_);
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ class SpaceToBatchCPUKernel : public LiteKernel {
|
|||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
void ProcessInput();
|
||||
|
||||
public:
|
||||
void DoRun(int task_id);
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_batch_to_space_nd_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFBatchToSpaceNDParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(WARNING) << "TF BatchToSpaceNDParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::BatchToSpaceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_BatchToSpace;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); ++i) {
|
||||
auto status = AddOpInput(tf_op, i, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfBatchToSpaceNDParser("BatchToSpaceND", new TFBatchToSpaceNDParser());
|
||||
TFNodeRegistrar g_tfBatchToSpaceParser("BatchToSpace", new TFBatchToSpaceNDParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFBatchToSpaceNDParser : public TFNodeParser {
|
||||
public:
|
||||
TFBatchToSpaceNDParser() = default;
|
||||
~TFBatchToSpaceNDParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_space_to_batch_nd_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFSpaceToBatchNDParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(WARNING) << "TF SpaceToBatchNDParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::SpaceToBatchNDT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_SpaceToBatchND;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); ++i) {
|
||||
auto status = AddOpInput(tf_op, i, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfSpaceToBatchNDParser("SpaceToBatchND", new TFSpaceToBatchNDParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFSpaceToBatchNDParser : public TFNodeParser {
|
||||
public:
|
||||
TFSpaceToBatchNDParser() = default;
|
||||
~TFSpaceToBatchNDParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_
|
Loading…
Reference in New Issue