!29089 remove device target info in session basic
Merge pull request !29089 from baihuawei/clear_code_rt1.6
This commit is contained in:
commit
64932d08fb
|
@ -387,7 +387,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
MS_LOG(INFO) << "Status record: start compile graph.";
|
||||
// construct graph, if successfully, graph_sum_ + 1
|
||||
auto graph = ConstructKernelGraph(lst, outputs);
|
||||
auto graph = ConstructKernelGraph(lst, outputs, DeviceAddressType::kAscend);
|
||||
auto graph_id = graph->graph_id();
|
||||
InitAllBucket(graph);
|
||||
MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
|
||||
|
@ -397,7 +397,7 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode
|
|||
GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
||||
MS_LOG(INFO) << "Status record: start compile graph.";
|
||||
std::vector<KernelGraphPtr> all_graphs;
|
||||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs, DeviceAddressType::kAscend);
|
||||
for (const auto &graph : all_graphs) {
|
||||
graph->set_root_graph_id(root_graph->graph_id());
|
||||
}
|
||||
|
|
|
@ -61,8 +61,7 @@ void CPUSession::Init(uint32_t device_id) {
|
|||
InitExecutor(kCPUDevice, device_id);
|
||||
}
|
||||
|
||||
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph,
|
||||
const std::string &) {
|
||||
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
|
@ -118,7 +117,7 @@ void CPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
|
||||
GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
auto graph_id = graph_sum_;
|
||||
auto graph = ConstructKernelGraph(lst, outputs);
|
||||
auto graph = ConstructKernelGraph(lst, outputs, DeviceAddressType::kCPU);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
opt::AddDynamicShapeAttrPass(graph);
|
||||
MS_LOG(INFO) << "Set kernel info";
|
||||
|
|
|
@ -42,7 +42,7 @@ class CPUSession : public SessionBasic {
|
|||
void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *const outputs) override;
|
||||
void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
|
||||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph, const std::string &) override;
|
||||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override;
|
||||
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
|
|
|
@ -392,14 +392,14 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
|
||||
GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
// Construct graph, if successfully, graph_sum_ + 1
|
||||
auto graph = ConstructKernelGraph(lst, outputs);
|
||||
auto graph = ConstructKernelGraph(lst, outputs, DeviceAddressType::kGPU);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
return CompileGraphImpl(graph);
|
||||
}
|
||||
|
||||
GraphId GPUSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
||||
std::vector<KernelGraphPtr> all_graphs;
|
||||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs, DeviceAddressType::kGPU);
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
if (all_graphs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Gpu backend does not support multi-graph schedule, graph num is " << all_graphs.size();
|
||||
|
|
|
@ -49,6 +49,7 @@ struct KernelWithIndexCmp {
|
|||
}
|
||||
};
|
||||
|
||||
using DeviceAddressType = device::DeviceAddressType;
|
||||
using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
|
||||
|
||||
class KernelGraph : public FuncGraph {
|
||||
|
@ -60,6 +61,7 @@ class KernelGraph : public FuncGraph {
|
|||
executable_ = true;
|
||||
summary_node_exist_ = false;
|
||||
stream_distinction_label_ = kInvalidDistincLabel;
|
||||
device_target_ = DeviceAddressType::kUnknown;
|
||||
}
|
||||
|
||||
KernelGraph(const KernelGraph &graph) : FuncGraph(graph) {
|
||||
|
@ -68,6 +70,7 @@ class KernelGraph : public FuncGraph {
|
|||
execution_order_ = graph.execution_order_;
|
||||
mem_reuse_exec_order_ = graph.mem_reuse_exec_order_;
|
||||
graph_id_ = graph.graph_id_;
|
||||
device_target_ = graph.device_target_;
|
||||
stream_distinction_label_ = graph.stream_distinction_label_;
|
||||
front_backend_anf_map_ = graph.front_backend_anf_map_;
|
||||
backend_front_anf_map_ = graph.backend_front_anf_map_;
|
||||
|
@ -148,6 +151,8 @@ class KernelGraph : public FuncGraph {
|
|||
void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
|
||||
uint32_t root_graph_id() const { return root_graph_id_; }
|
||||
void set_root_graph_id(uint32_t root_graph_id) { root_graph_id_ = root_graph_id; }
|
||||
DeviceAddressType device_target() const { return device_target_; }
|
||||
void set_device_target(DeviceAddressType target) { device_target_ = target; }
|
||||
|
||||
// and a new front to backend anf relation to maop
|
||||
void FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf);
|
||||
|
@ -446,6 +451,7 @@ class KernelGraph : public FuncGraph {
|
|||
std::vector<CNodePtr> mem_reuse_exec_order_;
|
||||
uint32_t graph_id_;
|
||||
uint32_t stream_distinction_label_;
|
||||
DeviceAddressType device_target_;
|
||||
uint32_t root_graph_id_{0};
|
||||
|
||||
// record map bettween front anf and backend anf,use two map implement bidirectional map
|
||||
|
|
|
@ -136,7 +136,6 @@ bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &no
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return nullptr;
|
||||
|
@ -645,8 +644,7 @@ AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, Kernel
|
|||
return new_parameter;
|
||||
}
|
||||
|
||||
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph,
|
||||
const std::string &target) {
|
||||
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
|
||||
|
@ -658,11 +656,21 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
ParameterPtr new_parameter = nullptr;
|
||||
if (target == "CPU") {
|
||||
auto func_graph = anf->func_graph();
|
||||
if (func_graph->manager() != nullptr && func_graph->IsMultiTarget() &&
|
||||
graph->device_target() == device::DeviceAddressType::kCPU) {
|
||||
auto iter = default_param_map_.find(anf);
|
||||
if (iter != default_param_map_.end()) {
|
||||
new_parameter = iter->second;
|
||||
}
|
||||
if (new_parameter != nullptr) {
|
||||
return new_parameter;
|
||||
}
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
|
||||
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
graph_inputs->push_back(new_parameter);
|
||||
valid_inputs->push_back(true);
|
||||
default_param_map_[anf] = new_parameter;
|
||||
return new_parameter;
|
||||
}
|
||||
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
|
||||
|
@ -711,8 +719,7 @@ void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *
|
|||
}
|
||||
|
||||
void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode,
|
||||
const std::string &target) {
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
||||
|
@ -741,7 +748,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
}
|
||||
continue;
|
||||
} else if (anf->isa<Parameter>()) {
|
||||
auto new_parameter = CreateNewParameterFromParameter(anf, graph, target);
|
||||
auto new_parameter = CreateNewParameterFromParameter(anf, graph);
|
||||
cnode_inputs->push_back(new_parameter);
|
||||
graph->FrontBackendMapAdd(anf, new_parameter);
|
||||
continue;
|
||||
|
@ -763,15 +770,14 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
}
|
||||
|
||||
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode,
|
||||
const std::string &target) {
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
||||
// get primitive of old node
|
||||
std::vector<AnfNodePtr> cnode_inputs;
|
||||
GetCNodeInfo(cnode, &cnode_inputs);
|
||||
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode, target);
|
||||
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
|
||||
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
|
||||
return new_cnode;
|
||||
|
@ -1157,7 +1163,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
|
|||
}
|
||||
|
||||
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
|
||||
bool common_opt, const device::DeviceContext *device_context) {
|
||||
DeviceAddressType device_target, bool common_opt) {
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
||||
auto graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -1170,12 +1176,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::string target;
|
||||
if (device_context) {
|
||||
target = device_context->device_context_key().device_name_;
|
||||
}
|
||||
graph->set_device_target(device_target);
|
||||
// create a new cnode object
|
||||
auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode, target);
|
||||
auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
|
@ -1211,7 +1214,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
SetInputNodeUsage(graph, manager);
|
||||
graph->SetOptimizerFlag();
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
|
@ -1667,7 +1669,8 @@ bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph
|
|||
}
|
||||
|
||||
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||
std::vector<KernelGraphPtr> *all_out_graph) {
|
||||
std::vector<KernelGraphPtr> *all_out_graph,
|
||||
DeviceAddressType device_target) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(all_out_graph);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
@ -1675,6 +1678,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
front_backend_graph_map_[func_graph.get()] = graph;
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
graph->set_device_target(device_target);
|
||||
for (const auto &node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
||||
|
@ -1697,7 +1701,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
// Create child kernel graph according ValueNode<FuncGraph>
|
||||
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
|
||||
if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
|
||||
(void)ConstructKernelGraph(child_graph, all_out_graph);
|
||||
(void)ConstructKernelGraph(child_graph, all_out_graph, device_target);
|
||||
}
|
||||
(void)CreateValueNodeKernelGraph(node, graph.get());
|
||||
continue;
|
||||
|
|
|
@ -123,15 +123,16 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
|
||||
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
|
||||
bool common_opt = true,
|
||||
const device::DeviceContext *device_context = nullptr);
|
||||
DeviceAddressType device_target = DeviceAddressType::kUnknown,
|
||||
bool common_opt = true);
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||
std::vector<KernelGraphPtr> *all_out_graph);
|
||||
std::vector<KernelGraphPtr> *all_out_graph,
|
||||
DeviceAddressType device_target);
|
||||
|
||||
void SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager);
|
||||
|
||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode, const std::string &target);
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
|
||||
// get graph id in child graphs by ME front anf node pointer
|
||||
|
@ -183,7 +184,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
std::vector<AnfNodePtr> CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const;
|
||||
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode, const std::string &target);
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
|
||||
void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
|
||||
|
@ -307,8 +308,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
// create a new kernel graph and update the graph sum
|
||||
KernelGraphPtr NewKernelGraph();
|
||||
AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);
|
||||
virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph,
|
||||
const std::string &target = "");
|
||||
virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
|
@ -327,7 +327,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void GetBatchElements(const AnfNodePtr &kernel_node) const;
|
||||
void InitPsWorker(const KernelGraphPtr &kernel_graph);
|
||||
#endif
|
||||
|
||||
std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_;
|
||||
std::map<uint32_t, uint32_t> free_bucket_id_map_;
|
||||
mindspore::HashMap<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||
|
@ -335,6 +334,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
mindspore::HashMap<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> partial_parameters_map_;
|
||||
mindspore::HashMap<AnfNodePtr, std::string> partial_target_map_;
|
||||
mindspore::HashMap<AnfNodePtr, ParameterPtr> default_param_map_;
|
||||
std::shared_ptr<Context> context_;
|
||||
CallBackFunc summary_callback_;
|
||||
static GraphId graph_sum_;
|
||||
|
|
|
@ -48,7 +48,7 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const
|
|||
};
|
||||
MS_LOG(INFO) << "Start MultiGraph Compile.";
|
||||
// construct kernel graph
|
||||
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false);
|
||||
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, device::DeviceAddressType::kUnknown, false);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm");
|
||||
|
|
|
@ -804,10 +804,10 @@ DeviceAddressPtr AscendKernelRuntime::GetInternalDeviceAddress(const session::Ke
|
|||
if (graph_output.first == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!AnfAlgo::OutputAddrExist(graph_output.first, 0)) {
|
||||
if (!AnfAlgo::OutputAddrExist(graph_output.first, graph_output.second)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto output_device_address = AnfAlgo::GetMutableOutputAddr(graph_output.first, 0);
|
||||
auto output_device_address = AnfAlgo::GetMutableOutputAddr(graph_output.first, graph_output.second);
|
||||
MS_EXCEPTION_IF_NULL(output_device_address);
|
||||
if (output_device_address->DeviceType() == DeviceAddressType::kAscend) {
|
||||
return output_device_address;
|
||||
|
|
|
@ -167,14 +167,21 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
auto input_node = input_nodes[device_tensor_store_key.first];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
||||
auto input_param = input_node->cast<ParameterPtr>();
|
||||
if (!input_param->IsUsedByRealKernelInGraph(graph_->graph_id())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (input_device_tensor->GetPtr() != device_address->GetPtr()) {
|
||||
MS_LOG(WARNING) << "The input data of node:" << input_node->DebugString()
|
||||
<< " device address:" << input_device_tensor->GetPtr()
|
||||
<< ", type:" << input_device_tensor->DeviceType()
|
||||
<< " is not equal to the graph node device address:" << device_address->GetPtr()
|
||||
<< ", type:" << device_address->DeviceType() << ".";
|
||||
MS_LOG(ERROR) << "The input data of node:" << input_node->DebugString()
|
||||
<< " device address:" << input_device_tensor->GetPtr()
|
||||
<< ", type:" << input_device_tensor->DeviceType()
|
||||
<< " is not equal to the graph node device address:" << device_address->GetPtr()
|
||||
<< ", type:" << device_address->DeviceType() << ".";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -350,8 +350,9 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
|
|||
MS_EXCEPTION_IF_NULL(segment);
|
||||
MS_LOG(INFO) << "Status record: start compile graph.";
|
||||
auto nodes = segment->nodes_;
|
||||
auto device_terget = device_context->GetDeviceAddressType();
|
||||
// Generate kernel graph.
|
||||
KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs, true, device_context);
|
||||
KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs, device_terget);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
opt::EliminateIllegalDataTypePass(graph);
|
||||
SetGraphDependency(graph, segment);
|
||||
|
@ -397,7 +398,8 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device
|
|||
MS_LOG(INFO) << "Status record: start compile graph.";
|
||||
// Generate kernel graph.
|
||||
std::vector<KernelGraphPtr> all_graphs;
|
||||
KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, &all_graphs);
|
||||
auto device_target = device_context->GetDeviceAddressType();
|
||||
KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, &all_graphs, device_target);
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
for (const auto &graph : all_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -43,6 +43,7 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
|
|||
parameters_(),
|
||||
has_vararg_(false),
|
||||
has_kwarg_(false),
|
||||
exist_multi_target_(false),
|
||||
kwonlyargs_count_(0),
|
||||
hyper_param_count_(0),
|
||||
is_generated_(false),
|
||||
|
@ -757,13 +758,14 @@ ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
|
|||
return parameter;
|
||||
}
|
||||
|
||||
bool FuncGraph::ContainMultiTarget() const {
|
||||
bool FuncGraph::ContainMultiTarget() {
|
||||
auto graph_manager = manager();
|
||||
MS_EXCEPTION_IF_NULL(graph_manager);
|
||||
FuncGraphSet graphs = graph_manager->func_graphs();
|
||||
for (auto &g : graphs) {
|
||||
auto nodes = mindspore::TopoSort(g->get_return());
|
||||
if (mindspore::ContainMultiTarget(nodes)) {
|
||||
exist_multi_target_ = true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -348,7 +348,8 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
|
|||
void set_switch_layer_input(const std::shared_ptr<bool> &switch_layer_input) {
|
||||
switch_layer_input_ = switch_layer_input;
|
||||
}
|
||||
bool ContainMultiTarget() const;
|
||||
bool ContainMultiTarget();
|
||||
bool IsMultiTarget() const { return exist_multi_target_; }
|
||||
int64_t stage() const { return stage_; }
|
||||
void set_stage(int64_t stage) { stage_ = stage; }
|
||||
|
||||
|
@ -400,6 +401,7 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
|
|||
// Whether there is a *args and **kwargs, and count kwonlyargs'number.
|
||||
bool has_vararg_;
|
||||
bool has_kwarg_;
|
||||
bool exist_multi_target_;
|
||||
int kwonlyargs_count_;
|
||||
// Hyper param is placed on the top graph,
|
||||
// and positioned in the end of the param list, so we record the number to trace the position.
|
||||
|
|
Loading…
Reference in New Issue