forked from OSSInnovation/mindspore
!14484 extract session load inputs
From: @kisnwang Reviewed-by: @zhoufeng54,@jjfeing Signed-off-by: @jjfeing
This commit is contained in:
commit
2a2cbbfa4c
|
@ -571,8 +571,6 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens
|
|||
std::set<KernelGraphPtr> memo;
|
||||
SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
// load input data from user input
|
||||
LoadInputData(kernel_graph, inputs);
|
||||
if (debugger_) {
|
||||
debugger_->PreExecute(kernel_graph, graph_sum_);
|
||||
}
|
||||
|
|
|
@ -130,6 +130,32 @@ void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &ker
|
|||
runtime_.SyncValueNodeDeviceAddr(kernel_graph.get());
|
||||
}
|
||||
|
||||
void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto &input_nodes = kernel_graph->inputs();
|
||||
if (input_nodes.size() != inputs_const.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
|
||||
}
|
||||
for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
|
||||
auto &item = input_nodes[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
if (item->isa<Parameter>() && !HasAbstractMonad(item)) {
|
||||
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
|
||||
auto tensor = inputs_const[input_idx];
|
||||
auto tensor_address = tensor->device_address();
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor_address != nullptr && tensor_address != address &&
|
||||
(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() !=
|
||||
device::DeviceAddressType::kCPU ||
|
||||
AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) {
|
||||
tensor->data_sync(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) {
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
|
|
|
@ -44,6 +44,8 @@ class CPUSession : public SessionBasic {
|
|||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const override;
|
||||
|
||||
private:
|
||||
void Reorder(std::vector<CNodePtr> *node_list);
|
||||
|
|
|
@ -161,6 +161,7 @@ void RunGraphTask::Run() {
|
|||
}
|
||||
graph->ResetGraphRunningStatus();
|
||||
try {
|
||||
session_->LoadInputs(graph_id_, input_tensors_);
|
||||
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
|
||||
UpdateOutputTensors(&outputs_, tensor_to_node_);
|
||||
} catch (const std::exception &e) {
|
||||
|
|
|
@ -425,8 +425,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
|
|||
MS_LOG(INFO) << "RunGraph graph_id: " << graph_id;
|
||||
// In pynative mode, device addresses of tensors in value nodes change.
|
||||
SyncValueNodeDeviceAddr(kernel_graph);
|
||||
// Load input data from user input
|
||||
LoadInputData(kernel_graph, inputs);
|
||||
if (debugger_) {
|
||||
debugger_->PreExecute(kernel_graph, graph_sum_);
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ class GPUSession : public SessionBasic {
|
|||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
|
||||
std::string GetCommWorldGroup() override { return kNcclWorldGroup; }
|
||||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const override;
|
||||
|
||||
private:
|
||||
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
@ -71,9 +73,6 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void RunOpClearMemory(KernelGraph *kernel_graph) const;
|
||||
|
||||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const override;
|
||||
|
||||
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
|
|
@ -180,8 +180,8 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
|
|||
return std::vector<AnfNodePtr>(1, graph_output);
|
||||
}
|
||||
|
||||
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
|
||||
void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
|
||||
MS_EXCEPTION_IF_NULL(visit_queue);
|
||||
MS_EXCEPTION_IF_NULL(visited_nodes);
|
||||
auto it = node_output_edges_.find(node);
|
||||
|
@ -241,7 +241,7 @@ void KernelGraph::SetExecOrderByDefault() {
|
|||
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
|
||||
// seed nodes first, then delay comm nodes
|
||||
if (seed_nodes.empty()) {
|
||||
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
||||
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
||||
delay_comm_stack.pop();
|
||||
} else {
|
||||
zero_input_nodes.push(seed_nodes.front());
|
||||
|
@ -272,16 +272,16 @@ void KernelGraph::SetExecOrderByDefault() {
|
|||
}
|
||||
if (optimize_comm) {
|
||||
while (!delay_comm_stack.empty()) {
|
||||
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
||||
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
||||
delay_comm_stack.pop();
|
||||
}
|
||||
delay_comm_stack.push(node);
|
||||
} else if (is_fused_comm) {
|
||||
delay_comm_stack.push(node);
|
||||
} else if (is_communication_descendant) {
|
||||
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
|
||||
EnqueueActiveNodes(node, &communication_descendants, &visited_nodes);
|
||||
} else {
|
||||
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
|
||||
EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -283,8 +283,8 @@ class KernelGraph : public FuncGraph {
|
|||
void SetKernelInfoForNode(const AnfNodePtr &node) const;
|
||||
void ResetInFormat(const AnfNodePtr &node, const std::string &format) const;
|
||||
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
|
||||
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
|
||||
void EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
|
||||
// update node edge list
|
||||
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
|
||||
// add node depend edge by data edge or control depend
|
||||
|
|
|
@ -188,6 +188,13 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
|
||||
virtual void SetSummaryNodes(KernelGraph *graph);
|
||||
|
||||
void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) {
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_LOG(INFO) << "Load inputs";
|
||||
LoadInputData(kernel_graph, inputs_const);
|
||||
}
|
||||
|
||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
|
||||
|
|
|
@ -283,20 +283,14 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
|
|||
if (input_nodes.size() != inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
|
||||
}
|
||||
size_t input_idx = 0;
|
||||
for (auto &item : input_nodes) {
|
||||
for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
|
||||
auto &item = input_nodes[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
if (item->isa<Parameter>() && !HasAbstractMonad(item)) {
|
||||
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
|
||||
auto tensor = inputs[input_idx];
|
||||
auto tensor_address = tensor->device_address();
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor_address != nullptr && tensor_address != address &&
|
||||
(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU ||
|
||||
AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) {
|
||||
tensor->data_sync(false);
|
||||
}
|
||||
if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) {
|
||||
address->ptr_ = tensor->data_c();
|
||||
} else {
|
||||
|
@ -318,7 +312,6 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
|
|||
address->ref_count_ = INIT_NODE_REF;
|
||||
tensor->set_device_address(address);
|
||||
}
|
||||
input_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue