extract load input

This commit is contained in:
kswang 2021-03-31 16:03:08 +08:00
parent a7be883db4
commit 97a97e02db
10 changed files with 48 additions and 24 deletions

View File

@ -571,8 +571,6 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens
std::set<KernelGraphPtr> memo;
SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo));
memo.clear();
// load input data from user input
LoadInputData(kernel_graph, inputs);
if (debugger_) {
debugger_->PreExecute(kernel_graph, graph_sum_);
}

View File

@ -130,6 +130,32 @@ void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &ker
runtime_.SyncValueNodeDeviceAddr(kernel_graph.get());
}
void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto &input_nodes = kernel_graph->inputs();
if (input_nodes.size() != inputs_const.size()) {
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
}
for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
auto &item = input_nodes[input_idx];
MS_EXCEPTION_IF_NULL(item);
if (item->isa<Parameter>() && !HasAbstractMonad(item)) {
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
auto tensor = inputs_const[input_idx];
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address &&
(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() !=
device::DeviceAddressType::kCPU ||
AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) {
tensor->data_sync(false);
}
}
}
}
void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
auto kernel_graph = GetGraph(graph_id);

View File

@ -44,6 +44,8 @@ class CPUSession : public SessionBasic {
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const override;
private:
void Reorder(std::vector<CNodePtr> *node_list);

View File

@ -161,6 +161,7 @@ void RunGraphTask::Run() {
}
graph->ResetGraphRunningStatus();
try {
session_->LoadInputs(graph_id_, input_tensors_);
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
UpdateOutputTensors(&outputs_, tensor_to_node_);
} catch (const std::exception &e) {

View File

@ -425,8 +425,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
MS_LOG(INFO) << "RunGraph graph_id: " << graph_id;
// In pynative mode, device addresses of tensors in value nodes change.
SyncValueNodeDeviceAddr(kernel_graph);
// Load input data from user input
LoadInputData(kernel_graph, inputs);
if (debugger_) {
debugger_->PreExecute(kernel_graph, graph_sum_);
}

View File

@ -47,6 +47,8 @@ class GPUSession : public SessionBasic {
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
std::string GetCommWorldGroup() override { return kNcclWorldGroup; }
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const override;
private:
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
@ -71,9 +73,6 @@ class GPUSession : public SessionBasic {
void RunOpClearMemory(KernelGraph *kernel_graph) const;
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const override;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;

View File

@ -180,7 +180,7 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
return std::vector<AnfNodePtr>(1, graph_output);
}
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
MS_EXCEPTION_IF_NULL(visit_queue);
MS_EXCEPTION_IF_NULL(visited_nodes);
@ -241,7 +241,7 @@ void KernelGraph::SetExecOrderByDefault() {
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
// seed nodes first, then delay comm nodes
if (seed_nodes.empty()) {
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
} else {
zero_input_nodes.push(seed_nodes.front());
@ -272,16 +272,16 @@ void KernelGraph::SetExecOrderByDefault() {
}
if (optimize_comm) {
while (!delay_comm_stack.empty()) {
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
}
delay_comm_stack.push(node);
} else if (is_fused_comm) {
delay_comm_stack.push(node);
} else if (is_communication_descendant) {
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
EnqueueActiveNodes(node, &communication_descendants, &visited_nodes);
} else {
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes);
}
}
}

View File

@ -283,7 +283,7 @@ class KernelGraph : public FuncGraph {
void SetKernelInfoForNode(const AnfNodePtr &node) const;
void ResetInFormat(const AnfNodePtr &node, const std::string &format) const;
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
void EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
// update node edge list
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);

View File

@ -181,6 +181,13 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
virtual void SetSummaryNodes(KernelGraph *graph);
void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) {
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Load inputs";
LoadInputData(kernel_graph, inputs_const);
}
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const;
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);

View File

@ -283,20 +283,14 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
if (input_nodes.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
}
size_t input_idx = 0;
for (auto &item : input_nodes) {
for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
auto &item = input_nodes[input_idx];
MS_EXCEPTION_IF_NULL(item);
if (item->isa<Parameter>() && !HasAbstractMonad(item)) {
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
auto tensor = inputs[input_idx];
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address &&
(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU ||
AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) {
tensor->data_sync(false);
}
if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) {
address->ptr_ = tensor->data_c();
} else {
@ -318,7 +312,6 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
address->ref_count_ = INIT_NODE_REF;
tensor->set_device_address(address);
}
input_idx++;
}
}