diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index b5a6799db0b..a4bdcc9eae6 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -150,6 +150,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { auto tensor = metaGraphT->allTensors[input].get(); if (tensor->data.empty()) { tensor->nodeType = schema::NodeType_ValueNode; + tensor->format = schema::Format_NHWC; // tensor->refCount = lite::MSCONST_WEIGHT_REFCOUNT; metaGraphT->inputIndex.emplace_back(input); } diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc index 6b9252cd39f..c8a8e4e0a83 100644 --- a/mindspore/lite/src/ops/addn.cc +++ b/mindspore/lite/src/ops/addn.cc @@ -36,6 +36,7 @@ int AddN::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); output->set_shape(input->shape()); output->set_data_type(input->data_type()); return RET_OK; diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc index af94e597e7e..d71910e438e 100644 --- a/mindspore/lite/src/ops/argmax.cc +++ b/mindspore/lite/src/ops/argmax.cc @@ -40,6 +40,7 @@ int ArgMax::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); output->set_shape(output_shape); output->set_data_type(input->data_type()); return RET_OK; diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index 2323af643f4..b501b14b1c9 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -39,9 +39,9 @@ int ArgMin::InferShape(std::vector inputs_, std::vector output_shape(input->shape()); output_shape.erase(output_shape.begin() + axis); + output->SetFormat(input->GetFormat()); output->set_shape(output_shape); output->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/arithmetic.cc b/mindspore/lite/src/ops/arithmetic.cc index 2a6ce1320e2..2b15e22608c 100644 --- a/mindspore/lite/src/ops/arithmetic.cc +++ b/mindspore/lite/src/ops/arithmetic.cc @@ -39,7 +39,7 @@ int Arithmetic::InferShape(std::vector inputs_, std::vectorshape(); auto input_shape1 = input1->shape(); - + auto format = input0->GetFormat(); in_shape0_.resize(5); in_shape1_.resize(5); out_shape_.resize(5); @@ -57,6 +57,7 @@ int Arithmetic::InferShape(std::vector inputs_, std::vectorGetFormat(); } else if (input_shape0.size() > input_shape1.size()) { ndim_ = input_shape0.size(); auto fill_dim_num = input_shape0.size() - input_shape1.size(); @@ -93,7 +94,7 @@ int Arithmetic::InferShape(std::vector inputs_, std::vectorSetFormat(format); output->set_shape(output_shape); output->set_data_type(input0->data_type()); return RET_OK; diff --git a/mindspore/lite/src/ops/arithmetic_self.cc b/mindspore/lite/src/ops/arithmetic_self.cc index 567a190f6a6..3d2210e746c 100644 --- a/mindspore/lite/src/ops/arithmetic_self.cc +++ b/mindspore/lite/src/ops/arithmetic_self.cc @@ -26,9 +26,11 @@ int ArithmeticSelf::InferShape(std::vector inputs_, std::vecto MS_ASSERT(input != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); + + output->SetFormat(input->GetFormat()); output->set_shape(input->shape()); output->set_data_type(input->data_type()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc index a3ca0b2b49d..41412c58c35 100644 --- a/mindspore/lite/src/ops/batch_to_space.cc +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -85,9 +85,10 @@ int BatchToSpace::InferShape(std::vector inputs, std::vectorGet(0) - crops->Get(0) - crops->Get(1); output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3); output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index]; + + outputs[0]->SetFormat(input->GetFormat()); outputs[0]->set_shape(output_shape); outputs[0]->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc index 225e34d6147..51e59146770 100644 --- a/mindspore/lite/src/ops/broadcast_to.cc +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -58,9 +58,9 @@ int BroadcastTo::InferShape(std::vector inputs, std::vectorSetFormat(input->GetFormat()); outputs[0]->set_shape(shape); outputs[0]->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 796f80cbee0..13de84ff5e8 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -44,9 +44,9 @@ int Cast::InferShape(std::vector inputs_, std::vectordstT(); return RET_INPUT_TENSOR_ERROR; } + output->SetFormat(input->GetFormat()); output->set_shape(input->shape()); output->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index 2d966676d50..e69e2707a35 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -70,7 +70,8 @@ int Concat::InferShape(std::vector inputs_, std::vectorset_shape(output_shape); output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc index b58b8a27f44..dceab29f9bf 100644 --- a/mindspore/lite/src/ops/crop.cc +++ b/mindspore/lite/src/ops/crop.cc @@ -32,7 +32,8 @@ int Crop::InferShape(std::vector inputs, std::vectorset_shape(inputs[1]->shape()); + outputs[0]->SetFormat(inputs[1]->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc index 025c1ad3603..f09fddfb588 100644 --- a/mindspore/lite/src/ops/depth_to_space.cc +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -23,7 +23,7 @@ namespace mindspore::lite { namespace { constexpr int kDepthToSpaceOutputNum = 1; constexpr int kDepthToSpaceInputNum = 1; -} +} // namespace int DepthToSpace::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive != nullptr); @@ -56,7 +56,8 @@ int DepthToSpace::InferShape(std::vector inputs, std::vectorset_shape(output_shape); outputs[0]->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index 588710f886c..5b0391d6541 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -45,7 +45,8 @@ int ExpandDims::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc index 361a5e2b8df..f4bd0c1952f 100644 --- a/mindspore/lite/src/ops/fill.cc +++ b/mindspore/lite/src/ops/fill.cc @@ -42,7 +42,8 @@ int Fill::InferShape(std::vector inputs_, std::vectordims()->begin(), fill_prim->dims()->end()); output->set_shape(output_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc index c2264afcf93..bde0cd16c56 100644 --- a/mindspore/lite/src/ops/flatten.cc +++ b/mindspore/lite/src/ops/flatten.cc @@ -43,7 +43,8 @@ int Flatten::InferShape(std::vector inputs_, std::vectorset_shape(output_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/fullconnection.cc b/mindspore/lite/src/ops/fullconnection.cc index 0b44faecfd8..7b4b1e051fe 100644 --- a/mindspore/lite/src/ops/fullconnection.cc +++ b/mindspore/lite/src/ops/fullconnection.cc @@ -56,7 +56,8 @@ int FullConnection::InferShape(std::vector inputs_, std::vecto out_shape[fc_prim->axis()] = input1->shape()[0]; output->set_shape(out_shape); output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index 0e5cd619186..328de9ba2f8 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -71,7 +71,8 @@ int Gather::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc index 4f5598817bf..681e2d207ba 100644 --- a/mindspore/lite/src/ops/gather_nd.cc +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -59,7 +59,8 @@ int GatherNd::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index d7cb772f41b..2d031378bfb 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -57,7 +57,8 @@ int MatMul::InferShape(std::vector inputs_, std::vectorset_shape(y_shape); output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index eb96edad891..878813c9951 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -67,6 +67,8 @@ int OneHot::InferShape(std::vector inputs, std::vectorset_data_type(on_value->data_type()); + output->SetFormat(on_value->GetFormat()); + return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/ops.cc b/mindspore/lite/src/ops/ops.cc index 642b2898fbb..63fe262aa62 100644 --- a/mindspore/lite/src/ops/ops.cc +++ b/mindspore/lite/src/ops/ops.cc @@ -138,7 +138,8 @@ int Primitive::InferShape(std::vector inputs_, std::vectorset_shape(input->shape()); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 8604da24e3f..3bdbe04235e 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -55,9 +55,9 @@ int Pad::InferShape(std::vector inputs, std::vectorSetFormat(input->GetFormat()); output->set_shape(output_shape); output->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index c25a558bbce..20745e7cd0b 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -74,6 +74,7 @@ int Pooling::InferShape(std::vector inputs_, std::vectorset_shape(input_shape); output->set_data_type(input->data_type()); + // todo: temp fix output->SetFormat(schema::Format_NHWC); return RET_OK; diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index 4adafa26898..53180a8d518 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -34,7 +34,8 @@ int Range::InferShape(std::vector inputs_, std::vectorset_shape(in_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/rank.cc b/mindspore/lite/src/ops/rank.cc index c7b70930b14..5939396d163 100644 --- a/mindspore/lite/src/ops/rank.cc +++ b/mindspore/lite/src/ops/rank.cc @@ -29,7 +29,8 @@ int Rank::InferShape(std::vector inputs_, std::vector in_shape(1, 1); output->set_shape(in_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 888a61df875..76ce8199776 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -73,6 +73,8 @@ int Reduce::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 3e2a9c6eef8..1358769bb37 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -114,7 +114,8 @@ int Reshape::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index 7dd387c6369..9b7edd2d9e6 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -45,7 +45,8 @@ int Resize::InferShape(std::vector inputs_, std::vectorChannel()); output->set_shape(output_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/scatter_nd.cc b/mindspore/lite/src/ops/scatter_nd.cc index cf9f4dfbc35..446a37d8720 100644 --- a/mindspore/lite/src/ops/scatter_nd.cc +++ b/mindspore/lite/src/ops/scatter_nd.cc @@ -57,7 +57,8 @@ int ScatterND::InferShape(std::vector inputs_, std::vector out_shape(shape_data, shape_data + sizeof(shape_data) / sizeof(shape_data[0])); output->set_shape(out_shape); output->set_data_type(update->data_type()); + output->SetFormat(update->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 994fabb1cc5..c1fe1d396db 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -23,7 +23,7 @@ namespace mindspore::lite { namespace { constexpr int kSliceInputNum = 1; constexpr int kSliceOutputNum = 1; -} +} // namespace int Slice::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive != nullptr); @@ -47,13 +47,13 @@ int Slice::InferShape(std::vector inputs, std::vector