forked from mindspore-Ecosystem/mindspore
move opt to build graph
This commit is contained in:
parent
64abbeaa89
commit
cc54bb565d
|
@ -22,28 +22,32 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
|
||||
if (input_index >= inputs_format_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node";
|
||||
MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
|
||||
return kInvalidFormat;
|
||||
}
|
||||
return inputs_format_[input_index];
|
||||
}
|
||||
|
||||
std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
|
||||
if (output_index >= outputs_format_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of input node";
|
||||
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node";
|
||||
return kInvalidFormat;
|
||||
}
|
||||
return outputs_format_[output_index];
|
||||
}
|
||||
|
||||
TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const {
|
||||
if (input_index >= inputs_device_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node";
|
||||
MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input";
|
||||
return TypeId::kNumberTypeEnd;
|
||||
}
|
||||
return inputs_device_type_[input_index];
|
||||
}
|
||||
|
||||
TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
|
||||
if (output_index >= outputs_device_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of input node";
|
||||
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
|
||||
return TypeId::kNumberTypeEnd;
|
||||
}
|
||||
return outputs_device_type_[output_index];
|
||||
}
|
||||
|
|
|
@ -82,6 +82,9 @@ class KernelBuildInfo {
|
|||
|
||||
bool operator==(const KernelBuildInfo &other) const;
|
||||
|
||||
public:
|
||||
static auto constexpr kInvalidFormat = "InvalidFormat";
|
||||
|
||||
private:
|
||||
KernelType kernel_type_;
|
||||
std::vector<std::string> inputs_format_;
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
void FilterInvaildKernelInfo(const CNodePtr &kernel_node,
|
||||
void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
|
||||
|
@ -63,9 +63,9 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
|
|||
HcclMetadataInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
if (kernel_info_list->empty()) {
|
||||
MS_LOG(EXCEPTION) << "op" << kernel_node->DebugString() << "kernel query fail!";
|
||||
MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!";
|
||||
}
|
||||
FilterInvaildKernelInfo(kernel_node, kernel_info_list);
|
||||
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,24 +46,40 @@ RtKerDescFactory &RtKerDescFactory::Get() {
|
|||
|
||||
void GetRtKelInfo(const CNodePtr &kernel_node,
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
MS_LOG(INFO) << "Mng kernel Info.";
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node);
|
||||
(void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower);
|
||||
|
||||
auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower);
|
||||
if (ker_desc_ptr == nullptr) {
|
||||
MS_LOG(DEBUG) << "Mng can't find op [" << opNameLower << "].";
|
||||
if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) {
|
||||
*kernel_info_list = ker_desc_ptr->GetKernelInfo();
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ker_desc_ptr);
|
||||
auto kernel_info = ker_desc_ptr->GetKernelInfo();
|
||||
if (kernel_info.empty()) {
|
||||
MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "].";
|
||||
// if can't find kernel info in kernel info database, use the default kernel info
|
||||
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (node_name == "StreamSwitch" || node_name == "StreamActive") {
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set input infos
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
kernel_build_info_builder->SetInputsFormat(std::vector<std::string>(input_num, kOpFormat_DEFAULT));
|
||||
std::vector<TypeId> input_types = {};
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i));
|
||||
}
|
||||
kernel_build_info_builder->SetInputsDeviceType(input_types);
|
||||
// set output info
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown));
|
||||
// set ohter info
|
||||
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
|
||||
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
|
||||
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
|
||||
kernel_info_list->push_back(kernel_build_info_builder->Build());
|
||||
return;
|
||||
}
|
||||
*kernel_info_list = kernel_info;
|
||||
MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "].";
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -186,7 +186,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
save_graphs_path = ".";
|
||||
}
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir";
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" +
|
||||
std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, kernel_graph);
|
||||
DumpIRProto(kernel_graph, "before_hwopt");
|
||||
}
|
||||
|
@ -208,7 +209,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir";
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" +
|
||||
std::to_string(kernel_graph->graph_id()) + ".ir ";
|
||||
DumpIR(file_path, kernel_graph);
|
||||
}
|
||||
}
|
||||
|
@ -252,7 +254,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
save_graphs_path = ".";
|
||||
}
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_before.ir";
|
||||
std::string file_path =
|
||||
save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, kernel_graph);
|
||||
}
|
||||
// data layout optimization
|
||||
|
@ -278,7 +281,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_end.ir";
|
||||
std::string file_path =
|
||||
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, kernel_graph, true);
|
||||
DumpIRProto(kernel_graph, "after_hwopt");
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
|
|
|
@ -300,7 +300,12 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
|
|||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
return build_info->GetOutputFormat(output_idx);
|
||||
auto format = build_info->GetOutputFormat(output_idx);
|
||||
if (format == kernel::KernelBuildInfo::kInvalidFormat) {
|
||||
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
|
||||
<< " has a invalid output format";
|
||||
}
|
||||
return format;
|
||||
}
|
||||
|
||||
std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
|
||||
|
@ -314,7 +319,12 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
|
|||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
return build_info->GetInputFormat(input_idx);
|
||||
auto format = build_info->GetInputFormat(input_idx);
|
||||
if (format == kernel::KernelBuildInfo::kInvalidFormat) {
|
||||
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
|
||||
<< " has a invalid input format";
|
||||
}
|
||||
return format;
|
||||
}
|
||||
|
||||
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
|
@ -481,7 +491,12 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
|
|||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
return build_info->GetOutputDeviceType(output_idx);
|
||||
auto dtype = build_info->GetOutputDeviceType(output_idx);
|
||||
if (dtype == TypeId::kNumberTypeEnd) {
|
||||
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
|
||||
<< " has a invalid dtype";
|
||||
}
|
||||
return dtype;
|
||||
}
|
||||
|
||||
TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
|
||||
|
@ -494,7 +509,12 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
|
|||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
return build_info->GetInputDeviceType(input_idx);
|
||||
auto dtype = build_info->GetInputDeviceType(input_idx);
|
||||
if (dtype == TypeId::kNumberTypeEnd) {
|
||||
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
|
||||
<< " has a invalid dtype";
|
||||
}
|
||||
return dtype;
|
||||
}
|
||||
|
||||
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
*/
|
||||
#include "session/ascend_session.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <set>
|
||||
#include "operator/ops.h"
|
||||
#include "ir/meta_tensor.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -75,28 +78,15 @@ void DumpGraphInputArgs(const VectorRef &args) {
|
|||
|
||||
void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (auto &node : graph->execution_order()) {
|
||||
if (is_override || AnfAlgo::GetStreamDistinctionLabel(node.get()) == kInvalidDistincLabel) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
AnfAlgo::SetStreamDistinctionLabel(label, node.get());
|
||||
}
|
||||
if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) {
|
||||
graph->set_stream_distinction_label(label);
|
||||
}
|
||||
}
|
||||
|
||||
GraphId GetDistinctionLabel(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// if graph is empty,use graph id as distinction label
|
||||
if (graph->execution_order().empty()) {
|
||||
return graph->graph_id();
|
||||
}
|
||||
// else use first node of execution order as label
|
||||
return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get());
|
||||
}
|
||||
|
||||
std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
|
||||
auto valid_inputs = graph->ValidInputs();
|
||||
auto valid_inputs = graph->valid_inputs();
|
||||
size_t real_args_size = 0;
|
||||
std::vector<BaseRef> real_args = {};
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
|
@ -141,23 +131,9 @@ std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar
|
|||
|
||||
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
MS_LOG(INFO) << "start";
|
||||
auto graph_id = graph_sum_;
|
||||
// construct graph, if successfully, graph_sum_ + 1
|
||||
auto graph = ConstructKernelGraph(lst, outputs);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
opt::AscendBackendIRFusionOptimization(graph);
|
||||
// select kernel build info
|
||||
SelectKernel(*graph);
|
||||
// convert kernel Graph to model
|
||||
predictmodel::StepConvertGraph(graph);
|
||||
// optimize graph
|
||||
HardwareOptimize(graph);
|
||||
// init runtime resource
|
||||
InitRuntimeResource();
|
||||
// assign static memory of parameters
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
runtime_instance->AssignStaticMemoryInput(graph.get());
|
||||
auto graph_id = graph->graph_id();
|
||||
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
|
||||
return graph_id;
|
||||
}
|
||||
|
@ -166,16 +142,36 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|||
MS_LOG(INFO) << "start";
|
||||
auto graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// resource initialize
|
||||
InitRuntimeResource();
|
||||
// multiple graph handle
|
||||
if (graph_id == final_graph_id_) {
|
||||
if (!graph->executable()) {
|
||||
return;
|
||||
}
|
||||
// insert assigns to child graph
|
||||
InsertAllAssigns();
|
||||
// insert switch and active to child graph
|
||||
MergeSwitchCompile();
|
||||
// OptChildGraphs
|
||||
auto graph_order = GetGraphOrder(final_graph_id_);
|
||||
auto &graph_type = GetGraphOrderType(final_graph_id_);
|
||||
for (size_t i = 0; i < graph_order.size(); i++) {
|
||||
if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "Start build child graph " << graph_order[i];
|
||||
auto child_graph = GetGraph(graph_order[i]);
|
||||
CompileChildGraph(child_graph);
|
||||
}
|
||||
// merge child graph
|
||||
MergeGraphExecOrder();
|
||||
} else {
|
||||
auto single_graph = GetGraph(graph_id);
|
||||
CompileChildGraph(single_graph);
|
||||
// set the distinction label of single graph
|
||||
SetStreamDistinctionLabel(GetGraph(graph_id), graph_id, false);
|
||||
single_graph->set_stream_distinction_label(graph_id);
|
||||
single_graph->UpdateExecuteKernelStreamLabel();
|
||||
}
|
||||
// adjust execution order because merge child graph and other special operations
|
||||
AdjustKernel(graph);
|
||||
|
@ -197,9 +193,26 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|||
// load task info to device if it is sink mode
|
||||
LoadTask(graph);
|
||||
}
|
||||
// sync the inital const tensor to device
|
||||
SyncInitialTenosrToDevice();
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
opt::AscendBackendIRFusionOptimization(child_graph);
|
||||
// select kernel build info
|
||||
SelectKernel(*child_graph);
|
||||
// convert kernel Graph to model
|
||||
predictmodel::StepConvertGraph(child_graph);
|
||||
// optimize graph
|
||||
HardwareOptimize(child_graph);
|
||||
// assign static memory of parameters
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
runtime_instance->AssignStaticMemoryInput(child_graph.get());
|
||||
}
|
||||
|
||||
void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *const outputs) {
|
||||
MS_LOG(INFO) << "start";
|
||||
|
@ -458,11 +471,9 @@ void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const
|
|||
|
||||
GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
|
||||
MS_LOG(INFO) << "Start! Args size " << args.size();
|
||||
auto final_graph = std::make_shared<KernelGraph>();
|
||||
final_graph_id_ = graph_sum_++;
|
||||
graphs_[final_graph_id_] = final_graph;
|
||||
final_graph->set_graph_id(final_graph_id_);
|
||||
MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << "success";
|
||||
auto final_graph = NewKernelGraph();
|
||||
final_graph_id_ = final_graph->graph_id();
|
||||
MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success";
|
||||
// init private variables and bind them with final_graph_id
|
||||
graph_execute_orders_[final_graph_id_] = std::vector<GraphId>();
|
||||
graph_order_types_[final_graph_id_] = std::vector<GraphType>();
|
||||
|
@ -498,6 +509,46 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
|
|||
return final_graph_id_;
|
||||
}
|
||||
|
||||
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
|
||||
auto fake_graph = GetGraph(fake_graph_id);
|
||||
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
|
||||
auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr {
|
||||
auto parameter = fake_graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
parameter->set_abstract(abstract);
|
||||
auto new_parameter = fake_graph->NewParameter(parameter);
|
||||
// Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory.
|
||||
auto graph_inputs = fake_graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
graph_inputs->push_back(new_parameter);
|
||||
return new_parameter;
|
||||
};
|
||||
auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto abstract = cnode->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
// create multiple parameters if is a tuple output real kernel
|
||||
if (abstract->isa<abstract::AbstractTuple>()) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
|
||||
return create_parameter((*tuple_abstract)[output_idx]);
|
||||
}
|
||||
return create_parameter(cnode->abstract());
|
||||
};
|
||||
if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
auto make_tuple = output_item_with_index.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
|
||||
auto input = make_tuple->inputs()[i];
|
||||
make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input));
|
||||
}
|
||||
return fake_graph->NewCNode(make_tuple_inputs);
|
||||
}
|
||||
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
|
||||
}
|
||||
|
||||
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
|
||||
auto final_graph = GetGraph(final_graph_id_);
|
||||
MS_EXCEPTION_IF_NULL(final_graph);
|
||||
|
@ -559,12 +610,6 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
|
|||
condition_graph->AddValueNodeToGraph(counter_const);
|
||||
// create a new switch op
|
||||
auto switch_primitive = std::make_shared<Primitive>("StreamSwitch");
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeInt32});
|
||||
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
|
||||
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
|
||||
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
|
||||
auto cond_output_it = condition_output_.find(condition_graph_id);
|
||||
if (cond_output_it == condition_output_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id;
|
||||
|
@ -574,11 +619,9 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
|
|||
MS_EXCEPTION_IF_NULL(cond_output_kernel);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const};
|
||||
CNodePtr switch_node = condition_graph->NewCNode(inputs);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get());
|
||||
MS_EXCEPTION_IF_NULL(switch_node);
|
||||
switch_node->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
AnfAlgo::SetGraphId(condition_graph_id, switch_node.get());
|
||||
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(GetGraph(condition_graph_id)), switch_node.get());
|
||||
// set attr: cond_ RT_GREATER
|
||||
AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue<int>(static_cast<int>(RT_GREATER)), switch_node);
|
||||
// set attr:data_type
|
||||
|
@ -586,9 +629,9 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
|
|||
// set attr:true branch graph id ,which is same to stream distinction label
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(true_graph_id), switch_node);
|
||||
// append switch at the end of condition graph
|
||||
std::vector<CNodePtr> exec_order = condition_graph->execution_order();
|
||||
exec_order.push_back(switch_node);
|
||||
condition_graph->set_execution_order(exec_order);
|
||||
auto return_node = condition_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
InsertControlDependToGraph(condition_graph_id, return_node->input(1), switch_node);
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
|
@ -615,8 +658,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
|
|||
MS_EXCEPTION_IF_NULL(true_last);
|
||||
MS_EXCEPTION_IF_NULL(false_last);
|
||||
MS_LOG(INFO) << "The last graph of false branch is " << false_last_id;
|
||||
// now only consider the single output
|
||||
InsertMultipleAssignToGraph(true_last_id, true_last->output(), false_last->output());
|
||||
// create fake output
|
||||
auto fake_output_graph = NewKernelGraph();
|
||||
graph_execute_order.push_back(fake_output_graph->graph_id());
|
||||
graph_order_type.push_back(COMMON_GRAPH);
|
||||
fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output()));
|
||||
final_graph->set_output(fake_output_graph->output());
|
||||
InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output());
|
||||
InsertMultipleAssignToGraph(false_last_id, false_last->output(), final_graph->output());
|
||||
// insert stream active for loop sink
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -650,14 +699,14 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
|
|||
if (false_graph_id != kInvalidGraphId) {
|
||||
// false graph and condition in graph same stream
|
||||
auto condition_graph = GetGraph(cond_graph_id);
|
||||
SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true);
|
||||
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
|
||||
// if false graph is a condition graph and has been switch compiled before,it's false should be updated again
|
||||
auto cond_it = switches_.find(false_graph_id);
|
||||
while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) {
|
||||
cond_graph_id = cond_it->first;
|
||||
false_graph_id = cond_it->second.second;
|
||||
condition_graph = GetGraph(cond_graph_id);
|
||||
SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true);
|
||||
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
|
||||
cond_it = switches_.find(false_graph_id);
|
||||
}
|
||||
}
|
||||
|
@ -691,7 +740,7 @@ void AscendSession::MergeSwitchCompile() {
|
|||
}
|
||||
// insert stream active to common graph
|
||||
if (prev_graph_id != kInvalidGraphId) {
|
||||
InsertStreamActiveToGraph(prev_graph_id, GetDistinctionLabel(condition_graph));
|
||||
InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label());
|
||||
}
|
||||
// if this is a 'if' condition
|
||||
auto it = while_condition_graphs_.find(cond_graph_id);
|
||||
|
@ -700,12 +749,39 @@ void AscendSession::MergeSwitchCompile() {
|
|||
} else {
|
||||
// if it is a while,insert a stream active to true graph
|
||||
GraphId from_graph = it->second;
|
||||
InsertStreamActiveToGraph(from_graph, GetDistinctionLabel(condition_graph));
|
||||
InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label());
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::InsertAllAssigns() {
|
||||
std::set<std::pair<AnfNodePtr, AnfNodePtr>> assigns;
|
||||
for (auto assign : assigns_) {
|
||||
auto front_anf = std::get<0>(assign);
|
||||
auto to_graph_id = std::get<1>(assign);
|
||||
auto input_idx = std::get<2>(assign);
|
||||
auto to_graph = GetGraph(to_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(to_graph);
|
||||
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
|
||||
if (input_idx >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
auto backend_parameter = graph_inputs[input_idx];
|
||||
(void)assigns.insert(std::pair<AnfNodePtr, AnfNodePtr>(front_anf, backend_parameter));
|
||||
}
|
||||
// erase the repeat assign
|
||||
for (auto &assign : assigns) {
|
||||
auto front_anf = assign.first;
|
||||
auto backend_parameter = assign.second;
|
||||
auto from_graph_id = GetGraphIdByNode(front_anf);
|
||||
auto from_graph = GetGraph(from_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(from_graph);
|
||||
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
|
||||
InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
|
||||
}
|
||||
}
|
||||
|
||||
// insert active to graph
|
||||
void AscendSession::SetActive(GraphId from, GraphId to) {
|
||||
if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) {
|
||||
|
@ -735,20 +811,21 @@ void AscendSession::SetActive(GraphId from, GraphId to) {
|
|||
while_condition_graphs_[to] = from;
|
||||
}
|
||||
|
||||
void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter) {
|
||||
void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
MS_EXCEPTION_IF_NULL(backend_parameter);
|
||||
MS_EXCEPTION_IF_NULL(front_anf);
|
||||
if (!backend_parameter->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "Backend parameter's type is not a parameter,but is " << backend_parameter->ToString();
|
||||
}
|
||||
auto from_graph_id = GetGraphIdByNode(front_anf);
|
||||
auto from_graph = GetGraph(from_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(from_graph);
|
||||
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get());
|
||||
auto to_graph = GetGraph(to_graph_id);
|
||||
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
|
||||
MS_EXCEPTION_IF_NULL(to_graph);
|
||||
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
|
||||
if (input_idx >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
auto backend_parameter = graph_inputs[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(backend_parameter);
|
||||
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
|
||||
MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node["
|
||||
<< backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get())
|
||||
<< "]";
|
||||
|
@ -759,39 +836,21 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An
|
|||
// if arg is the the parameter of child graph,it is parameter of final graph too
|
||||
if (front_anf->isa<Parameter>()) {
|
||||
MS_EXCEPTION_IF_NULL(backend_arg);
|
||||
if (!AnfAlgo::OutputAddrExist(backend_arg, 0)) {
|
||||
// set parameter's addr in child graph to parameter in final graph
|
||||
AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_parameter, 0), 0, backend_arg.get());
|
||||
MS_LOG(INFO) << "Assign mem of node" << backend_parameter->DebugString() << " of graph "
|
||||
<< AnfAlgo::GetGraphId(backend_parameter.get()) << " to node" << backend_arg->DebugString()
|
||||
<< "of graph " << AnfAlgo::GetGraphId(backend_arg.get());
|
||||
return;
|
||||
}
|
||||
// if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device
|
||||
// type same to arg
|
||||
if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) {
|
||||
AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get());
|
||||
}
|
||||
// if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph
|
||||
// unless it's a weight.If backend_parameter is a weight,we should assign the value back.
|
||||
AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_arg, 0), 0, backend_parameter.get());
|
||||
MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString()
|
||||
<< "] will be replaced.";
|
||||
to_graph->ReplaceNode(backend_parameter, backend_arg);
|
||||
return;
|
||||
}
|
||||
InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node"
|
||||
<< backend_parameter->DebugString() << "of graph " << to_graph_id;
|
||||
(void)assigns_.insert(std::tuple<AnfNodePtr, GraphId, size_t>(front_anf, to_graph_id, input_idx));
|
||||
}
|
||||
|
||||
void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter) {
|
||||
void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id,
|
||||
size_t input_idx) {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
// sync data from host to device
|
||||
MS_EXCEPTION_IF_NULL(front_tensor);
|
||||
size_t tensor_size = front_tensor->data().nbytes();
|
||||
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
|
||||
front_tensor->data_type(), front_tensor->data_c(false))) {
|
||||
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
|
||||
}
|
||||
std::pair<GraphId, size_t> graph_input_pair(to_graph_id, input_idx);
|
||||
initial_tenosrs_[graph_input_pair] = front_tensor;
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
|
@ -818,10 +877,9 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfN
|
|||
if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
return input_index + output_num;
|
||||
}
|
||||
auto &graph_inputs = graph->inputs();
|
||||
auto &valid_inputs = graph->ValidInputs();
|
||||
auto valid_inputs = graph->valid_inputs();
|
||||
if (valid_inputs[input_index]) {
|
||||
SetChildGraphParameter(node, graph_inputs[input_index]);
|
||||
SetChildGraphParameter(node, graph->graph_id(), input_index);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString();
|
||||
}
|
||||
|
@ -833,8 +891,7 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const Valu
|
|||
if (!value->isa<Tensor>()) {
|
||||
MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString();
|
||||
}
|
||||
auto &graph_inputs = graph->inputs();
|
||||
SetChildGraphParameter(value->cast<TensorPtr>(), graph_inputs[input_index]);
|
||||
SetChildGraphParameter(value->cast<TensorPtr>(), graph->graph_id(), input_index);
|
||||
return ++input_index;
|
||||
}
|
||||
|
||||
|
@ -905,8 +962,6 @@ GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
|
|||
|
||||
void AscendSession::MergeGraphExecOrder() {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
// insert switch to graph
|
||||
MergeSwitchCompile();
|
||||
// merge graph order
|
||||
auto &graph_order = GetGraphOrder(final_graph_id_);
|
||||
auto &graph_type = GetGraphOrderType(final_graph_id_);
|
||||
|
@ -916,6 +971,13 @@ void AscendSession::MergeGraphExecOrder() {
|
|||
MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!";
|
||||
return;
|
||||
}
|
||||
if (graph_order.size() > 1) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->enable_task_sink()) {
|
||||
MS_LOG(INFO) << "Control sink network should run with task-sink mode!";
|
||||
}
|
||||
}
|
||||
// if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph
|
||||
SetStreamDistinctionLabel(final_graph, graph_order[0], false);
|
||||
std::vector<CNodePtr> final_exec_order = final_graph->execution_order();
|
||||
|
@ -930,7 +992,11 @@ void AscendSession::MergeGraphExecOrder() {
|
|||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
auto exec_order = child_graph->execution_order();
|
||||
MS_LOG(INFO) << "Merge graph,graph_id " << graph_id;
|
||||
(void)std::copy(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order));
|
||||
(void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order),
|
||||
[&](CNodePtr node) -> CNodePtr {
|
||||
AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get());
|
||||
return node;
|
||||
});
|
||||
// add all value nodes of child graphs to final graph
|
||||
for (auto &value_node : child_graph->graph_value_nodes()) {
|
||||
final_graph->AddValueNodeToGraph(value_node);
|
||||
|
@ -969,15 +1035,9 @@ void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from
|
|||
// generate a new cnode
|
||||
auto assign_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(assign_node);
|
||||
assign_node->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), assign_node.get());
|
||||
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(graph), assign_node.get());
|
||||
assign_node->set_abstract(to->abstract());
|
||||
// append the assign at the end of from graph
|
||||
auto exec_order = graph->execution_order();
|
||||
exec_order.push_back(assign_node);
|
||||
graph->set_execution_order(exec_order);
|
||||
InsertDependToGraph(graph_id, assign_node);
|
||||
}
|
||||
|
||||
void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) {
|
||||
|
@ -997,24 +1057,46 @@ void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodeP
|
|||
|
||||
void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) {
|
||||
MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream;
|
||||
auto from_graph = graphs_[graph_id];
|
||||
auto from_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(from_graph);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("StreamActive"))};
|
||||
auto active_node = from_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(active_node);
|
||||
active_node->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), active_node.get());
|
||||
// set the active stream id into the attr of active node
|
||||
std::vector<uint32_t> active_index_value = {};
|
||||
active_index_value.push_back(actived_stream);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_value), active_node);
|
||||
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(from_graph), active_node.get());
|
||||
// append the active node at the end of from graph
|
||||
auto exec_order = from_graph->execution_order();
|
||||
exec_order.push_back(active_node);
|
||||
from_graph->set_execution_order(exec_order);
|
||||
auto return_node = from_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
InsertControlDependToGraph(graph_id, return_node->input(1), active_node);
|
||||
}
|
||||
|
||||
void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) {
|
||||
MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString();
|
||||
auto graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
|
||||
auto return_node = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
inputs.push_back(return_node->input(1));
|
||||
inputs.push_back(attch_node);
|
||||
auto depend_node = graph->NewCNode(inputs);
|
||||
return_node->set_input(1, depend_node);
|
||||
}
|
||||
|
||||
void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node,
|
||||
const AnfNodePtr &second_node) {
|
||||
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
|
||||
<< ", the second node is " << second_node->DebugString();
|
||||
auto graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("ControlDepend"))};
|
||||
inputs.push_back(first_node);
|
||||
inputs.push_back(second_node);
|
||||
auto control_depend = graph->NewCNode(inputs);
|
||||
InsertDependToGraph(graph_id, control_depend);
|
||||
}
|
||||
|
||||
size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) {
|
||||
|
@ -1043,5 +1125,29 @@ std::vector<GraphType> &AscendSession::GetGraphOrderType(GraphId final_graph_id)
|
|||
}
|
||||
return graph_type_iter->second;
|
||||
}
|
||||
|
||||
void AscendSession::SyncInitialTenosrToDevice() {
|
||||
for (auto &item : initial_tenosrs_) {
|
||||
auto to_graph_id = item.first.first;
|
||||
auto input_idx = item.first.second;
|
||||
auto front_tensor = item.second;
|
||||
auto to_graph = GetGraph(to_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(to_graph);
|
||||
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
|
||||
if (input_idx >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
auto backend_parameter = graph_inputs[input_idx];
|
||||
// sync data from host to device
|
||||
MS_EXCEPTION_IF_NULL(front_tensor);
|
||||
size_t tensor_size = front_tensor->data().nbytes();
|
||||
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
|
||||
front_tensor->data_type(), front_tensor->data_c(false))) {
|
||||
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,9 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
#include <stack>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <set>
|
||||
#include "session/session_basic.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "kernel/kernel.h"
|
||||
|
@ -60,6 +63,8 @@ class AscendSession : public SessionBasic {
|
|||
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
|
||||
// insert active to graph
|
||||
void SetActive(GraphId, GraphId) override;
|
||||
// compile child graph when session have multiple child graphs
|
||||
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
||||
|
||||
private:
|
||||
void InitRuntimeResource();
|
||||
|
@ -95,12 +100,16 @@ class AscendSession : public SessionBasic {
|
|||
size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph);
|
||||
// handle condition graph from vm
|
||||
void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id);
|
||||
// insert depend to graph, used to attch control nodes to graph
|
||||
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
|
||||
// insert depend to graph, used to attch control nodes to graph
|
||||
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
|
||||
// Get graph by graph id ,if not exist return null ptr
|
||||
KernelGraphPtr GetGraph(GraphId graph_id);
|
||||
// set child graph parameter if front arg is a anf
|
||||
void SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter);
|
||||
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
|
||||
// set child graph parameter if front arg is a tensor
|
||||
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter);
|
||||
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx);
|
||||
// update the execution order of all child graphs
|
||||
void UpdateGraphOrder(GraphId to_graph);
|
||||
// handle switch when merge
|
||||
|
@ -113,6 +122,12 @@ class AscendSession : public SessionBasic {
|
|||
void CopyOutputOfIf(GraphId false_graph_id);
|
||||
// check if graph cache exist
|
||||
bool GraphCacheExist(const GraphInfo &graph_info) const;
|
||||
// insert all assign to child graph
|
||||
void InsertAllAssigns();
|
||||
// create fake output of final graph
|
||||
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
|
||||
// sync intial tensors' data to device
|
||||
void SyncInitialTenosrToDevice();
|
||||
|
||||
// member variables
|
||||
// key is final_graph_id,value is child graph execute order of final graph
|
||||
|
@ -124,6 +139,10 @@ class AscendSession : public SessionBasic {
|
|||
// record all conditions
|
||||
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
|
||||
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
|
||||
// share parameters
|
||||
std::set<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
|
||||
// initial tensors, these tensor will sync data to device before run graph
|
||||
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
|
||||
// final_graph_id is used in every root graph has it's own session situation
|
||||
GraphId final_graph_id_;
|
||||
};
|
||||
|
|
|
@ -295,10 +295,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
|
|||
// set the format of value_node to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
std::vector<TypeId> types;
|
||||
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
|
||||
types.push_back(kTypeUnknown);
|
||||
}
|
||||
std::vector<TypeId> types = std::vector<TypeId>(AnfAlgo::GetOutputTensorNum(value_node), kTypeUnknown);
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
|
||||
|
@ -330,10 +327,11 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
|
|||
MS_LOG(EXCEPTION) << "old can't be same with new";
|
||||
}
|
||||
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
|
||||
MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
|
||||
return;
|
||||
}
|
||||
if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "anf is not exist in the mape ,old " << old_backend_anf->DebugString();
|
||||
MS_LOG(EXCEPTION) << "anf is not exist in the map ,old " << old_backend_anf->DebugString();
|
||||
}
|
||||
front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
|
||||
backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
|
||||
|
@ -528,5 +526,44 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(old_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(new_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(inputs_);
|
||||
auto it = node_output_edges_.find(old_anf_node);
|
||||
if (it == node_output_edges_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map";
|
||||
}
|
||||
auto &outputs = it->second;
|
||||
for (auto &output_node : outputs) {
|
||||
auto output_cnode = output_node.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
auto &output_node_inputs = output_cnode->inputs();
|
||||
for (size_t i = 1; i < output_node_inputs.size(); i++) {
|
||||
if (output_node_inputs[i] == old_anf_node) {
|
||||
output_cnode->set_input(i, new_anf_node);
|
||||
}
|
||||
}
|
||||
// update graph inputs
|
||||
for (size_t i = 0; i < inputs_->size(); i++) {
|
||||
if ((*inputs_)[i] == old_anf_node) {
|
||||
(*inputs_)[i] = new_anf_node;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// update front to backend map
|
||||
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
|
||||
// update output depend relations
|
||||
node_output_edges_[new_anf_node] = it->second;
|
||||
(void)node_output_edges_.erase(old_anf_node);
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateExecuteKernelStreamLabel() {
|
||||
for (auto &kernel : execution_order_) {
|
||||
AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
|
||||
}
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -37,6 +38,7 @@ class KernelGraph : public FuncGraph {
|
|||
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
|
||||
execution_order_ = {};
|
||||
executable_ = true;
|
||||
stream_distinction_label_ = kInvalidDistincLabel;
|
||||
}
|
||||
~KernelGraph() override = default;
|
||||
|
||||
|
@ -88,7 +90,15 @@ class KernelGraph : public FuncGraph {
|
|||
void set_executable(bool executable) { executable_ = executable; }
|
||||
// set invalid inputs for control sink
|
||||
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
|
||||
const std::vector<bool> &ValidInputs() const { return valid_inputs_; }
|
||||
std::vector<bool> valid_inputs() const { return valid_inputs_; }
|
||||
// replace node in graph
|
||||
void ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node);
|
||||
// set stream label of graph
|
||||
void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
|
||||
// get stream label of graph
|
||||
uint32_t stream_distinction_label() { return stream_distinction_label_; }
|
||||
// refresh execute kernel stream label
|
||||
void UpdateExecuteKernelStreamLabel();
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
@ -108,6 +118,7 @@ class KernelGraph : public FuncGraph {
|
|||
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
|
||||
std::vector<CNodePtr> execution_order_;
|
||||
uint32_t graph_id_;
|
||||
uint32_t stream_distinction_label_;
|
||||
|
||||
// record map bettween front anf and backend anf,use two map implement bidirectional map
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_backend_anf_map_;
|
||||
|
|
|
@ -417,9 +417,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|||
|
||||
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
||||
auto graph = std::make_shared<KernelGraph>();
|
||||
graph->set_graph_id(graph_sum_);
|
||||
MS_LOG(INFO) << "Create graph: " << graph_sum_;
|
||||
auto graph = NewKernelGraph();
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
size_t from_other_graph_depend_num = 0;
|
||||
for (const auto &node : lst) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -456,7 +455,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
}
|
||||
graph->SetExecOrderByDefault();
|
||||
opt::BackendCommonOptimization(graph);
|
||||
graphs_[graph_sum_++] = graph;
|
||||
return graph;
|
||||
}
|
||||
|
||||
|
@ -588,14 +586,14 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
|||
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> output_args;
|
||||
for (const auto &output : outputs) {
|
||||
MS_LOG(INFO) << "output:" << output->DebugString();
|
||||
}
|
||||
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
|
||||
if (backend_anf != nullptr) {
|
||||
return backend_anf;
|
||||
}
|
||||
for (const auto &output : outputs) {
|
||||
MS_LOG(INFO) << "output:" << output->DebugString();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
|
||||
};
|
||||
output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
@ -695,5 +693,12 @@ BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
|||
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
||||
}
|
||||
}
|
||||
|
||||
KernelGraphPtr SessionBasic::NewKernelGraph() {
|
||||
auto graph = std::make_shared<KernelGraph>();
|
||||
graph->set_graph_id(graph_sum_);
|
||||
graphs_[graph_sum_++] = graph;
|
||||
return graph;
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -104,6 +104,8 @@ class SessionBasic {
|
|||
const std::vector<bool> &tensors_mask);
|
||||
// trans BaseRef list to py::tuple
|
||||
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
|
||||
// create a new kernel graph and update the graph sum
|
||||
KernelGraphPtr NewKernelGraph();
|
||||
|
||||
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
|
|
|
@ -27,6 +27,7 @@ assign_op_info = TBERegOp("Assign") \
|
|||
.input(1, "value", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
|
||||
|
|
Loading…
Reference in New Issue