control flow infershape

This commit is contained in:
yefeng 2022-09-15 19:22:04 +08:00
parent 3865101b68
commit 57e276875d
2 changed files with 14 additions and 0 deletions

View File

@ -888,6 +888,11 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) {
MS_ASSERT(src_model_ != nullptr);
MS_ASSERT(!src_model_->graph_.sub_graphs_.empty());
MS_ASSERT(src_model_->graph_.sub_graphs_.size() > subgraph_index);
if (find(infer_subgraph_index_.begin(), infer_subgraph_index_.end(), subgraph_index) != infer_subgraph_index_.end()) {
MS_LOG(ERROR) << "The subgraph has been infer shape, subgraph index: " << subgraph_index;
return RET_INFER_INVALID;
}
infer_subgraph_index_.push_back(subgraph_index);
auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index);
int subgraph_infershape_ret = RET_OK;
for (auto node_index : subgraph->node_indices_) {
@ -1184,6 +1189,10 @@ kernel::KernelExec *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
}
} else {
data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "GetFirstFp32Fp16OrInt8Type is unknown.";
return nullptr;
}
}
if (context_->float_mode) {
for (auto tensor : out_tensors) {
@ -1693,6 +1702,10 @@ TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_ten
return dtype;
}
}
if (in_tensors.empty()) {
MS_LOG(ERROR) << "in tensor is empty.";
return kTypeUnknown;
}
MS_ASSERT(!in_tensors.empty());
return in_tensors[0]->data_type() == kObjectTypeTensorType ? kNumberTypeFloat32 : in_tensors[0]->data_type();
}

View File

@ -171,6 +171,7 @@ class Scheduler {
std::map<std::string, TypeId> *execution_plan_ = nullptr;
const std::map<std::string, std::map<std::string, std::string>> *config_info_ = nullptr;
std::shared_ptr<ShapeFusionPass> shape_fusion_pass_ = nullptr;
std::vector<int> infer_subgraph_index_;
};
} // namespace mindspore::lite