forked from mindspore-Ecosystem/mindspore
fix bug when zero shape
This commit is contained in:
parent
5a35626b1c
commit
65488f4c10
|
@ -406,8 +406,8 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
|
|||
this->ConvInferShape(input_h, input_w, &output_h, &output_w);
|
||||
|
||||
std::vector<int> out_shape{input_tensor->shape()};
|
||||
out_shape.at(1) = output_h > 0 ? output_h : 1;
|
||||
out_shape.at(2) = output_w > 0 ? output_w : 1;
|
||||
out_shape.at(1) = output_h >= 0 ? output_h : 1;
|
||||
out_shape.at(2) = output_w >= 0 ? output_w : 1;
|
||||
out_shape.at(3) = weight_tensor->shape()[0];
|
||||
out_tensor->set_shape(out_shape);
|
||||
|
||||
|
|
|
@ -66,14 +66,25 @@ PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC
|
|||
Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator);
|
||||
#endif
|
||||
|
||||
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(inputs_.size() == 2 * outputs_.size());
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
InferStatus Merge::AbleToInfer(const std::vector<lite::Tensor *> &inputs) {
|
||||
for (auto &input : inputs) {
|
||||
if (input->shape().empty()) {
|
||||
return HasZeroShape;
|
||||
}
|
||||
for (size_t i = 0; i < inputs_.size() / 2; i++) {
|
||||
auto *input = inputs_[i];
|
||||
auto *output = outputs_[i];
|
||||
if (input->root_tensor() != nullptr && input->root_tensor()->data_c() != nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (input->data_c() == nullptr) {
|
||||
return NotAble;
|
||||
}
|
||||
}
|
||||
return Able;
|
||||
}
|
||||
|
||||
int Merge::Infer(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs) {
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto *input = inputs[i];
|
||||
auto *output = outputs[i];
|
||||
if (input == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is nullptr";
|
||||
return RET_ERROR;
|
||||
|
@ -98,5 +109,35 @@ int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(inputs_.size() == 2 * outputs_.size());
|
||||
for (size_t i = 0; i < outputs_.size(); ++i) {
|
||||
outputs_[i]->set_data_type(inputs_[i]->data_type());
|
||||
}
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
|
||||
std::vector<Tensor *> left_part_inputs{};
|
||||
left_part_inputs.assign(inputs_.begin(), inputs_.begin() + inputs_.size() / 2);
|
||||
|
||||
std::vector<Tensor *> right_part_inputs{};
|
||||
right_part_inputs.assign(inputs_.begin() + inputs_.size() / 2, inputs_.end());
|
||||
|
||||
if (AbleToInfer(left_part_inputs) == Able) {
|
||||
return Infer(left_part_inputs, outputs_);
|
||||
}
|
||||
|
||||
if (AbleToInfer(right_part_inputs) == Able) {
|
||||
return Infer(right_part_inputs, outputs_);
|
||||
}
|
||||
|
||||
if (AbleToInfer(left_part_inputs) == HasZeroShape && AbleToInfer(right_part_inputs) == HasZeroShape) {
|
||||
return Infer(left_part_inputs, outputs_);
|
||||
}
|
||||
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
enum InferStatus { Able, NotAble, HasZeroShape };
|
||||
|
||||
class Merge : public PrimitiveC {
|
||||
public:
|
||||
|
@ -37,6 +38,10 @@ class Merge : public PrimitiveC {
|
|||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
|
||||
private:
|
||||
static InferStatus AbleToInfer(const std::vector<lite::Tensor *> &inputs);
|
||||
static int Infer(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -116,12 +116,12 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector<int> *out_shape) c
|
|||
for (size_t i = 0; i < in_tensor->shape().size(); i++) {
|
||||
in_shape_size *= in_tensor->shape().at(i);
|
||||
}
|
||||
int64_t inferIndex = -1;
|
||||
size_t out_shapeSize = 1;
|
||||
int64_t infer_index = -1;
|
||||
size_t out_shape_size = 1;
|
||||
for (size_t i = 0; i < out_shape->size(); i++) {
|
||||
if (out_shape->at(i) == -1) {
|
||||
if (inferIndex == -1) {
|
||||
inferIndex = i;
|
||||
if (infer_index == -1) {
|
||||
infer_index = i;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "output shape should has no more than one dim which need infer";
|
||||
return RET_INFER_ERR;
|
||||
|
@ -130,18 +130,23 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector<int> *out_shape) c
|
|||
MS_LOG(ERROR) << "output shape dim should be non-negative";
|
||||
return RET_INFER_ERR;
|
||||
} else if (out_shape->at(i) == 0) {
|
||||
if (in_tensor->ElementsNum() != 0) {
|
||||
out_shape->at(i) = in_tensor->shape().at(i);
|
||||
out_shapeSize *= out_shape->at(i);
|
||||
out_shape_size *= out_shape->at(i);
|
||||
} else {
|
||||
out_shapeSize *= out_shape->at(i);
|
||||
out_shape_size = 0;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
out_shape_size *= out_shape->at(i);
|
||||
}
|
||||
}
|
||||
if (inferIndex == -1 && out_shapeSize != in_shape_size) {
|
||||
MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size;
|
||||
if (infer_index == -1 && out_shape_size != in_shape_size) {
|
||||
MS_LOG(ERROR) << "output shapeSize: " << out_shape_size << " should be equal to input shapeSize: " << in_shape_size;
|
||||
return RET_INFER_ERR;
|
||||
}
|
||||
if (inferIndex != -1) {
|
||||
out_shape->at(inferIndex) = in_shape_size / out_shapeSize;
|
||||
if (infer_index != -1) {
|
||||
out_shape->at(infer_index) = in_shape_size / out_shape_size;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
|
|||
}
|
||||
lite::STATUS ret;
|
||||
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
|
||||
ret = MoveTensorLiteData(reinterpret_cast<lite::TensorList *>(dst_tensor),
|
||||
ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor),
|
||||
reinterpret_cast<lite::TensorList *>(src_tensor));
|
||||
} else {
|
||||
ret = MoveTensorData(dst_tensor, src_tensor);
|
||||
|
@ -55,7 +55,13 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
|
|||
int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) {
|
||||
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() ||
|
||||
!(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) {
|
||||
MS_LOG(ERROR) << "input tensor and output tensor is incompatible";
|
||||
MS_LOG(ERROR) << "input tensor and output tensor is incompatible.";
|
||||
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
|
||||
<< "output tensor data_type: " << dst_tensor->data_type()
|
||||
<< "input tensor format: " << src_tensor->format() << " vs "
|
||||
<< "output tensor format: " << dst_tensor->format() << "input tensor shape: " << src_tensor->shape()
|
||||
<< " vs "
|
||||
<< "output tensor shape: " << dst_tensor->shape();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (src_tensor->root_tensor() == nullptr) {
|
||||
|
@ -83,18 +89,19 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
|
||||
int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
|
||||
// shape may change, because tensors.size() can be change in RunGraph
|
||||
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) {
|
||||
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
|
||||
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
|
||||
<< "output tensor data_type: " << dst_tensor->data_type()
|
||||
<< "input tensor format: " << src_tensor->format() << " vs "
|
||||
<< "output tensor format: " << dst_tensor->format();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (dst_tensor->element_shape().empty()) {
|
||||
// when tensorlist malloc is done. this need to check element_shape compatibility
|
||||
dst_tensor->set_element_shape(src_tensor->element_shape());
|
||||
} else if (dst_tensor->element_shape() != src_tensor->element_shape()) {
|
||||
MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto update_data_type = kTypeUnknown;
|
||||
auto dst_tensor_data_type = dst_tensor->tensors_data_type();
|
||||
auto src_tensor_data_type = src_tensor->tensors_data_type();
|
||||
|
|
|
@ -34,7 +34,7 @@ class CarryDataKernel : public LiteKernel {
|
|||
int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end,
|
||||
std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit);
|
||||
static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
|
||||
static int MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor);
|
||||
static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor);
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -146,6 +146,14 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) {
|
|||
}
|
||||
|
||||
int TensorListStackCPUKernel::Run() {
|
||||
if (dtype_ == kTypeUnknown) {
|
||||
dtype_ = input0_->tensors_data_type();
|
||||
#ifdef ENABLE_FP16
|
||||
if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) {
|
||||
dtype_ = kNumberTypeFloat16;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if (CheckParam() != RET_OK) {
|
||||
MS_LOG(ERROR) << "CheckParam failed!";
|
||||
return RET_ERROR;
|
||||
|
@ -169,7 +177,10 @@ int TensorListStackCPUKernel::Run() {
|
|||
MS_ASSERT(out_data != nullptr);
|
||||
for (int i = 0; i < num_element_; ++i) {
|
||||
auto in_ptr = input0_->GetTensor(i);
|
||||
MS_ASSERT(in_ptr != nullptr);
|
||||
if (in_ptr == nullptr) {
|
||||
MS_LOG(DEBUG) << "no need to stack.";
|
||||
continue;
|
||||
}
|
||||
if (in_ptr->data_type() != kTypeUnknown) {
|
||||
int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_);
|
||||
auto in_data = in_ptr->data_c();
|
||||
|
|
|
@ -44,3 +44,4 @@ ml_video_edit_style_transfer_gongnongbing.onnx
|
|||
ml_video_edit_style_transfer_starry.onnx
|
||||
ml_video_edit_judge.onnx
|
||||
ml_video_edit_vignet.onnx
|
||||
ssd_mobilenet_v1_10.onnx;1,383,640,3
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace lite {
|
|||
lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx NonZeroParser";
|
||||
auto attr = std::make_unique<schema::NonZeroT>();
|
||||
auto attr = std::make_unique<schema::WhereT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
|
@ -33,7 +33,7 @@ lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &
|
|||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_NonZero;
|
||||
primitive->value.type = schema::PrimitiveType_Where;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue