forked from mindspore-Ecosystem/mindspore
!26832 Support ValueNode inputs json generation in CollectFusedJsonWithSingleKernel
Merge pull request !26832 from zichun_ye/akg_json_build
This commit is contained in:
commit
cdb618984f
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
|
||||
#ifdef ENABLE_GPU
|
||||
#include <cuda.h>
|
||||
|
@ -993,16 +994,46 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
|||
bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_node) {
|
||||
kernel_json_ = nlohmann::json();
|
||||
std::vector<AnfNodePtr> node_list, input_list, output_list;
|
||||
node_list.push_back(c_node);
|
||||
(void)input_list.insert(input_list.begin(), c_node->inputs().begin() + 1, c_node->inputs().end());
|
||||
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(c_node));
|
||||
FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node}));
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(fg, false);
|
||||
fg->set_manager(mng);
|
||||
}
|
||||
auto out_cnode = fg->output()->cast<CNodePtr>();
|
||||
if (out_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Wrong graph generated for Kernel [" << c_node->fullname_with_scope() << "]";
|
||||
return false;
|
||||
}
|
||||
// check all inputs in the cnodes: if it is a valuenode, replace it by a parameter
|
||||
std::set<AnfNodePtr> value_nodes;
|
||||
auto &inputs = out_cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
const auto &tnode = inputs[i];
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
|
||||
if (tensor) {
|
||||
value_nodes.insert(tnode);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &vnode : value_nodes) {
|
||||
auto parameter = fg->add_parameter();
|
||||
parameter->set_abstract(vnode->abstract());
|
||||
parameter->set_kernel_info(vnode->kernel_info_ptr());
|
||||
mng->Replace(vnode, parameter);
|
||||
}
|
||||
|
||||
node_list.push_back(out_cnode);
|
||||
(void)input_list.insert(input_list.begin(), out_cnode->inputs().begin() + 1, out_cnode->inputs().end());
|
||||
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode));
|
||||
if (output_num > 1) {
|
||||
for (int64_t idx = 0; idx < output_num; idx++) {
|
||||
auto gt = c_node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), c_node, NewValueNode(idx)});
|
||||
auto gt =
|
||||
out_cnode->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(idx)});
|
||||
output_list.emplace_back(std::move(gt));
|
||||
}
|
||||
} else {
|
||||
output_list.push_back(c_node);
|
||||
output_list.push_back(out_cnode);
|
||||
}
|
||||
return CollectFusedJson(node_list, input_list, output_list, &kernel_json_);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue