forked from mindspore-Ecosystem/mindspore
!3534 Expand make tuple node of graph input
Merge pull request !3534 from chenfei_mindspore/release-make-tuple-of-load-input-data
This commit is contained in:
commit
2d9133d773
|
@ -885,11 +885,6 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
|
|||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
CreateMultiBranchOutput(NOT_NULL(child_graph), memo);
|
||||
}
|
||||
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert assign
|
||||
// from condition to true graph
|
||||
if (graph->get_output_null()) {
|
||||
return;
|
||||
}
|
||||
std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
|
||||
auto node_list = GetCNodes(TopoSort(graph->get_return()));
|
||||
for (auto &node : node_list) {
|
||||
|
@ -909,6 +904,11 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
|
|||
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node);
|
||||
for (auto &child_graph : child_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert
|
||||
// assign from condition to true graph
|
||||
if (memo->find(child_graph) != memo->end()) {
|
||||
continue;
|
||||
}
|
||||
if (child_graph->get_output_null()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -927,6 +927,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
|
|||
}
|
||||
}
|
||||
}
|
||||
memo->erase(graph.get());
|
||||
}
|
||||
|
||||
void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
|
|
|
@ -475,7 +475,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|||
cnode_inputs.emplace_back(new_value_node);
|
||||
}
|
||||
continue;
|
||||
} else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) {
|
||||
} else if (anf->isa<Parameter>()) {
|
||||
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
|
||||
cnode_inputs.push_back(new_parameter);
|
||||
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
|
||||
|
@ -818,6 +818,25 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
|
||||
if (ms_context->enable_pynative_infer()) {
|
||||
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
|
||||
}
|
||||
if (tensor->is_dirty()) {
|
||||
return true;
|
||||
}
|
||||
if (tensor->device_address() != device_address) {
|
||||
(void)tensor->data_sync();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// run graph steps
|
||||
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||
|
@ -827,7 +846,11 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
if (kernel_graph->input_ctrl_tensors()) {
|
||||
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
|
||||
}
|
||||
auto input_nodes = kernel_graph->inputs();
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
for (const auto &input_node : kernel_graph->inputs()) {
|
||||
auto params = AnfAlgo::GetAllOutput(input_node);
|
||||
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
|
||||
}
|
||||
if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) {
|
||||
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
||||
<< ", input_ctrl_size:" << input_ctrl_size;
|
||||
|
@ -838,33 +861,17 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
auto tensor = inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto input_node = input_nodes[i];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||
auto pk_node = input_node->cast<ParameterPtr>();
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
bool need_sync = false;
|
||||
if (ms_context->enable_pynative_infer()) {
|
||||
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
|
||||
need_sync = true;
|
||||
}
|
||||
} else {
|
||||
if (tensor->is_dirty()) {
|
||||
need_sync = true;
|
||||
} else if (tensor->device_address() != device_address) {
|
||||
(void)tensor->data_sync();
|
||||
need_sync = true;
|
||||
}
|
||||
if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
if (ms_context->execution_mode() == kPynativeMode ||
|
||||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
if (need_sync) {
|
||||
if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
|
|
Loading…
Reference in New Issue