!26832 Support ValueNode inputs json generation in CollectFusedJsonWithSingleKernel

Merge pull request !26832 from zichun_ye/akg_json_build
This commit is contained in:
i-robot 2021-11-29 04:19:54 +00:00 committed by Gitee
commit cdb618984f
1 changed files with 36 additions and 5 deletions

View File

@ -15,6 +15,7 @@
*/
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
#ifdef ENABLE_GPU
#include <cuda.h>
@ -993,16 +994,46 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_node) {
kernel_json_ = nlohmann::json();
std::vector<AnfNodePtr> node_list, input_list, output_list;
node_list.push_back(c_node);
(void)input_list.insert(input_list.begin(), c_node->inputs().begin() + 1, c_node->inputs().end());
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(c_node));
FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node}));
FuncGraphManagerPtr mng = fg->manager();
if (mng == nullptr) {
mng = Manage(fg, false);
fg->set_manager(mng);
}
auto out_cnode = fg->output()->cast<CNodePtr>();
if (out_cnode == nullptr) {
MS_LOG(ERROR) << "Wrong graph generated for Kernel [" << c_node->fullname_with_scope() << "]";
return false;
}
// check all inputs in the cnodes: if it is a valuenode, replace it by a parameter
std::set<AnfNodePtr> value_nodes;
auto &inputs = out_cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
const auto &tnode = inputs[i];
auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
if (tensor) {
value_nodes.insert(tnode);
}
}
for (auto &vnode : value_nodes) {
auto parameter = fg->add_parameter();
parameter->set_abstract(vnode->abstract());
parameter->set_kernel_info(vnode->kernel_info_ptr());
mng->Replace(vnode, parameter);
}
node_list.push_back(out_cnode);
(void)input_list.insert(input_list.begin(), out_cnode->inputs().begin() + 1, out_cnode->inputs().end());
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode));
if (output_num > 1) {
for (int64_t idx = 0; idx < output_num; idx++) {
auto gt = c_node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), c_node, NewValueNode(idx)});
auto gt =
out_cnode->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(idx)});
output_list.emplace_back(std::move(gt));
}
} else {
output_list.push_back(c_node);
output_list.push_back(out_cnode);
}
return CollectFusedJson(node_list, input_list, output_list, &kernel_json_);
}