refactor processor setting

This commit is contained in:
wenfangpei 2021-03-30 20:52:41 +08:00
parent 8d42a57093
commit 4af448bdd3
16 changed files with 59 additions and 148 deletions

View File

@ -44,7 +44,7 @@ std::vector<std::string> AkgKernelBuilder::GetNotCachedKernelJsons(const std::ve
auto kernel_name = json_generator.kernel_name();
MS_LOG(DEBUG) << "Akg start compile op: " << kernel_name;
auto cached_kernel_pack = AkgSearchCache(kernel_name, GetProcessorStr(anf_node));
auto cached_kernel_pack = AkgSearchCache(kernel_name);
if (cached_kernel_pack != nullptr) {
MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
@ -67,7 +67,7 @@ std::vector<std::string> AkgKernelBuilder::GetNotCachedKernelJsons(const std::ve
bool AkgKernelBuilder::InsertToCache(const std::vector<JsonNodePair> &build_args) {
for (const auto &[json_generator, anf_node] : build_args) {
auto kernel_name = json_generator.kernel_name();
auto new_kernel_pack = AkgInsertCache(kernel_name, GetProcessorStr(anf_node));
auto new_kernel_pack = AkgInsertCache(kernel_name);
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
@ -82,7 +82,7 @@ bool AkgKernelBuilder::InsertToCache(const std::vector<JsonNodePair> &build_args
bool AkgKernelBuilder::HandleRepeatNodes() {
for (const auto &[json_generator, anf_node] : repeat_nodes_) {
auto kernel_name = json_generator.kernel_name();
auto cached_kernel_pack = AkgSearchCache(kernel_name, GetProcessorStr(anf_node));
auto cached_kernel_pack = AkgSearchCache(kernel_name);
if (cached_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Use cached kernel failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";

View File

@ -36,8 +36,8 @@ class AkgKernelBuilder {
~AkgKernelBuilder() = default;
virtual KernelBuildClient *GetClient() = 0;
virtual KernelPackPtr AkgSearchCache(const std::string &kernel_name, const std::string &processor) = 0;
virtual KernelPackPtr AkgInsertCache(const std::string &kernel_name, const std::string &processor) = 0;
virtual KernelPackPtr AkgSearchCache(const std::string &kernel_name) = 0;
virtual KernelPackPtr AkgInsertCache(const std::string &kernel_name) = 0;
virtual void AkgSetKernelMod(const KernelPackPtr &kernel_pack, const AkgKernelJsonGenerator &json_generator,
const AnfNodePtr &anf_node) = 0;
virtual void AkgSaveJsonInfo(const string &kernel_name, const string &kernel_json) = 0;

View File

@ -544,7 +544,7 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_node);
(*kernel_json)[kJsonKeyProcess] = GetStrProcessorFromContext(); // GetProcessorStr(anf_node);
(*kernel_json)[kJsonKeyComposite] = false;
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
@ -632,7 +632,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]);
(*kernel_json)[kJsonKeyProcess] = GetStrProcessorFromContext();
(*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id();
@ -765,7 +765,7 @@ void AkgKernelJsonGenerator::GenParallelJson(const std::vector<AnfNodePtr> &anf_
}
if (!sub_graphs_info.empty()) {
auto processor = GetProcessorStr(anf_nodes[0]);
auto processor = GetStrProcessorFromContext(); // GetProcessorStr(anf_nodes[0]);
if (processor != kProcessorCuda) {
MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now.";
}

View File

@ -33,12 +33,12 @@
namespace mindspore {
namespace kernel {
KernelPackPtr AkgAscendKernelBuilder::AkgSearchCache(const std::string &kernel_name, const std::string &processor) {
return tbe::TbeUtils::SearchCache(kernel_name, processor);
KernelPackPtr AkgAscendKernelBuilder::AkgSearchCache(const std::string &kernel_name) {
return tbe::TbeUtils::SearchCache(kernel_name, kProcessorAiCore);
}
KernelPackPtr AkgAscendKernelBuilder::AkgInsertCache(const std::string &kernel_name, const std::string &processor) {
return tbe::TbeUtils::InsertCache(kernel_name, processor);
KernelPackPtr AkgAscendKernelBuilder::AkgInsertCache(const std::string &kernel_name) {
return tbe::TbeUtils::InsertCache(kernel_name, kProcessorAiCore);
}
void AkgAscendKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack,

View File

@ -32,8 +32,8 @@ class AkgAscendKernelBuilder : public AkgKernelBuilder {
~AkgAscendKernelBuilder() = default;
kernel::KernelBuildClient *GetClient() override { return &(kernel::AscendKernelBuildClient::Instance()); }
KernelPackPtr AkgSearchCache(const std::string &kernel_name, const std::string &processor) override;
KernelPackPtr AkgInsertCache(const std::string &kernel_name, const std::string &processor) override;
KernelPackPtr AkgSearchCache(const std::string &kernel_name) override;
KernelPackPtr AkgInsertCache(const std::string &kernel_name) override;
void AkgSetKernelMod(const KernelPackPtr &kernel_pack, const AkgKernelJsonGenerator &json_generator,
const AnfNodePtr &anf_node) override;
void AkgSaveJsonInfo(const string &kernel_name, const string &kernel_json) override;

View File

@ -29,12 +29,12 @@ namespace kernel {
constexpr int32_t ARGS_SIZE = 1;
constexpr auto kCompileWithJsonFunc = "compilewithjson";
KernelPackPtr AkgGpuKernelBuilder::AkgSearchCache(const std::string &kernel_name, const std::string &processor) {
return SearchCache(kernel_name, processor);
KernelPackPtr AkgGpuKernelBuilder::AkgSearchCache(const std::string &kernel_name) {
return SearchCache(kernel_name, kProcessorCuda);
}
KernelPackPtr AkgGpuKernelBuilder::AkgInsertCache(const std::string &kernel_name, const std::string &processor) {
return InsertCache(kernel_name, processor);
KernelPackPtr AkgGpuKernelBuilder::AkgInsertCache(const std::string &kernel_name) {
return InsertCache(kernel_name, kProcessorCuda);
}
void AkgGpuKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack,
@ -49,99 +49,5 @@ void AkgGpuKernelBuilder::AkgSaveJsonInfo(const string &kernel_name, const strin
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
}
KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
auto processor = GetProcessorStr(anf_node);
auto kernel_name = json_generator.kernel_name();
auto cached_kernel_pack = SearchCache(kernel_name, processor);
if (cached_kernel_pack != nullptr) {
MS_LOG(INFO) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return cached_kernel_pack;
}
auto kernel_json = json_generator.kernel_json_str();
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
(void)alarm(AUTODIFF_COMPILE_OVERTIME);
auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(kernel_json);
(void)alarm(0);
if (!res) {
MS_LOG(ERROR) << "Akg compile failed, json: " << kernel_json;
return nullptr;
}
auto new_kernel_pack = InsertCache(kernel_name, processor);
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return nullptr;
}
return new_kernel_pack;
}
KernelModPtr AkgGpuKernelBuilder::BuildByJson(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "Akg start compile, op[" << anf_node->fullname_with_scope() << "]";
AkgKernelJsonGenerator json_generator;
if (!json_generator.CollectJson(anf_node)) {
MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
}
auto kernel_pack = OpBuild(json_generator, anf_node);
if (kernel_pack == nullptr) {
MS_LOG(ERROR) << "Akg build failed op[" << anf_node->fullname_with_scope() << "].";
return nullptr;
}
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
MS_LOG(INFO) << "Akg compile success, op[" << anf_node->fullname_with_scope() << "]";
return kernel_mod_ptr;
}
KernelModPtr AkgGpuKernelBuilder::FuseByJson(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "Akg start compile, graph_kernel[" << anf_node->fullname_with_scope() << "]";
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
if (mng == nullptr) {
mng = Manage(fg, true);
fg->set_manager(mng);
}
AnfNodePtrList node_list;
AnfNodePtrList input_list;
AnfNodePtrList output_list;
GetValidKernelNodes(fg, &node_list, &input_list, &output_list);
AkgKernelJsonGenerator json_generator;
if (!json_generator.CollectFusedJson(node_list, input_list, output_list)) {
MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
}
auto kernel_pack = OpBuild(json_generator, anf_node);
if (kernel_pack == nullptr) {
MS_LOG(ERROR) << "Akg build failed, graph_kernel[" << anf_node->fullname_with_scope() << "].";
return nullptr;
}
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
MS_LOG(INFO) << "Akg compile success, graph_kernel[" << anf_node->fullname_with_scope() << "]";
return kernel_mod_ptr;
}
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
AkgGpuKernelBuilder akg_gpu_kernel_builder;
if (AnfAlgo::IsGraphKernel(anf_node)) {
return akg_gpu_kernel_builder.FuseByJson(anf_node);
}
return akg_gpu_kernel_builder.BuildByJson(anf_node);
}
} // namespace kernel
} // namespace mindspore

View File

@ -28,19 +28,13 @@ class AkgGpuKernelBuilder : public AkgKernelBuilder {
~AkgGpuKernelBuilder() = default;
kernel::KernelBuildClient *GetClient() override { return &(kernel::GpuKernelBuildClient::Instance()); }
KernelPackPtr AkgSearchCache(const std::string &kernel_name, const std::string &processor) override;
KernelPackPtr AkgInsertCache(const std::string &kernel_name, const std::string &processor) override;
KernelPackPtr AkgSearchCache(const std::string &kernel_name) override;
KernelPackPtr AkgInsertCache(const std::string &kernel_name) override;
void AkgSetKernelMod(const KernelPackPtr &kernel_pack, const AkgKernelJsonGenerator &json_generator,
const AnfNodePtr &anf_node) override;
void AkgSaveJsonInfo(const string &kernel_name, const string &kernel_json) override;
KernelModPtr BuildByJson(const AnfNodePtr &anf_node);
KernelModPtr FuseByJson(const AnfNodePtr &anf_node);
private:
KernelPackPtr OpBuild(const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node);
};
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node);
} // namespace kernel
} // namespace mindspore

View File

@ -29,6 +29,7 @@
#include "ir/meta_tensor.h"
#include "base/core_ops.h"
#include "ir/graph_utils.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace kernel {
@ -803,6 +804,30 @@ std::string GetProcessorStr(const AnfNodePtr &anf_node) {
return processor;
}
Processor GetProcessorFromContext() {
kernel::Processor processor = kernel::Processor::UNKNOWN;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_info == kGPUDevice) {
processor = kernel::Processor::CUDA;
} else if (device_info == kAscendDevice) {
processor = kernel::Processor::AICORE;
}
return processor;
}
std::string GetStrProcessorFromContext() {
auto processor = GetProcessorFromContext();
string str_processor = kernel::kProcessorUnknown;
if (processor == kernel::Processor::CUDA) {
str_processor = kernel::kProcessorCuda;
} else if (processor == kernel::Processor::AICORE) {
str_processor = kernel::kProcessorAiCore;
}
return str_processor;
}
float Scaling(size_t in_size, size_t out_size, bool align_corners) {
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
: in_size / static_cast<float>(out_size);

View File

@ -102,6 +102,8 @@ void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<An
bool IsWeightBoundary(const AnfNodePtr &node);
std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode);
std::string GetProcessorStr(const AnfNodePtr &anf_node);
Processor GetProcessorFromContext();
std::string GetStrProcessorFromContext();
float Scaling(size_t in_size, size_t out_size, bool align_corners);
float ScaleGrid(const int x, const float scale);
struct CachedInterpolation {
@ -130,7 +132,6 @@ inline T ComputeLerp(T top_left, T top_right, T bottom_left, T bottom_right, T x
T bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
return top + (bottom - top) * y_lerp;
}
} // namespace kernel
} // namespace mindspore

View File

@ -32,6 +32,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "debug/anf_ir_dump.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace opt {
@ -85,7 +86,7 @@ inline int64_t CalNewIndex(int64_t old_index, int64_t reduce_index) {
}
} // namespace
std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
auto processor = GetProcessorFromContext();
auto processor = kernel::GetProcessorFromContext();
if (processor == kernel::Processor::AICORE) {
return std::make_shared<AtomicAddCheckerAscend>();
} else if (processor == kernel::Processor::CUDA) {
@ -401,8 +402,7 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
new_sub_graph->set_output(broadcast_to_node_inner);
auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner},
AnfAlgo::GetProcessor(atomic_add_node_));
SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));

View File

@ -260,7 +260,7 @@ AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, cons
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs);
SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs);
return graph_kernel_node;
}

View File

@ -143,7 +143,7 @@ AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_node));
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs);
MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope()
<< " with: " << graph_kernel_node->fullname_with_scope();
return graph_kernel_node;

View File

@ -34,7 +34,6 @@
#include "pipeline/jit/action.h"
#include "utils/context/graph_kernel_flags.h"
#include "vm/segment_runner.h"
#include "utils/ms_context.h"
#if ENABLE_GPU
#include "runtime/device/gpu/kernel_info_setter.h"
#endif
@ -306,7 +305,7 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(
}
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs, kernel::Processor processor) {
const AnfNodePtrList &outputs) {
std::vector<std::string> graph_input_format;
std::vector<TypeId> graph_input_type;
std::vector<std::string> graph_output_format;
@ -339,7 +338,7 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
graph_info_builder.SetInputsDeviceType(graph_input_type);
graph_info_builder.SetOutputsFormat(graph_output_format);
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(processor);
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto graph_selected_info = graph_info_builder.Build();
@ -443,7 +442,7 @@ std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<Anf
std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs);
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs);
// Handle get-item probleam.
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs);
@ -807,19 +806,6 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) {
return axis;
}
kernel::Processor GetProcessorFromContext() {
kernel::Processor processor = kernel::Processor::UNKNOWN;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_info == kGPUDevice) {
processor = kernel::Processor::CUDA;
} else if (device_info == kAscendDevice) {
processor = kernel::Processor::AICORE;
}
return processor;
}
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info) {
// Limitation: 1. Node's attributes should be set out of this function; 2. only one output.
MS_EXCEPTION_IF_NULL(out_info.type);
@ -876,7 +862,7 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &
info_builder.SetInputsDeviceType(input_types);
info_builder.SetOutputsFormat(output_formats);
info_builder.SetOutputsDeviceType(output_types);
info_builder.SetProcessor(GetProcessorFromContext());
info_builder.SetProcessor(kernel::GetProcessorFromContext());
info_builder.SetKernelType(KernelType::AKG_KERNEL);
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto selected_info = info_builder.Build();

View File

@ -57,7 +57,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes,
AnfNodePtrList *src_outputs = nullptr);
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs, kernel::Processor processor);
const AnfNodePtrList &outputs);
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs);
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
@ -84,7 +84,6 @@ TypePtr GetType(const AnfNodePtr &node);
ShapeVector GetShape(const AnfNodePtr &node);
ShapeVector GetDeviceShape(const AnfNodePtr &node);
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
kernel::Processor GetProcessorFromContext();
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);

View File

@ -484,7 +484,7 @@ class Splitter {
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(sub_func_graph, &outputs);
SetNewKernelInfo(cnode, sub_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_subgraph_cnode_));
SetNewKernelInfo(cnode, sub_func_graph, inputs, outputs);
}
}

View File

@ -46,7 +46,7 @@ bool TensorPromotion::Run(const FuncGraphPtr &func_graph) {
inputs.insert(inputs.end(), args.begin() + 1, args.end());
kernel::GetFuncGraphOutputNodes(fg, &outputs);
auto new_cnode = CreateNewFuseCNode(func_graph, fg, inputs, outputs);
SetNewKernelInfo(new_cnode, fg, inputs, outputs, AnfAlgo::GetProcessor(node));
SetNewKernelInfo(new_cnode, fg, inputs, outputs);
mng->Replace(node, new_cnode);
changed = true;
}