!29089 remove device target info in session basic

Merge pull request !29089 from baihuawei/clear_code_rt1.6
This commit is contained in:
i-robot 2022-01-24 05:31:30 +00:00 committed by Gitee
commit 64932d08fb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 69 additions and 47 deletions

View File

@ -387,7 +387,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
MS_LOG(INFO) << "Status record: start compile graph.";
// construct graph, if successfully, graph_sum_ + 1
auto graph = ConstructKernelGraph(lst, outputs);
auto graph = ConstructKernelGraph(lst, outputs, DeviceAddressType::kAscend);
auto graph_id = graph->graph_id();
InitAllBucket(graph);
MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
@ -397,7 +397,7 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode
GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "Status record: start compile graph.";
std::vector<KernelGraphPtr> all_graphs;
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs, DeviceAddressType::kAscend);
for (const auto &graph : all_graphs) {
graph->set_root_graph_id(root_graph->graph_id());
}

View File

@ -61,8 +61,7 @@ void CPUSession::Init(uint32_t device_id) {
InitExecutor(kCPUDevice, device_id);
}
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph,
const std::string &) {
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
if (!anf->isa<Parameter>()) {
@ -118,7 +117,7 @@ void CPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
auto graph_id = graph_sum_;
auto graph = ConstructKernelGraph(lst, outputs);
auto graph = ConstructKernelGraph(lst, outputs, DeviceAddressType::kCPU);
MS_EXCEPTION_IF_NULL(graph);
opt::AddDynamicShapeAttrPass(graph);
MS_LOG(INFO) << "Set kernel info";

View File

@ -42,7 +42,7 @@ class CPUSession : public SessionBasic {
void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *const outputs) override;
void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph, const std::string &) override;
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override;
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,

View File

@ -392,14 +392,14 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
// Construct graph, if successfully, graph_sum_ + 1
auto graph = ConstructKernelGraph(lst, outputs);
auto graph = ConstructKernelGraph(lst, outputs, DeviceAddressType::kGPU);
MS_EXCEPTION_IF_NULL(graph);
return CompileGraphImpl(graph);
}
GraphId GPUSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
std::vector<KernelGraphPtr> all_graphs;
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs, DeviceAddressType::kGPU);
MS_EXCEPTION_IF_NULL(root_graph);
if (all_graphs.size() != 1) {
MS_LOG(EXCEPTION) << "Gpu backend does not support multi-graph schedule, graph num is " << all_graphs.size();

View File

@ -49,6 +49,7 @@ struct KernelWithIndexCmp {
}
};
using DeviceAddressType = device::DeviceAddressType;
using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
class KernelGraph : public FuncGraph {
@ -60,6 +61,7 @@ class KernelGraph : public FuncGraph {
executable_ = true;
summary_node_exist_ = false;
stream_distinction_label_ = kInvalidDistincLabel;
device_target_ = DeviceAddressType::kUnknown;
}
KernelGraph(const KernelGraph &graph) : FuncGraph(graph) {
@ -68,6 +70,7 @@ class KernelGraph : public FuncGraph {
execution_order_ = graph.execution_order_;
mem_reuse_exec_order_ = graph.mem_reuse_exec_order_;
graph_id_ = graph.graph_id_;
device_target_ = graph.device_target_;
stream_distinction_label_ = graph.stream_distinction_label_;
front_backend_anf_map_ = graph.front_backend_anf_map_;
backend_front_anf_map_ = graph.backend_front_anf_map_;
@ -148,6 +151,8 @@ class KernelGraph : public FuncGraph {
void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
uint32_t root_graph_id() const { return root_graph_id_; }
void set_root_graph_id(uint32_t root_graph_id) { root_graph_id_ = root_graph_id; }
DeviceAddressType device_target() const { return device_target_; }
void set_device_target(DeviceAddressType target) { device_target_ = target; }
// and a new front to backend anf relation to maop
void FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf);
@ -446,6 +451,7 @@ class KernelGraph : public FuncGraph {
std::vector<CNodePtr> mem_reuse_exec_order_;
uint32_t graph_id_;
uint32_t stream_distinction_label_;
DeviceAddressType device_target_;
uint32_t root_graph_id_{0};
// record map bettween front anf and backend anf,use two map implement bidirectional map

View File

@ -136,7 +136,6 @@ bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &no
}
return false;
}
ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
@ -645,8 +644,7 @@ AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, Kernel
return new_parameter;
}
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph,
const std::string &target) {
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
@ -658,11 +656,21 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
ParameterPtr new_parameter = nullptr;
if (target == "CPU") {
auto func_graph = anf->func_graph();
if (func_graph->manager() != nullptr && func_graph->IsMultiTarget() &&
graph->device_target() == device::DeviceAddressType::kCPU) {
auto iter = default_param_map_.find(anf);
if (iter != default_param_map_.end()) {
new_parameter = iter->second;
}
if (new_parameter != nullptr) {
return new_parameter;
}
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
graph_inputs->push_back(new_parameter);
valid_inputs->push_back(true);
default_param_map_[anf] = new_parameter;
return new_parameter;
}
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
@ -711,8 +719,7 @@ void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *
}
void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode,
const std::string &target) {
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(other_graph_cnode);
@ -741,7 +748,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
}
continue;
} else if (anf->isa<Parameter>()) {
auto new_parameter = CreateNewParameterFromParameter(anf, graph, target);
auto new_parameter = CreateNewParameterFromParameter(anf, graph);
cnode_inputs->push_back(new_parameter);
graph->FrontBackendMapAdd(anf, new_parameter);
continue;
@ -763,15 +770,14 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
}
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode,
const std::string &target) {
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(other_graph_cnode);
// get primitive of old node
std::vector<AnfNodePtr> cnode_inputs;
GetCNodeInfo(cnode, &cnode_inputs);
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode, target);
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
return new_cnode;
@ -1157,7 +1163,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
}
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
bool common_opt, const device::DeviceContext *device_context) {
DeviceAddressType device_target, bool common_opt) {
mindspore::HashMap<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
@ -1170,12 +1176,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::string target;
if (device_context) {
target = device_context->device_context_key().device_name_;
}
graph->set_device_target(device_target);
// create a new cnode object
auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode, target);
auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
@ -1211,7 +1214,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
SetInputNodeUsage(graph, manager);
graph->SetOptimizerFlag();
}
return graph;
}
@ -1667,7 +1669,8 @@ bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph
}
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph) {
std::vector<KernelGraphPtr> *all_out_graph,
DeviceAddressType device_target) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(all_out_graph);
auto node_list = TopoSort(func_graph->get_return());
@ -1675,6 +1678,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(graph);
front_backend_graph_map_[func_graph.get()] = graph;
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
graph->set_device_target(device_target);
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
@ -1697,7 +1701,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
// Create child kernel graph according ValueNode<FuncGraph>
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
(void)ConstructKernelGraph(child_graph, all_out_graph);
(void)ConstructKernelGraph(child_graph, all_out_graph, device_target);
}
(void)CreateValueNodeKernelGraph(node, graph.get());
continue;

View File

@ -123,15 +123,16 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
bool common_opt = true,
const device::DeviceContext *device_context = nullptr);
DeviceAddressType device_target = DeviceAddressType::kUnknown,
bool common_opt = true);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph);
std::vector<KernelGraphPtr> *all_out_graph,
DeviceAddressType device_target);
void SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager);
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode, const std::string &target);
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
// get graph id in child graphs by ME front anf node pointer
@ -183,7 +184,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::vector<AnfNodePtr> CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph);
void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const;
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode, const std::string &target);
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
@ -307,8 +308,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);
virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph,
const std::string &target = "");
virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph);
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph);
@ -327,7 +327,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void GetBatchElements(const AnfNodePtr &kernel_node) const;
void InitPsWorker(const KernelGraphPtr &kernel_graph);
#endif
std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_;
std::map<uint32_t, uint32_t> free_bucket_id_map_;
mindspore::HashMap<GraphId, std::shared_ptr<KernelGraph>> graphs_;
@ -335,6 +334,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
mindspore::HashMap<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
mindspore::HashMap<AnfNodePtr, AnfNodePtr> partial_parameters_map_;
mindspore::HashMap<AnfNodePtr, std::string> partial_target_map_;
mindspore::HashMap<AnfNodePtr, ParameterPtr> default_param_map_;
std::shared_ptr<Context> context_;
CallBackFunc summary_callback_;
static GraphId graph_sum_;

View File

@ -48,7 +48,7 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const
};
MS_LOG(INFO) << "Start MultiGraph Compile.";
// construct kernel graph
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false);
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, device::DeviceAddressType::kUnknown, false);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm");

View File

@ -804,10 +804,10 @@ DeviceAddressPtr AscendKernelRuntime::GetInternalDeviceAddress(const session::Ke
if (graph_output.first == nullptr) {
continue;
}
if (!AnfAlgo::OutputAddrExist(graph_output.first, 0)) {
if (!AnfAlgo::OutputAddrExist(graph_output.first, graph_output.second)) {
return nullptr;
}
auto output_device_address = AnfAlgo::GetMutableOutputAddr(graph_output.first, 0);
auto output_device_address = AnfAlgo::GetMutableOutputAddr(graph_output.first, graph_output.second);
MS_EXCEPTION_IF_NULL(output_device_address);
if (output_device_address->DeviceType() == DeviceAddressType::kAscend) {
return output_device_address;

View File

@ -167,14 +167,21 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) {
}
auto input_node = input_nodes[device_tensor_store_key.first];
MS_EXCEPTION_IF_NULL(input_node);
auto input_param = input_node->cast<ParameterPtr>();
if (!input_param->IsUsedByRealKernelInGraph(graph_->graph_id())) {
continue;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_address);
if (input_device_tensor->GetPtr() != device_address->GetPtr()) {
MS_LOG(WARNING) << "The input data of node:" << input_node->DebugString()
<< " device address:" << input_device_tensor->GetPtr()
<< ", type:" << input_device_tensor->DeviceType()
<< " is not equal to the graph node device address:" << device_address->GetPtr()
<< ", type:" << device_address->DeviceType() << ".";
MS_LOG(ERROR) << "The input data of node:" << input_node->DebugString()
<< " device address:" << input_device_tensor->GetPtr()
<< ", type:" << input_device_tensor->DeviceType()
<< " is not equal to the graph node device address:" << device_address->GetPtr()
<< ", type:" << device_address->DeviceType() << ".";
return false;
}
}

View File

@ -350,8 +350,9 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
MS_EXCEPTION_IF_NULL(segment);
MS_LOG(INFO) << "Status record: start compile graph.";
auto nodes = segment->nodes_;
auto device_terget = device_context->GetDeviceAddressType();
// Generate kernel graph.
KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs, true, device_context);
KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs, device_terget);
MS_EXCEPTION_IF_NULL(graph);
opt::EliminateIllegalDataTypePass(graph);
SetGraphDependency(graph, segment);
@ -397,7 +398,8 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device
MS_LOG(INFO) << "Status record: start compile graph.";
// Generate kernel graph.
std::vector<KernelGraphPtr> all_graphs;
KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, &all_graphs);
auto device_target = device_context->GetDeviceAddressType();
KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, &all_graphs, device_target);
MS_EXCEPTION_IF_NULL(root_graph);
for (const auto &graph : all_graphs) {
MS_EXCEPTION_IF_NULL(graph);

View File

@ -43,6 +43,7 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
parameters_(),
has_vararg_(false),
has_kwarg_(false),
exist_multi_target_(false),
kwonlyargs_count_(0),
hyper_param_count_(0),
is_generated_(false),
@ -757,13 +758,14 @@ ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
return parameter;
}
bool FuncGraph::ContainMultiTarget() const {
bool FuncGraph::ContainMultiTarget() {
auto graph_manager = manager();
MS_EXCEPTION_IF_NULL(graph_manager);
FuncGraphSet graphs = graph_manager->func_graphs();
for (auto &g : graphs) {
auto nodes = mindspore::TopoSort(g->get_return());
if (mindspore::ContainMultiTarget(nodes)) {
exist_multi_target_ = true;
return true;
}
}

View File

@ -348,7 +348,8 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
void set_switch_layer_input(const std::shared_ptr<bool> &switch_layer_input) {
switch_layer_input_ = switch_layer_input;
}
bool ContainMultiTarget() const;
bool ContainMultiTarget();
bool IsMultiTarget() const { return exist_multi_target_; }
int64_t stage() const { return stage_; }
void set_stage(int64_t stage) { stage_ = stage; }
@ -400,6 +401,7 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
// Whether there is a *args and **kwargs, and count kwonlyargs'number.
bool has_vararg_;
bool has_kwarg_;
bool exist_multi_target_;
int kwonlyargs_count_;
// Hyper param is placed on the top graph,
// and positioned in the end of the param list, so we record the number to trace the position.