remove IsControlFlowPattern

This commit is contained in:
mengyuanli 2021-10-12 16:43:28 +08:00
parent 72fc71fc07
commit 60fe78333f
2 changed files with 3 additions and 25 deletions

View File

@ -283,13 +283,11 @@ int Scheduler::CheckCpuValid(std::vector<kernel::LiteKernel *> *dst_kernels) {
int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *dst_kernels) {
#ifndef CONTROLFLOW_TENSORLIST_CLIP
if (IsControlFlowParttern(*dst_kernels)) {
*is_control_flow_ = true;
if (*is_control_flow_) {
return ConstructControlFlowMainGraph(dst_kernels);
}
#endif
*is_control_flow_ = false;
auto src_kernel = *dst_kernels;
dst_kernels->clear();
std::map<const kernel::LiteKernel *, bool> is_kernel_finish;
@ -582,15 +580,7 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
ret = KernelInferShape(inputs, outputs, parameter);
#ifndef CONTROLFLOW_TENSORLIST_CLIP
bool not_able_to_infer = false;
for (auto &input : inputs) {
if (input->data_type() == kObjectTypeTensorType) {
not_able_to_infer = true;
break;
}
}
if (not_able_to_infer) {
if (*is_control_flow_) {
for (auto &output : outputs) {
output->set_shape({-1});
}
@ -713,6 +703,7 @@ int Scheduler::InferCallShape(const lite::Model::Node *node) {
#ifndef CONTROLFLOW_TENSORLIST_CLIP
auto switch_input = NodeInputIsSwitch(node);
if (switch_input) {
*is_control_flow_ = true;
return InferSwitchShape(switch_input);
}
#endif
@ -1749,18 +1740,6 @@ void CopyTensorList(TensorList *dst_tensor, TensorList *src_tensor) {
dst_tensor->set_tensors(cpy_tensors);
}
bool Scheduler::IsControlFlowParttern(const std::vector<kernel::LiteKernel *> &kernels) {
if (std::any_of(kernels.begin(), kernels.end(), [](kernel::LiteKernel *item) {
if (item->op_parameter()) {
return item->op_parameter()->type_ == schema::PrimitiveType_PartialFusion;
}
return false;
})) {
return true;
}
return false;
}
int Scheduler::ConstructControlFlowMainGraph(std::vector<kernel::LiteKernel *> *kernels) {
auto back_kernels = *kernels;
kernels->clear();

View File

@ -131,7 +131,6 @@ class Scheduler {
bool SubGraphHasScheduled(const int &index);
void SubGraphMarkScheduled(const int &index);
void SetSubgraphForPartialNode();
bool IsControlFlowParttern(const std::vector<kernel::LiteKernel *> &kernels);
int ConstructControlFlowMainGraph(std::vector<kernel::LiteKernel *> *kernels);
#endif