!2817 add internal output

Merge pull request !2817 from kisnwang/optimize-sub-graph-memcpy
This commit is contained in:
mindspore-ci-bot 2020-07-03 10:28:24 +08:00 committed by Gitee
commit 32405f9ab3
12 changed files with 269 additions and 97 deletions

View File

@ -340,7 +340,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
}
}
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
std::vector<session::KernelWithIndex> non_communication_op;
@ -351,6 +351,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph)
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
continue;
}
graph->AddFinalOutputKernel(item_with_index.first);
if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
} else {

View File

@ -95,7 +95,7 @@ class KernelRuntime {
#endif
private:
void AssignStaticMemoryOutput(const session::KernelGraph *graph);
void AssignStaticMemoryOutput(session::KernelGraph *graph);
void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
bool LaunchKernelMod(const session::KernelGraph &graph);

View File

@ -25,7 +25,7 @@
#include "ir/func_graph.h"
#include "ir/primitive_base.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
namespace mindspore {
@ -179,4 +179,43 @@ std::string get_id(const AnfNodePtr &node) {
void reset_id() { node_ids.clear(); }
} // namespace id_generator
std::string GetCNodeTarget(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
if (!node->isa<CNode>()) {
return default_target;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto attr_input = cnode->input(0);
if (attr_input == nullptr) {
return default_target;
}
auto value_node = attr_input->cast<ValueNodePtr>();
if (value_node == nullptr) {
return default_target;
}
auto value = value_node->value();
if (value == nullptr) {
return default_target;
}
if (!value->isa<Primitive>()) {
return default_target;
}
auto primitive = value->cast<PrimitivePtr>();
auto att_target = primitive->GetAttr("primitive_target");
if (att_target != nullptr) {
if (!att_target->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
auto target = GetValue<std::string>(att_target);
if (kTargetSet.find(target) == kTargetSet.end()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
return target;
}
return default_target;
}
} // namespace mindspore

View File

@ -448,7 +448,7 @@ void reset_id();
} // namespace id_generator
using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>;
using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
std::string GetCNodeTarget(const AnfNodePtr &node);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_ANF_H_

View File

@ -46,6 +46,11 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
return nullptr;
}
AnfNodePtr front_node;
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
front_node = kernel_graph->GetFrontNodeByInternalOutput(node);
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
MS_LOG(DEBUG) << "====process op: " << node->DebugString();
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
@ -56,7 +61,12 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
return new_node;
}
}
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_);
if (kernel_graph != nullptr && front_node != nullptr) {
auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node);
kernel_graph->ReplaceInternalOutput(old_node, final_node);
}
return final_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -987,15 +987,6 @@ void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
}
}
KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
auto it = graphs_.find(graph_id);
if (it == graphs_.end()) {
MS_LOG(WARNING) << "Can't find graph " << graph_id;
return nullptr;
}
return it->second;
}
void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) {
MS_LOG(INFO) << "Start!";
MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]";

View File

@ -128,8 +128,6 @@ class AscendSession : public SessionBasic {
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
// insert depend to graph, used to attch control nodes to graph
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id);
// set child graph parameter if front arg is a anf
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
// set child graph parameter if front arg is a tensor

View File

@ -329,6 +329,9 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
FrontBackendlMapUpdate(cnode, new_cnode);
}
AnfAlgo::SetGraphId(graph_id_, cnode.get());
if (IsInternalOutput(cnode)) {
ReplaceInternalOutput(cnode, new_cnode);
}
return new_cnode;
}
@ -872,6 +875,76 @@ void KernelGraph::PrintGraphExecuteOrder() const {
}
}
void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) {
if (front_node == nullptr || node == nullptr) {
MS_LOG(INFO) << "Front node or node is nullptr";
return;
}
MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
front_to_internal_outputs_map_[front_node] = node;
internal_outputs_to_front_map_[node] = front_node;
}
void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) {
if (new_node == nullptr || node == nullptr) {
MS_LOG(INFO) << "New node or node is nullptr";
return;
}
if (node == new_node) {
MS_LOG(INFO) << "New node and node is the same";
return;
}
auto iter = internal_outputs_to_front_map_.find(node);
if (iter == internal_outputs_to_front_map_.end()) {
MS_LOG(INFO) << "Node is not internal output";
return;
}
MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString();
internal_outputs_to_front_map_[new_node] = iter->second;
front_to_internal_outputs_map_[iter->second] = new_node;
internal_outputs_to_front_map_.erase(iter);
}
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
auto iter = front_to_internal_outputs_map_.find(front_node);
if (iter != front_to_internal_outputs_map_.end()) {
return iter->second;
}
return nullptr;
}
bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) {
return true;
}
return false;
}
AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const {
auto iter = internal_outputs_to_front_map_.find(node);
if (iter != internal_outputs_to_front_map_.end()) {
return iter->second;
}
return nullptr;
}
void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) {
if (node == nullptr) {
return;
}
(void)final_output_kernels_.insert(node);
}
bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
if (node == nullptr) {
return false;
}
if (final_output_kernels_.find(node) != final_output_kernels_.end()) {
return true;
}
return false;
}
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }

View File

@ -144,6 +144,13 @@ class KernelGraph : public FuncGraph {
void PrintGraphExecuteOrder() const;
const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node);
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node);
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
bool IsInternalOutput(const AnfNodePtr &node) const;
AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const;
void AddFinalOutputKernel(const AnfNodePtr &node);
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
private:
// remove value node form graph
@ -202,6 +209,9 @@ class KernelGraph : public FuncGraph {
CNodePtr start_label_;
CNodePtr end_goto_;
bool null_output_;
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
std::unordered_map<AnfNodePtr, AnfNodePtr> internal_outputs_to_front_map_;
std::set<AnfNodePtr> final_output_kernels_;
};
} // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;

View File

@ -95,6 +95,13 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
TypeId type_id = kNumberTypeFloat32;
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
std::vector<int> temp_shape;
if (graph.IsInternalOutput(node)) {
temp_shape.emplace_back(1);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_device_address(address);
tensor->set_dirty(false);
return tensor;
}
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
// if in paynative mode,data only copyed to host when user want to print data
@ -172,48 +179,6 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
return new_value_node;
}
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters;
std::vector<AnfNodePtr> pre_graph_out = {node};
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
if (!AnfAlgo::IsRealKernel(node)) {
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
}
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
auto parameter = graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter);
parameter->set_abstract(abstract);
auto new_parameter = graph->NewParameter(parameter);
parameters.push_back(new_parameter);
valid_inputs->push_back(valid_input);
graph_inputs->push_back(new_parameter);
};
for (const auto &out_node : pre_graph_out) {
MS_EXCEPTION_IF_NULL(out_node);
auto abstract = out_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
// create multiple parameters if is a tuple output real kernel
if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
create_parameter((*tuple_abstract)[output_idx]);
}
continue;
}
// create single parameter if is a abstract real kernel
create_parameter(out_node->abstract());
}
return parameters;
}
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Load kInputCtrlTensors";
@ -323,6 +288,103 @@ bool ExistSummaryNode(const KernelGraph *graph) {
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) {
auto it = graphs_.find(graph_id);
if (it == graphs_.end()) {
MS_LOG(WARNING) << "Can't find graph " << graph_id;
return nullptr;
}
return it->second;
}
void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
auto graph_id = GetGraphIdByNode(out_node);
if (graph_id == kInvalidGraphId) {
return;
}
auto node_graph = GetGraph(graph_id);
if (node_graph == nullptr) {
return;
}
MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
if (ref_node == nullptr) {
MS_LOG(INFO) << "No corresponding internal output for output node";
return;
}
auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0);
auto ref_real_node = real_kernel.first;
auto ref_real_node_index = real_kernel.second;
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
node_graph->IsFinalOutputKernel(ref_real_node)) {
auto kernel_info = ref_real_node->kernel_info();
if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) {
MS_LOG(INFO) << "No kernel info";
return;
}
auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
if (address == nullptr) {
MS_LOG(INFO) << "No kernel address";
return;
}
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
auto d_kernel_info = parameter->kernel_info();
MS_EXCEPTION_IF_NULL(d_kernel_info);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetOutputsDeviceType({type});
builder.SetOutputsFormat({format});
d_kernel_info->set_select_kernel_build_info(builder.Build());
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
}
}
std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input,
KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters;
std::vector<AnfNodePtr> pre_graph_out = {node};
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
if (!AnfAlgo::IsRealKernel(node)) {
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
}
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
auto parameter = graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter);
parameter->set_abstract(abstract);
auto new_parameter = graph->NewParameter(parameter);
parameters.push_back(new_parameter);
valid_inputs->push_back(valid_input);
graph_inputs->push_back(new_parameter);
};
for (const auto &out_node : pre_graph_out) {
MS_EXCEPTION_IF_NULL(out_node);
auto abstract = out_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
// create multiple parameters if is a tuple output real kernel
if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
create_parameter((*tuple_abstract)[output_idx]);
}
continue;
}
// create single parameter if is a abstract real kernel
create_parameter(out_node->abstract());
InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]);
}
return parameters;
}
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
@ -877,6 +939,29 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
if (backend_anf != nullptr) {
auto front_real_kernel = AnfAlgo::VisitKernel(out, 0);
auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0);
MS_EXCEPTION_IF_NULL(out);
auto out_func_graph = out->func_graph();
MS_EXCEPTION_IF_NULL(out_func_graph);
auto out_func_graph_manager = out_func_graph->manager();
if (out_func_graph_manager == nullptr) {
return backend_anf;
}
auto node_users = out_func_graph_manager->node_users();
auto users = node_users[out];
bool internal_output = true;
std::string kernel_target = GetCNodeTarget(front_real_kernel.first);
for (auto user : users) {
if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
internal_output = false;
break;
}
}
if (internal_output) {
MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString();
graph->AddInternalOutput(out, backend_real_kernel.first);
}
return backend_anf;
}
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";

View File

@ -110,6 +110,8 @@ class SessionBasic {
#endif
protected:
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id);
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const;
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
@ -127,11 +129,13 @@ class SessionBasic {
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph);
virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph);
void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter);
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;

View File

@ -52,45 +52,6 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
}
namespace {
std::string GetCNodeTarget(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
if (!node->isa<CNode>()) {
return default_target;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto attr_input = cnode->input(kAnfPrimitiveIndex);
if (attr_input == nullptr) {
return default_target;
}
auto value_node = attr_input->cast<ValueNodePtr>();
if (value_node == nullptr) {
return default_target;
}
auto value = value_node->value();
if (value == nullptr) {
return default_target;
}
if (!value->isa<Primitive>()) {
return default_target;
}
auto primitive = value->cast<PrimitivePtr>();
auto att_target = primitive->GetAttr("primitive_target");
if (att_target != nullptr) {
if (!att_target->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
auto target = GetValue<std::string>(att_target);
if (kTargetSet.find(target) == kTargetSet.end()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
return target;
}
return default_target;
}
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);