fix bug when zero shape

This commit is contained in:
mengyuanli 2021-01-28 16:28:26 +08:00
parent 5a35626b1c
commit 65488f4c10
9 changed files with 103 additions and 33 deletions

View File

@ -406,8 +406,8 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
this->ConvInferShape(input_h, input_w, &output_h, &output_w);
std::vector<int> out_shape{input_tensor->shape()};
out_shape.at(1) = output_h > 0 ? output_h : 1;
out_shape.at(2) = output_w > 0 ? output_w : 1;
out_shape.at(1) = output_h >= 0 ? output_h : 1;
out_shape.at(2) = output_w >= 0 ? output_w : 1;
out_shape.at(3) = weight_tensor->shape()[0];
out_tensor->set_shape(out_shape);

View File

@ -66,14 +66,25 @@ PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC
Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator);
#endif
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(inputs_.size() == 2 * outputs_.size());
if (!infer_flag()) {
return RET_INFER_INVALID;
InferStatus Merge::AbleToInfer(const std::vector<lite::Tensor *> &inputs) {
for (auto &input : inputs) {
if (input->shape().empty()) {
return HasZeroShape;
}
if (input->root_tensor() != nullptr && input->root_tensor()->data_c() != nullptr) {
continue;
}
if (input->data_c() == nullptr) {
return NotAble;
}
}
for (size_t i = 0; i < inputs_.size() / 2; i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
return Able;
}
int Merge::Infer(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs) {
for (size_t i = 0; i < inputs.size(); i++) {
auto *input = inputs[i];
auto *output = outputs[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
@ -98,5 +109,35 @@ int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
}
return RET_OK;
}
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(inputs_.size() == 2 * outputs_.size());
for (size_t i = 0; i < outputs_.size(); ++i) {
outputs_[i]->set_data_type(inputs_[i]->data_type());
}
if (!infer_flag()) {
return RET_INFER_INVALID;
}
std::vector<Tensor *> left_part_inputs{};
left_part_inputs.assign(inputs_.begin(), inputs_.begin() + inputs_.size() / 2);
std::vector<Tensor *> right_part_inputs{};
right_part_inputs.assign(inputs_.begin() + inputs_.size() / 2, inputs_.end());
if (AbleToInfer(left_part_inputs) == Able) {
return Infer(left_part_inputs, outputs_);
}
if (AbleToInfer(right_part_inputs) == Able) {
return Infer(right_part_inputs, outputs_);
}
if (AbleToInfer(left_part_inputs) == HasZeroShape && AbleToInfer(right_part_inputs) == HasZeroShape) {
return Infer(left_part_inputs, outputs_);
}
return RET_INFER_INVALID;
}
} // namespace lite
} // namespace mindspore

View File

@ -24,6 +24,7 @@
namespace mindspore {
namespace lite {
enum InferStatus { Able, NotAble, HasZeroShape };
class Merge : public PrimitiveC {
public:
@ -37,6 +38,10 @@ class Merge : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
private:
static InferStatus AbleToInfer(const std::vector<lite::Tensor *> &inputs);
static int Infer(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs);
};
} // namespace lite
} // namespace mindspore

View File

@ -116,12 +116,12 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector<int> *out_shape) c
for (size_t i = 0; i < in_tensor->shape().size(); i++) {
in_shape_size *= in_tensor->shape().at(i);
}
int64_t inferIndex = -1;
size_t out_shapeSize = 1;
int64_t infer_index = -1;
size_t out_shape_size = 1;
for (size_t i = 0; i < out_shape->size(); i++) {
if (out_shape->at(i) == -1) {
if (inferIndex == -1) {
inferIndex = i;
if (infer_index == -1) {
infer_index = i;
} else {
MS_LOG(ERROR) << "output shape should has no more than one dim which need infer";
return RET_INFER_ERR;
@ -130,18 +130,23 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector<int> *out_shape) c
MS_LOG(ERROR) << "output shape dim should be non-negative";
return RET_INFER_ERR;
} else if (out_shape->at(i) == 0) {
out_shape->at(i) = in_tensor->shape().at(i);
out_shapeSize *= out_shape->at(i);
if (in_tensor->ElementsNum() != 0) {
out_shape->at(i) = in_tensor->shape().at(i);
out_shape_size *= out_shape->at(i);
} else {
out_shape_size = 0;
break;
}
} else {
out_shapeSize *= out_shape->at(i);
out_shape_size *= out_shape->at(i);
}
}
if (inferIndex == -1 && out_shapeSize != in_shape_size) {
MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size;
if (infer_index == -1 && out_shape_size != in_shape_size) {
MS_LOG(ERROR) << "output shapeSize: " << out_shape_size << " should be equal to input shapeSize: " << in_shape_size;
return RET_INFER_ERR;
}
if (inferIndex != -1) {
out_shape->at(inferIndex) = in_shape_size / out_shapeSize;
if (infer_index != -1) {
out_shape->at(infer_index) = in_shape_size / out_shape_size;
}
return RET_OK;
}

View File

@ -39,7 +39,7 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
}
lite::STATUS ret;
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
ret = MoveTensorLiteData(reinterpret_cast<lite::TensorList *>(dst_tensor),
ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor),
reinterpret_cast<lite::TensorList *>(src_tensor));
} else {
ret = MoveTensorData(dst_tensor, src_tensor);
@ -55,7 +55,13 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) {
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() ||
!(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) {
MS_LOG(ERROR) << "input tensor and output tensor is incompatible";
MS_LOG(ERROR) << "input tensor and output tensor is incompatible.";
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
<< "output tensor data_type: " << dst_tensor->data_type()
<< "input tensor format: " << src_tensor->format() << " vs "
<< "output tensor format: " << dst_tensor->format() << "input tensor shape: " << src_tensor->shape()
<< " vs "
<< "output tensor shape: " << dst_tensor->shape();
return RET_ERROR;
}
if (src_tensor->root_tensor() == nullptr) {
@ -83,18 +89,19 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_
return RET_OK;
}
int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
// shape may change, because tensors.size() can be change in RunGraph
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
<< "output tensor data_type: " << dst_tensor->data_type()
<< "input tensor format: " << src_tensor->format() << " vs "
<< "output tensor format: " << dst_tensor->format();
return RET_ERROR;
}
if (dst_tensor->element_shape().empty()) {
dst_tensor->set_element_shape(src_tensor->element_shape());
} else if (dst_tensor->element_shape() != src_tensor->element_shape()) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible";
return RET_ERROR;
}
// when tensorlist malloc is done. this need to check element_shape compatibility
dst_tensor->set_element_shape(src_tensor->element_shape());
auto update_data_type = kTypeUnknown;
auto dst_tensor_data_type = dst_tensor->tensors_data_type();
auto src_tensor_data_type = src_tensor->tensors_data_type();

View File

@ -34,7 +34,7 @@ class CarryDataKernel : public LiteKernel {
int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end,
std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit);
static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
static int MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor);
static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor);
};
} // namespace mindspore::kernel

View File

@ -146,6 +146,14 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) {
}
int TensorListStackCPUKernel::Run() {
if (dtype_ == kTypeUnknown) {
dtype_ = input0_->tensors_data_type();
#ifdef ENABLE_FP16
if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) {
dtype_ = kNumberTypeFloat16;
}
#endif
}
if (CheckParam() != RET_OK) {
MS_LOG(ERROR) << "CheckParam failed!";
return RET_ERROR;
@ -169,7 +177,10 @@ int TensorListStackCPUKernel::Run() {
MS_ASSERT(out_data != nullptr);
for (int i = 0; i < num_element_; ++i) {
auto in_ptr = input0_->GetTensor(i);
MS_ASSERT(in_ptr != nullptr);
if (in_ptr == nullptr) {
MS_LOG(DEBUG) << "no need to stack.";
continue;
}
if (in_ptr->data_type() != kTypeUnknown) {
int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_);
auto in_data = in_ptr->data_c();

View File

@ -44,3 +44,4 @@ ml_video_edit_style_transfer_gongnongbing.onnx
ml_video_edit_style_transfer_starry.onnx
ml_video_edit_judge.onnx
ml_video_edit_vignet.onnx
ssd_mobilenet_v1_10.onnx;1,383,640,3

View File

@ -23,7 +23,7 @@ namespace lite {
lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx NonZeroParser";
auto attr = std::make_unique<schema::NonZeroT>();
auto attr = std::make_unique<schema::WhereT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
@ -33,7 +33,7 @@ lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_NonZero;
primitive->value.type = schema::PrimitiveType_Where;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}