forked from mindspore-Ecosystem/mindspore
!4372 Modify the method for getting output index of metagraph.
Merge pull request !4372 from wangshaocong/lite
This commit is contained in:
commit
49fd9fa978
|
@ -58,7 +58,7 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
for (size_t i = 0; i < in_shape.size(); i++) {
|
||||
bool reduce_axis = false;
|
||||
for (int idx = 0; idx < num_axes; ++idx) {
|
||||
if (static_cast<size_t>((*axes)[idx]) == i) {
|
||||
if (static_cast<size_t>((*axes)[idx]) == i || static_cast<size_t>((*axes)[idx] + in_shape.size()) == i) {
|
||||
reduce_axis = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ int ReduceCPUKernel::CheckParameters() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
for (auto i = 0; i < num_axes_; i++) {
|
||||
if (axes_[i] < -static_cast<int>(input_rank) || static_cast<size_t>(axes_[i]) >= input_rank) {
|
||||
if (axes_[i] < -static_cast<int>(input_rank) || axes_[i] >= static_cast<int>(input_rank)) {
|
||||
MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in ["
|
||||
<< -static_cast<int>(input_rank) << ", " << input_rank - 1 << "].";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -236,18 +236,31 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT>
|
|||
}
|
||||
}
|
||||
|
||||
void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache,
|
||||
void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const mindspore::lite::TensorCache &tensorCache,
|
||||
schema::MetaGraphT *subGraphDef) {
|
||||
auto opGraph = OpGraphT::Build(subGraphDef);
|
||||
auto graphInputs = tensorCache.GetGraphInputs();
|
||||
auto graphOutputs = opGraph->GetOutputNode();
|
||||
|
||||
subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end());
|
||||
|
||||
for (const auto &output : graphOutputs) {
|
||||
auto op = opMap[output->ID()];
|
||||
for (auto outputIndex : op->outputIndex) {
|
||||
subGraphDef->outputIndex.emplace_back(outputIndex);
|
||||
for (auto outputIndex : tflite_subgraph->outputs) {
|
||||
int i = 0;
|
||||
bool found = false;
|
||||
for (const auto &tfliteOp : tflite_subgraph->operators) {
|
||||
int j = 0;
|
||||
auto opType = GetTfliteNodeType(tfliteOp, tflite_model);
|
||||
std::string opName = opType + "-" + std::to_string(i++);
|
||||
for (auto opOutputIndex : tfliteOp->outputs) {
|
||||
if (outputIndex == opOutputIndex) {
|
||||
subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]);
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
j++;
|
||||
}
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -284,7 +297,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
SetGraphTensorIndex(tensorCache, subGraph.get());
|
||||
SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get());
|
||||
SetAllTensors(tensorCache, subGraph.get());
|
||||
return subGraph.release();
|
||||
}
|
||||
|
|
|
@ -50,7 +50,10 @@ class TfliteModelParser : public ModelParser {
|
|||
|
||||
void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);
|
||||
|
||||
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef);
|
||||
void SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const mindspore::lite::TensorCache &tensorCache,
|
||||
schema::MetaGraphT *subGraphDef);
|
||||
|
||||
STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph,
|
||||
|
|
Loading…
Reference in New Issue