fix bug for anf_exporter graph input tensor format and op output format

This commit is contained in:
cjh9368 2020-07-29 17:24:24 +08:00
parent 1b69923472
commit 78c9122897
44 changed files with 100 additions and 60 deletions

View File

@ -150,6 +150,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto tensor = metaGraphT->allTensors[input].get();
if (tensor->data.empty()) {
tensor->nodeType = schema::NodeType_ValueNode;
tensor->format = schema::Format_NHWC;
// tensor->refCount = lite::MSCONST_WEIGHT_REFCOUNT;
metaGraphT->inputIndex.emplace_back(input);
}

View File

@ -36,6 +36,7 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
return RET_INPUT_TENSOR_ERROR;
}
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return RET_OK;

View File

@ -40,6 +40,7 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
output_shape.erase(output_shape.begin() + axis);
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;

View File

@ -39,9 +39,9 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std::vector<int> output_shape(input->shape());
output_shape.erase(output_shape.begin() + axis);
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -39,7 +39,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto input_shape0 = input0->shape();
auto input_shape1 = input1->shape();
auto format = input0->GetFormat();
in_shape0_.resize(5);
in_shape1_.resize(5);
out_shape_.resize(5);
@ -57,6 +57,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
in_shape1_[i] = input_shape1[i];
}
format = input0->GetFormat();
} else if (input_shape0.size() > input_shape1.size()) {
ndim_ = input_shape0.size();
auto fill_dim_num = input_shape0.size() - input_shape1.size();
@ -93,7 +94,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
output_shape.push_back(out_shape_[i]);
}
output->SetFormat(format);
output->set_shape(output_shape);
output->set_data_type(input0->data_type());
return RET_OK;

View File

@ -26,9 +26,11 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -85,9 +85,10 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] * block_shape->Get(0) - crops->Get(0) - crops->Get(1);
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -58,9 +58,9 @@ int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<te
shape[i] = dst_shape[i];
--input_shape_index;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(shape);
outputs[0]->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -44,9 +44,9 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT();
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -70,7 +70,8 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output_shape[axis] = output_axis_dim;
outputs_[0]->set_shape(output_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -32,7 +32,8 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
return RET_PARAM_INVALID;
}
outputs[0]->set_shape(inputs[1]->shape());
outputs[0]->SetFormat(inputs[1]->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace {
constexpr int kDepthToSpaceOutputNum = 1;
constexpr int kDepthToSpaceInputNum = 1;
}
} // namespace
int DepthToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
@ -56,7 +56,8 @@ int DepthToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index] / (block_size * block_size);
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -45,7 +45,8 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
out_shape.insert(out_shape.begin() + dim, 1, 1);
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -42,7 +42,8 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
(void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -43,7 +43,8 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -56,7 +56,8 @@ int FullConnection::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
out_shape[fc_prim->axis()] = input1->shape()[0];
output->set_shape(out_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -71,7 +71,8 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -59,7 +59,8 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -57,7 +57,8 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
y_shape[y_shape_size - 1] = w_shape[w_shape.size() - 1];
output->set_shape(y_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -67,6 +67,8 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
return RET_NULL_PTR;
}
output->set_data_type(on_value->data_type());
output->SetFormat(on_value->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -138,7 +138,8 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -55,9 +55,9 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
if (output == nullptr) {
return RET_NULL_PTR;
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -74,6 +74,7 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
input_shape.at(2) = output_w;
output->set_shape(input_shape);
output->set_data_type(input->data_type());
// todo: temp fix
output->SetFormat(schema::Format_NHWC);
return RET_OK;

View File

@ -34,7 +34,8 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
in_shape.push_back(shape_size);
output->set_shape(in_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -29,7 +29,8 @@ int Rank::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
std::vector<int> in_shape(1, 1);
output->set_shape(in_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -73,6 +73,8 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -114,7 +114,8 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -45,7 +45,8 @@ int Resize::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output_shape.push_back(input->Channel());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -57,7 +57,8 @@ int ScatterND::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
std::vector<int> out_shape(shape_data, shape_data + sizeof(shape_data) / sizeof(shape_data[0]));
output->set_shape(out_shape);
output->set_data_type(update->data_type());
output->SetFormat(update->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace {
constexpr int kSliceInputNum = 1;
constexpr int kSliceOutputNum = 1;
}
} // namespace
int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
@ -47,13 +47,13 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_PARAM_INVALID;
}
if (input_shape[i] <= slice_begin[i]) {
MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i] << " which should be <= "
<< input_shape[i];
MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i]
<< " which should be <= " << input_shape[i];
return RET_PARAM_INVALID;
}
if (slice_size[i] > (input_shape[i] - slice_begin[i])) {
MS_LOG(ERROR) << "Invalid size input " << slice_size[i] << " which should be <= "
<< input_shape[i] - slice_begin[i];
MS_LOG(ERROR) << "Invalid size input " << slice_size[i]
<< " which should be <= " << input_shape[i] - slice_begin[i];
return RET_PARAM_INVALID;
}
@ -62,7 +62,8 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -28,7 +28,8 @@ int SoftMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -55,8 +55,8 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
output_shape[split_dim] = split_dim_i;
outputs_[i]->set_shape(output_shape);
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace {
constexpr int kSqueezeInputNum = 1;
constexpr int kSqueezeOutputNum = 1;
}
} // namespace
int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (kSqueezeInputNum != inputs_.size()) {
@ -45,31 +45,31 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
std::vector<int> axes_;
for (auto iter = axis->begin(); iter != axis->end(); iter++) {
axes_.push_back(*iter);
}
}
if (axes_.size() == 0) {
for (int i = 0; i < in_shape.size(); i++) {
if (in_shape[i] != 1) {
out_shape.push_back(in_shape[i]);
}
if (in_shape[i] != 1) {
out_shape.push_back(in_shape[i]);
}
}
} else {
int axisIdx = 0;
for (int i = 0; i < in_shape.size(); i++) {
if (axisIdx < axes_.size() && axes_[axisIdx] == i) {
MS_ASSERT(in_shape[i] == 1);
axisIdx++;
continue;
} else {
out_shape.push_back(in_shape[i]);
}
int axisIdx = 0;
for (int i = 0; i < in_shape.size(); i++) {
if (axisIdx < axes_.size() && axes_[axisIdx] == i) {
MS_ASSERT(in_shape[i] == 1);
axisIdx++;
continue;
} else {
out_shape.push_back(in_shape[i]);
}
}
}
outputs_.front()->set_shape(out_shape);
outputs_.front()->set_data_type(in_tensor->data_type());
outputs_.front()->SetFormat(in_tensor->GetFormat());
return 0;
}
} // namespace mindspore::lite

View File

@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace {
constexpr int kStackOutputNum = 1;
constexpr int kStackMinInputNum = 2;
}
} // namespace
int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
@ -61,7 +61,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
output_shape.insert(output_shape.begin() + axis, inputs.size());
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -157,6 +157,8 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
outputs.front()->set_shape(output_shape);
outputs.front()->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -37,9 +37,9 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
out_shape.push_back(tmp);
}
output->SetFormat(input->GetFormat());
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -37,12 +37,13 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
output0->set_shape(input->shape());
output0->set_data_type(input->data_type());
// output0->shape().back() = topk_prim->k();
// output0->shape().back() = topk_prim->k();
output1->set_shape(input->shape());
output1->set_data_type(input->data_type());
// output1->shape().back() = topk_prim->k();
// output1->shape().back() = topk_prim->k();
output1->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -47,7 +47,8 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -36,7 +36,9 @@ int Unique::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output0->set_data_type(input->data_type());
output1->set_shape(input->shape());
output1->set_data_type(kNumberTypeInt32);
output1->SetFormat(input->GetFormat());
output0->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -65,9 +65,9 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
}
output->SetFormat(input->GetFormat());
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -41,8 +41,8 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
MS_ASSERT(out != nullptr);
out->set_shape(output_shape);
out->set_data_type(input->data_type());
out->SetFormat(input->GetFormat());
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -73,7 +73,8 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
output_shape[axisout] = nummax;
outputs_[0]->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -33,7 +33,8 @@ int ZerosLike::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -330,6 +330,7 @@ int Benchmark::MarkAccuracy() {
}
ReadCalibData();
CompareOutput();
if (cleanData) {
for (auto &msOutput : msOutputs) {
for (auto &outputTensor : msOutput.second) {