forked from mindspore-Ecosystem/mindspore
!2817 add internal output
Merge pull request !2817 from kisnwang/optimize-sub-graph-memcpy
This commit is contained in:
commit
32405f9ab3
|
@ -340,7 +340,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
|
||||
void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
std::vector<session::KernelWithIndex> non_communication_op;
|
||||
|
@ -351,6 +351,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph)
|
|||
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
|
||||
continue;
|
||||
}
|
||||
graph->AddFinalOutputKernel(item_with_index.first);
|
||||
if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
|
||||
AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
|
||||
} else {
|
||||
|
|
|
@ -95,7 +95,7 @@ class KernelRuntime {
|
|||
#endif
|
||||
|
||||
private:
|
||||
void AssignStaticMemoryOutput(const session::KernelGraph *graph);
|
||||
void AssignStaticMemoryOutput(session::KernelGraph *graph);
|
||||
void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
|
||||
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
|
||||
bool LaunchKernelMod(const session::KernelGraph &graph);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive_base.h"
|
||||
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -179,4 +179,43 @@ std::string get_id(const AnfNodePtr &node) {
|
|||
|
||||
void reset_id() { node_ids.clear(); }
|
||||
} // namespace id_generator
|
||||
|
||||
std::string GetCNodeTarget(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->device_target();
|
||||
if (!node->isa<CNode>()) {
|
||||
return default_target;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto attr_input = cnode->input(0);
|
||||
if (attr_input == nullptr) {
|
||||
return default_target;
|
||||
}
|
||||
auto value_node = attr_input->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return default_target;
|
||||
}
|
||||
auto value = value_node->value();
|
||||
if (value == nullptr) {
|
||||
return default_target;
|
||||
}
|
||||
if (!value->isa<Primitive>()) {
|
||||
return default_target;
|
||||
}
|
||||
auto primitive = value->cast<PrimitivePtr>();
|
||||
auto att_target = primitive->GetAttr("primitive_target");
|
||||
if (att_target != nullptr) {
|
||||
if (!att_target->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
||||
}
|
||||
auto target = GetValue<std::string>(att_target);
|
||||
if (kTargetSet.find(target) == kTargetSet.end()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
||||
}
|
||||
return target;
|
||||
}
|
||||
return default_target;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -448,7 +448,7 @@ void reset_id();
|
|||
} // namespace id_generator
|
||||
using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>;
|
||||
using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
|
||||
|
||||
std::string GetCNodeTarget(const AnfNodePtr &node);
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_IR_ANF_H_
|
||||
|
|
|
@ -46,6 +46,11 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
|||
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr front_node;
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
|
||||
front_node = kernel_graph->GetFrontNodeByInternalOutput(node);
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
MS_LOG(DEBUG) << "====process op: " << node->DebugString();
|
||||
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
|
||||
|
@ -56,7 +61,12 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
|||
return new_node;
|
||||
}
|
||||
}
|
||||
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
|
||||
auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_);
|
||||
if (kernel_graph != nullptr && front_node != nullptr) {
|
||||
auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node);
|
||||
kernel_graph->ReplaceInternalOutput(old_node, final_node);
|
||||
}
|
||||
return final_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -987,15 +987,6 @@ void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
|
|||
}
|
||||
}
|
||||
|
||||
KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
|
||||
auto it = graphs_.find(graph_id);
|
||||
if (it == graphs_.end()) {
|
||||
MS_LOG(WARNING) << "Can't find graph " << graph_id;
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]";
|
||||
|
|
|
@ -128,8 +128,6 @@ class AscendSession : public SessionBasic {
|
|||
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
|
||||
// insert depend to graph, used to attch control nodes to graph
|
||||
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
|
||||
// Get graph by graph id ,if not exist return null ptr
|
||||
KernelGraphPtr GetGraph(GraphId graph_id);
|
||||
// set child graph parameter if front arg is a anf
|
||||
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
|
||||
// set child graph parameter if front arg is a tensor
|
||||
|
|
|
@ -329,6 +329,9 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
|
|||
FrontBackendlMapUpdate(cnode, new_cnode);
|
||||
}
|
||||
AnfAlgo::SetGraphId(graph_id_, cnode.get());
|
||||
if (IsInternalOutput(cnode)) {
|
||||
ReplaceInternalOutput(cnode, new_cnode);
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
|
||||
|
@ -872,6 +875,76 @@ void KernelGraph::PrintGraphExecuteOrder() const {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) {
|
||||
if (front_node == nullptr || node == nullptr) {
|
||||
MS_LOG(INFO) << "Front node or node is nullptr";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
|
||||
front_to_internal_outputs_map_[front_node] = node;
|
||||
internal_outputs_to_front_map_[node] = front_node;
|
||||
}
|
||||
|
||||
void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) {
|
||||
if (new_node == nullptr || node == nullptr) {
|
||||
MS_LOG(INFO) << "New node or node is nullptr";
|
||||
return;
|
||||
}
|
||||
if (node == new_node) {
|
||||
MS_LOG(INFO) << "New node and node is the same";
|
||||
return;
|
||||
}
|
||||
auto iter = internal_outputs_to_front_map_.find(node);
|
||||
if (iter == internal_outputs_to_front_map_.end()) {
|
||||
MS_LOG(INFO) << "Node is not internal output";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString();
|
||||
internal_outputs_to_front_map_[new_node] = iter->second;
|
||||
front_to_internal_outputs_map_[iter->second] = new_node;
|
||||
internal_outputs_to_front_map_.erase(iter);
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
|
||||
auto iter = front_to_internal_outputs_map_.find(front_node);
|
||||
if (iter != front_to_internal_outputs_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
|
||||
if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const {
|
||||
auto iter = internal_outputs_to_front_map_.find(node);
|
||||
if (iter != internal_outputs_to_front_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return;
|
||||
}
|
||||
(void)final_output_kernels_.insert(node);
|
||||
}
|
||||
|
||||
bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (final_output_kernels_.find(node) != final_output_kernels_.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
|
||||
|
||||
KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }
|
||||
|
|
|
@ -144,6 +144,13 @@ class KernelGraph : public FuncGraph {
|
|||
void PrintGraphExecuteOrder() const;
|
||||
const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
|
||||
void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
|
||||
void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node);
|
||||
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node);
|
||||
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
|
||||
bool IsInternalOutput(const AnfNodePtr &node) const;
|
||||
AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const;
|
||||
void AddFinalOutputKernel(const AnfNodePtr &node);
|
||||
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
@ -202,6 +209,9 @@ class KernelGraph : public FuncGraph {
|
|||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
bool null_output_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> internal_outputs_to_front_map_;
|
||||
std::set<AnfNodePtr> final_output_kernels_;
|
||||
};
|
||||
} // namespace session
|
||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||
|
|
|
@ -95,6 +95,13 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
TypeId type_id = kNumberTypeFloat32;
|
||||
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
std::vector<int> temp_shape;
|
||||
if (graph.IsInternalOutput(node)) {
|
||||
temp_shape.emplace_back(1);
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
tensor->set_device_address(address);
|
||||
tensor->set_dirty(false);
|
||||
return tensor;
|
||||
}
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
// if in paynative mode,data only copyed to host when user want to print data
|
||||
|
@ -172,48 +179,6 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
|||
return new_value_node;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::vector<AnfNodePtr> pre_graph_out = {node};
|
||||
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
|
||||
if (!AnfAlgo::IsRealKernel(node)) {
|
||||
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
|
||||
}
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
|
||||
auto parameter = graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
parameter->set_abstract(abstract);
|
||||
auto new_parameter = graph->NewParameter(parameter);
|
||||
parameters.push_back(new_parameter);
|
||||
valid_inputs->push_back(valid_input);
|
||||
graph_inputs->push_back(new_parameter);
|
||||
};
|
||||
for (const auto &out_node : pre_graph_out) {
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
auto abstract = out_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
// create multiple parameters if is a tuple output real kernel
|
||||
if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
|
||||
for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
|
||||
create_parameter((*tuple_abstract)[output_idx]);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// create single parameter if is a abstract real kernel
|
||||
create_parameter(out_node->abstract());
|
||||
}
|
||||
return parameters;
|
||||
}
|
||||
|
||||
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
||||
|
@ -323,6 +288,103 @@ bool ExistSummaryNode(const KernelGraph *graph) {
|
|||
} // namespace
|
||||
|
||||
GraphId SessionBasic::graph_sum_ = 0;
|
||||
|
||||
KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) {
|
||||
auto it = graphs_.find(graph_id);
|
||||
if (it == graphs_.end()) {
|
||||
MS_LOG(WARNING) << "Can't find graph " << graph_id;
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) {
|
||||
auto graph_id = GetGraphIdByNode(out_node);
|
||||
if (graph_id == kInvalidGraphId) {
|
||||
return;
|
||||
}
|
||||
auto node_graph = GetGraph(graph_id);
|
||||
if (node_graph == nullptr) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
|
||||
auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
|
||||
if (ref_node == nullptr) {
|
||||
MS_LOG(INFO) << "No corresponding internal output for output node";
|
||||
return;
|
||||
}
|
||||
auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0);
|
||||
auto ref_real_node = real_kernel.first;
|
||||
auto ref_real_node_index = real_kernel.second;
|
||||
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
|
||||
node_graph->IsFinalOutputKernel(ref_real_node)) {
|
||||
auto kernel_info = ref_real_node->kernel_info();
|
||||
if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) {
|
||||
MS_LOG(INFO) << "No kernel info";
|
||||
return;
|
||||
}
|
||||
auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
|
||||
if (address == nullptr) {
|
||||
MS_LOG(INFO) << "No kernel address";
|
||||
return;
|
||||
}
|
||||
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
|
||||
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
|
||||
parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
auto d_kernel_info = parameter->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(d_kernel_info);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsDeviceType({type});
|
||||
builder.SetOutputsFormat({format});
|
||||
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input,
|
||||
KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::vector<AnfNodePtr> pre_graph_out = {node};
|
||||
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
|
||||
if (!AnfAlgo::IsRealKernel(node)) {
|
||||
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
|
||||
}
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
|
||||
auto parameter = graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
parameter->set_abstract(abstract);
|
||||
auto new_parameter = graph->NewParameter(parameter);
|
||||
parameters.push_back(new_parameter);
|
||||
valid_inputs->push_back(valid_input);
|
||||
graph_inputs->push_back(new_parameter);
|
||||
};
|
||||
for (const auto &out_node : pre_graph_out) {
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
auto abstract = out_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
// create multiple parameters if is a tuple output real kernel
|
||||
if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
|
||||
for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
|
||||
create_parameter((*tuple_abstract)[output_idx]);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// create single parameter if is a abstract real kernel
|
||||
create_parameter(out_node->abstract());
|
||||
InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]);
|
||||
}
|
||||
return parameters;
|
||||
}
|
||||
|
||||
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
|
||||
KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
|
@ -877,6 +939,29 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
|
|||
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
|
||||
if (backend_anf != nullptr) {
|
||||
auto front_real_kernel = AnfAlgo::VisitKernel(out, 0);
|
||||
auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0);
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
auto out_func_graph = out->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(out_func_graph);
|
||||
auto out_func_graph_manager = out_func_graph->manager();
|
||||
if (out_func_graph_manager == nullptr) {
|
||||
return backend_anf;
|
||||
}
|
||||
auto node_users = out_func_graph_manager->node_users();
|
||||
auto users = node_users[out];
|
||||
bool internal_output = true;
|
||||
std::string kernel_target = GetCNodeTarget(front_real_kernel.first);
|
||||
for (auto user : users) {
|
||||
if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
|
||||
internal_output = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (internal_output) {
|
||||
MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString();
|
||||
graph->AddInternalOutput(out, backend_real_kernel.first);
|
||||
}
|
||||
return backend_anf;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
|
||||
|
|
|
@ -110,6 +110,8 @@ class SessionBasic {
|
|||
#endif
|
||||
|
||||
protected:
|
||||
// Get graph by graph id ,if not exist return null ptr
|
||||
KernelGraphPtr GetGraph(GraphId graph_id);
|
||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
|
||||
|
@ -127,11 +129,13 @@ class SessionBasic {
|
|||
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
|
||||
// create a new kernel graph and update the graph sum
|
||||
KernelGraphPtr NewKernelGraph();
|
||||
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph);
|
||||
virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
|
||||
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
|
||||
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph);
|
||||
void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter);
|
||||
|
||||
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
|
|
|
@ -52,45 +52,6 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
|
|||
}
|
||||
|
||||
namespace {
|
||||
std::string GetCNodeTarget(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->device_target();
|
||||
if (!node->isa<CNode>()) {
|
||||
return default_target;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
if (attr_input == nullptr) {
|
||||
return default_target;
|
||||
}
|
||||
auto value_node = attr_input->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return default_target;
|
||||
}
|
||||
auto value = value_node->value();
|
||||
if (value == nullptr) {
|
||||
return default_target;
|
||||
}
|
||||
if (!value->isa<Primitive>()) {
|
||||
return default_target;
|
||||
}
|
||||
auto primitive = value->cast<PrimitivePtr>();
|
||||
auto att_target = primitive->GetAttr("primitive_target");
|
||||
if (att_target != nullptr) {
|
||||
if (!att_target->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
||||
}
|
||||
auto target = GetValue<std::string>(att_target);
|
||||
if (kTargetSet.find(target) == kTargetSet.end()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
||||
}
|
||||
return target;
|
||||
}
|
||||
return default_target;
|
||||
}
|
||||
|
||||
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
|
Loading…
Reference in New Issue