forked from mindspore-Ecosystem/mindspore
constant fold approve multi output
This commit is contained in:
parent
f42b3bbfbc
commit
e121bcd3be
|
@ -236,6 +236,136 @@ MetaGraphTptr BuildMixGraph() {
|
||||||
// final output
|
// final output
|
||||||
return meta_graph;
|
return meta_graph;
|
||||||
}
|
}
|
||||||
|
MetaGraphTptr BuildSplitGraph() {
|
||||||
|
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||||
|
meta_graph->name = "graph";
|
||||||
|
// slice node
|
||||||
|
auto split_node = std::make_unique<schema::CNodeT>();
|
||||||
|
split_node->inputIndex = {0};
|
||||||
|
split_node->outputIndex = {1, 2};
|
||||||
|
split_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
split_node->primitive->value.type = schema::PrimitiveType_Split;
|
||||||
|
std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>();
|
||||||
|
attr->numberSplit = 2;
|
||||||
|
attr->splitDim = 1;
|
||||||
|
split_node->primitive->value.value = attr.release();
|
||||||
|
split_node->name = "split";
|
||||||
|
meta_graph->nodes.emplace_back(std::move(split_node));
|
||||||
|
|
||||||
|
meta_graph->inputIndex = {0, 3, 4};
|
||||||
|
meta_graph->outputIndex = {5, 6};
|
||||||
|
|
||||||
|
auto mul_node1 = std::make_unique<schema::CNodeT>();
|
||||||
|
mul_node1->inputIndex = {1, 3};
|
||||||
|
mul_node1->outputIndex = {5};
|
||||||
|
mul_node1->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
mul_node1->primitive->value.type = schema::PrimitiveType_Mul;
|
||||||
|
std::unique_ptr<schema::MulT> mul_attr = std::make_unique<schema::MulT>();
|
||||||
|
mul_node1->primitive->value.value = mul_attr.release();
|
||||||
|
mul_node1->name = "mul1";
|
||||||
|
meta_graph->nodes.emplace_back(std::move(mul_node1));
|
||||||
|
|
||||||
|
auto mul_node2 = std::make_unique<schema::CNodeT>();
|
||||||
|
mul_node2->inputIndex = {2, 4};
|
||||||
|
mul_node2->outputIndex = {6};
|
||||||
|
mul_node2->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
mul_node2->primitive->value.type = schema::PrimitiveType_Mul;
|
||||||
|
std::unique_ptr<schema::MulT> mul2_attr = std::make_unique<schema::MulT>();
|
||||||
|
mul_node2->primitive->value.value = mul2_attr.release();
|
||||||
|
mul_node2->name = "mul2";
|
||||||
|
meta_graph->nodes.emplace_back(std::move(mul_node2));
|
||||||
|
|
||||||
|
// input 0: data1
|
||||||
|
auto input0 = std::make_unique<schema::TensorT>();
|
||||||
|
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||||
|
input0->format = schema::Format_NHWC;
|
||||||
|
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||||
|
input0->dims = {1, 2, 2, 3};
|
||||||
|
input0->offset = -1;
|
||||||
|
auto input0_data = new(std::nothrow) float[2 * 2 * 3];
|
||||||
|
for (auto i = 0; i < 2 * 2 * 3; i++) {
|
||||||
|
input0_data[i] = i;
|
||||||
|
}
|
||||||
|
input0->data.resize(sizeof(float) * 2 * 2 * 3);
|
||||||
|
memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float));
|
||||||
|
delete[] input0_data;
|
||||||
|
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||||
|
|
||||||
|
// split output1
|
||||||
|
auto split_output1 = std::make_unique<schema::TensorT>();
|
||||||
|
split_output1->nodeType = schema::NodeType::NodeType_Parameter;
|
||||||
|
split_output1->format = schema::Format_NHWC;
|
||||||
|
split_output1->dataType = TypeId::kNumberTypeFloat32;
|
||||||
|
split_output1->dims = {1, 1, 2, 3};
|
||||||
|
split_output1->offset = -1;
|
||||||
|
split_output1->data.resize(sizeof(float) * 1 * 2 * 3);
|
||||||
|
auto split_output_data1 = new(std::nothrow) float[1 * 2 * 3];
|
||||||
|
memcpy(split_output1->data.data(), split_output_data1, 1 * 2 * 3 * sizeof(float));
|
||||||
|
delete[] split_output_data1;
|
||||||
|
meta_graph->allTensors.emplace_back(std::move(split_output1));
|
||||||
|
|
||||||
|
// split output2
|
||||||
|
auto split_output2 = std::make_unique<schema::TensorT>();
|
||||||
|
split_output2->nodeType = schema::NodeType::NodeType_Parameter;
|
||||||
|
split_output2->format = schema::Format_NHWC;
|
||||||
|
split_output2->dataType = TypeId::kNumberTypeFloat32;
|
||||||
|
split_output2->dims = {1, 1, 2, 3};
|
||||||
|
split_output2->offset = -1;
|
||||||
|
split_output2->data.resize(sizeof(float) * 1 * 2 * 3);
|
||||||
|
auto split_output_data2 = new(std::nothrow) float[1 * 2 * 3];
|
||||||
|
memcpy(split_output2->data.data(), split_output_data2, 1 * 2 * 3 * sizeof(float));
|
||||||
|
delete[] split_output_data2;
|
||||||
|
meta_graph->allTensors.emplace_back(std::move(split_output2));
|
||||||
|
|
||||||
|
// input 1: data2
|
||||||
|
auto input1 = std::make_unique<schema::TensorT>();
|
||||||
|
input1->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||||
|
input1->format = schema::Format_NHWC;
|
||||||
|
input1->dataType = TypeId::kNumberTypeFloat32;
|
||||||
|
input1->dims = {1, 1, 2, 3};
|
||||||
|
input1->offset = -1;
|
||||||
|
input1->data.resize(sizeof(float) * 2 * 3);
|
||||||
|
auto input1_data = new(std::nothrow) float[2 * 3];
|
||||||
|
for (auto i = 0; i < 2 * 3; i++) {
|
||||||
|
input1_data[i] = i;
|
||||||
|
}
|
||||||
|
memcpy(input1->data.data(), input1_data, 2 * 3 * sizeof(float));
|
||||||
|
delete[] input1_data;
|
||||||
|
meta_graph->allTensors.emplace_back(std::move(input1));
|
||||||
|
|
||||||
|
// input 2: data3
|
||||||
|
auto input2 = std::make_unique<schema::TensorT>();
|
||||||
|
input2->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||||
|
input2->format = schema::Format_NHWC;
|
||||||
|
input2->dataType = TypeId::kNumberTypeFloat32;
|
||||||
|
input2->dims = {1, 1, 2, 3};
|
||||||
|
input2->offset = -1;
|
||||||
|
input2->data.resize(sizeof(float) * 2 * 3);
|
||||||
|
auto input2_data = new(std::nothrow) float[2 * 3];
|
||||||
|
for (auto i = 0; i < 2 * 3; i++) {
|
||||||
|
input2_data[i] = 10;
|
||||||
|
}
|
||||||
|
memcpy(input2->data.data(), input2_data, 2 * 3 * sizeof(float));
|
||||||
|
delete[] input2_data;
|
||||||
|
meta_graph->allTensors.emplace_back(std::move(input2));
|
||||||
|
|
||||||
|
// final mul output1
|
||||||
|
auto mul_output = std::make_unique<schema::TensorT>();
|
||||||
|
mul_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||||
|
mul_output->format = schema::Format_NHWC;
|
||||||
|
mul_output->dataType = TypeId::kNumberTypeFloat32;
|
||||||
|
mul_output->dims = {1, 1, 2, 3};
|
||||||
|
meta_graph->allTensors.emplace_back(std::move(mul_output));
|
||||||
|
|
||||||
|
// final mul output2
|
||||||
|
auto mul_output2 = std::make_unique<schema::TensorT>();
|
||||||
|
mul_output2->nodeType = schema::NodeType::NodeType_Parameter;
|
||||||
|
mul_output2->format = schema::Format_NHWC;
|
||||||
|
mul_output2->dataType = TypeId::kNumberTypeFloat32;
|
||||||
|
mul_output2->dims = {1, 1, 2, 3};
|
||||||
|
meta_graph->allTensors.emplace_back(std::move(mul_output2));
|
||||||
|
return meta_graph;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) {
|
TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) {
|
||||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Add, new schema::AddT);
|
auto meta_graph = BuildGraph(schema::PrimitiveType_Add, new schema::AddT);
|
||||||
|
@ -483,4 +613,19 @@ TEST_F(ConstantFoldingFusionTest, TestCastDimsConstantFold) {
|
||||||
auto new_meta_graph = lite::Export(new_graph);
|
auto new_meta_graph = lite::Export(new_graph);
|
||||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantFoldingFusionTest, TestSplitConstantFold) {
|
||||||
|
auto meta_graph = BuildSplitGraph();
|
||||||
|
auto input_tensor = meta_graph->allTensors.at(0).get();
|
||||||
|
input_tensor->dataType = kNumberTypeFloat32;
|
||||||
|
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>("test", false);
|
||||||
|
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||||
|
ASSERT_NE(nullptr, new_graph);
|
||||||
|
auto new_meta_graph = lite::Export(new_graph);
|
||||||
|
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -319,7 +319,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
|
||||||
if (utils::isa<PrimitiveCPtr>(value)) {
|
if (utils::isa<PrimitiveCPtr>(value)) {
|
||||||
auto primitive = value->cast<PrimitiveCPtr>();
|
auto primitive = value->cast<PrimitiveCPtr>();
|
||||||
MS_ASSERT(primitive != nullptr);
|
MS_ASSERT(primitive != nullptr);
|
||||||
return (schema::PrimitiveType)primitive->Type();
|
return (schema::PrimitiveType) primitive->Type();
|
||||||
} else if (utils::isa<Primitive>(value)) {
|
} else if (utils::isa<Primitive>(value)) {
|
||||||
auto primitive = value->cast<PrimitivePtr>();
|
auto primitive = value->cast<PrimitivePtr>();
|
||||||
MS_ASSERT(primitive != nullptr);
|
MS_ASSERT(primitive != nullptr);
|
||||||
|
@ -392,8 +392,8 @@ size_t GetOutputTensorNum(const AnfNodePtr &node) {
|
||||||
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||||
auto output_node_list = GetRealNodeUsedList(graph, node);
|
auto output_node_list = GetRealNodeUsedList(graph, node);
|
||||||
if (output_node_list->size() != 1) {
|
if (output_node_list->size() != 1) {
|
||||||
MS_LOG(DEBUG) << "fusion node has multi output nodes";
|
MS_LOG(DEBUG) << "fusion node has multi output nodes";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -412,5 +412,50 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
|
||||||
std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
|
std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
|
||||||
return output_node_list;
|
return output_node_list;
|
||||||
}
|
}
|
||||||
|
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
|
||||||
|
MS_ASSERT(tuple_get_item != nullptr);
|
||||||
|
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
||||||
|
MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
|
||||||
|
MS_ASSERT(output_index_value_node != nullptr);
|
||||||
|
auto value_node = output_index_value_node->cast<ValueNodePtr>();
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
return IntToSize(GetValue<int>(value_node->value()));
|
||||||
|
}
|
||||||
|
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
|
||||||
|
const AnfNodePtr &node,
|
||||||
|
size_t output_index) {
|
||||||
|
MS_ASSERT(graph != nullptr);
|
||||||
|
MS_ASSERT(node != nullptr);
|
||||||
|
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
|
||||||
|
auto manager = graph->manager();
|
||||||
|
MS_ASSERT(manager != nullptr);
|
||||||
|
auto iter = manager->node_users().find(node);
|
||||||
|
if (iter == manager->node_users().end()) {
|
||||||
|
MS_LOG(ERROR) << "node has no output in manager";
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
|
auto output_info_list = iter->second;
|
||||||
|
for (const auto &output_info : output_info_list) {
|
||||||
|
size_t used_output_index;
|
||||||
|
if (GetCNodeType(output_info.first) == schema::PrimitiveType_TupleGetItem) {
|
||||||
|
used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
|
||||||
|
} else if (GetCNodeType(node) == schema::PrimitiveType_TupleGetItem) {
|
||||||
|
used_output_index = output_index;
|
||||||
|
} else {
|
||||||
|
if (output_index != 0) {
|
||||||
|
MS_LOG(ERROR) << "node has no output in manager";
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
|
if (used_output_index == output_index) {
|
||||||
|
output_node_list->push_back(output_info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -63,6 +63,8 @@ bool CheckIsAllInputsParam(const AnfNodePtr &node);
|
||||||
size_t GetOutputTensorNum(const AnfNodePtr &node);
|
size_t GetOutputTensorNum(const AnfNodePtr &node);
|
||||||
|
|
||||||
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||||
|
|
||||||
|
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
|
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
|
||||||
|
|
|
@ -41,7 +41,7 @@ std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
|
||||||
auto tensorT = tmp_meta_graph->allTensors.at(input_index).get();
|
auto tensorT = tmp_meta_graph->allTensors.at(input_index).get();
|
||||||
auto tensor_shape = tensorT->dims;
|
auto tensor_shape = tensorT->dims;
|
||||||
auto lite_tensor =
|
auto lite_tensor =
|
||||||
new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType);
|
new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType);
|
||||||
if (lite_tensor == nullptr) {
|
if (lite_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "lite tensor is nullptr";
|
MS_LOG(ERROR) << "lite tensor is nullptr";
|
||||||
return input_tensors;
|
return input_tensors;
|
||||||
|
@ -106,7 +106,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
|
||||||
mindspore::lite::PrimitiveC *primitive) {
|
mindspore::lite::PrimitiveC *primitive) {
|
||||||
MS_ASSERT(nullptr != lite_primitive);
|
MS_ASSERT(nullptr != lite_primitive);
|
||||||
auto data_type = inputs.front()->data_type();
|
auto data_type = inputs.front()->data_type();
|
||||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()};
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType) primitive->Type()};
|
||||||
lite::Context context;
|
lite::Context context;
|
||||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||||
if (creator != nullptr) {
|
if (creator != nullptr) {
|
||||||
|
@ -115,6 +115,44 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_node, const AnfNodePtr &input_node,
|
||||||
|
std::vector<Tensor *> output_tensors, size_t replace_index) {
|
||||||
|
MS_ASSERT(func_graph != nullptr);
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_ASSERT(manager != nullptr);
|
||||||
|
if (output_tensors.size() != 1) {
|
||||||
|
for (size_t k = 0; k < output_tensors.size(); k++) {
|
||||||
|
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, input_node, k);
|
||||||
|
if (used_node_list->size() != 1) {
|
||||||
|
MS_LOG(ERROR) << " output must tuple_getitem";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto tuple_node = used_node_list->at(0).first;
|
||||||
|
if (GetCNodeType(tuple_node) == schema::PrimitiveType_TupleGetItem) {
|
||||||
|
auto new_parameter = CreateNewParamter(func_graph, output_tensors.at(k));
|
||||||
|
if (new_parameter == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
new_parameter->set_name(input_node->fullname_with_scope() + "_const_" + std::to_string(k));
|
||||||
|
manager->Replace(tuple_node, new_parameter);
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << " multi out tensor must connect tuple-getitem: " << input_node->fullname_with_scope();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto new_parameter = CreateNewParamter(func_graph, output_tensors.front());
|
||||||
|
if (new_parameter == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
new_parameter->set_name(input_node->fullname_with_scope());
|
||||||
|
any_node->set_input(replace_index, new_parameter);
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
|
void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
|
||||||
if (input_tensor != nullptr) {
|
if (input_tensor != nullptr) {
|
||||||
|
@ -140,64 +178,66 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
||||||
}
|
}
|
||||||
auto any_node = node->cast<CNodePtr>();
|
auto any_node = node->cast<CNodePtr>();
|
||||||
CheckIfCNodeIsNull(any_node);
|
CheckIfCNodeIsNull(any_node);
|
||||||
|
bool changed = false;
|
||||||
for (size_t i = 1; i < any_node->inputs().size(); i++) {
|
for (size_t i = 1; i < any_node->inputs().size(); i++) {
|
||||||
auto input_node = any_node->input(i);
|
auto input_node = any_node->input(i);
|
||||||
if (input_node->isa<CNode>() && CheckIsAllInputsParam(input_node)) {
|
if (!input_node->isa<CNode>() || !CheckIsAllInputsParam(input_node)) {
|
||||||
auto input_cnode = input_node->cast<CNodePtr>();
|
continue;
|
||||||
auto input_tensors = GetCNodeInputTensors(input_cnode);
|
}
|
||||||
if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) {
|
auto input_cnode = input_node->cast<CNodePtr>();
|
||||||
FreeTensors(&input_tensors, nullptr);
|
auto input_tensors = GetCNodeInputTensors(input_cnode);
|
||||||
continue;
|
if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) {
|
||||||
|
FreeTensors(&input_tensors, nullptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
changed = true;
|
||||||
|
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
||||||
|
auto output_nums = GetOutputTensorNum(input_cnode);
|
||||||
|
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
||||||
|
auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
||||||
|
if (lite_primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "lite_primitive is nullptr";
|
||||||
|
FreeTensors(&input_tensors, &output_tensors);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// here, input_tensor's format need to be transposed nhwc according to fmkType,
|
||||||
|
// but for the time being, we only transpose the tensor with 0/1/2/3D.
|
||||||
|
// Others should be added in future.
|
||||||
|
for (size_t j = 0; j < input_tensors.size(); ++j) {
|
||||||
|
input_tensors[j]->SetFormat(schema::Format_NHWC);
|
||||||
|
if (input_tensors[j]->shape().size() == 4) {
|
||||||
|
MS_LOG(INFO) << "init input_tensor format to nhwc";
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
}
|
||||||
auto output_nums = GetOutputTensorNum(input_cnode);
|
lite_primitive->InferShape(input_tensors, output_tensors);
|
||||||
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
auto parameter = kernel::PopulateParameter(lite_primitive.get());
|
||||||
auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
if (parameter == nullptr) {
|
||||||
if (lite_primitive == nullptr) {
|
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
||||||
MS_LOG(ERROR) << "lite_primitive is nullptr";
|
<< schema::EnumNamePrimitiveType((schema::PrimitiveType) (lite_primitive->Type()));
|
||||||
FreeTensors(&input_tensors, &output_tensors);
|
return nullptr;
|
||||||
return nullptr;
|
}
|
||||||
}
|
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
|
||||||
// here, input_tensor's format need to be transposed nhwc according to fmkType,
|
if (lite_kernel == nullptr) {
|
||||||
// but for the time being, we only transpose the tensor with 0/1/2/3D.
|
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||||
// Others should be added in future.
|
FreeTensors(&input_tensors, &output_tensors);
|
||||||
for (size_t j = 0; j < input_tensors.size(); ++j) {
|
return nullptr;
|
||||||
input_tensors[j]->SetFormat(schema::Format_NHWC);
|
}
|
||||||
if (input_tensors[j]->shape().size() == 4) {
|
auto ret = lite_kernel->Run();
|
||||||
MS_LOG(INFO) << "init input_tensor format to nhwc";
|
if (0 != ret) {
|
||||||
}
|
FreeTensors(&input_tensors, &output_tensors);
|
||||||
}
|
MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name();
|
||||||
lite_primitive->InferShape(input_tensors, output_tensors);
|
return nullptr;
|
||||||
auto parameter = kernel::PopulateParameter(lite_primitive.get());
|
}
|
||||||
if (parameter == nullptr) {
|
// replace cnode by new param
|
||||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
if (ReplaceCNode(func_graph, any_node, input_node, output_tensors, i) != lite::RET_OK) {
|
||||||
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
|
|
||||||
if (lite_kernel == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
|
||||||
FreeTensors(&input_tensors, &output_tensors);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto ret = lite_kernel->Run();
|
|
||||||
if (0 != ret) {
|
|
||||||
FreeTensors(&input_tensors, &output_tensors);
|
|
||||||
MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto new_parameter = CreateNewParamter(func_graph, output_tensors.front());
|
|
||||||
if (new_parameter == nullptr) {
|
|
||||||
FreeTensors(&input_tensors, &output_tensors);
|
|
||||||
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << lite_kernel->name();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
new_parameter->set_name(input_node->fullname_with_scope());
|
|
||||||
any_node->set_input(i, new_parameter);
|
|
||||||
FreeTensors(&input_tensors, &output_tensors);
|
FreeTensors(&input_tensors, &output_tensors);
|
||||||
delete (lite_kernel);
|
delete (lite_kernel);
|
||||||
|
MS_LOG(ERROR) << "constant_folding replace cnode failed";
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
FreeTensors(&input_tensors, &output_tensors);
|
||||||
|
delete (lite_kernel);
|
||||||
}
|
}
|
||||||
return any_node;
|
return changed ? any_node : nullptr;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::opt
|
} // namespace mindspore::opt
|
||||||
|
|
Loading…
Reference in New Issue