!45896 Bugfix for PyNative dynamic shape.

Merge pull request !45896 from caifubi/master-pynative-dynamic-shape-bugfix1
This commit is contained in:
i-robot 2022-11-23 09:59:39 +00:00 committed by Gitee
commit 22827c8b3a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 19 additions and 17 deletions

View File

@ -872,7 +872,7 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
}
}
input_tensor_info->input_tensors_mask.emplace_back(
(is_value_node || !is_forward_output) ? kValueNodeTensorMask : kParameterDataTensorMask);
(is_value_node && !is_forward_output) ? kValueNodeTensorMask : kParameterDataTensorMask);
} else if (real_input->isa<Parameter>()) {
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask

View File

@ -496,16 +496,22 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co
}
}
}
void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, const session::BackendOpRunInfoPtr &op_run_info) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(op_run_info);
const auto &kernels = kernel_graph->execution_order();
for (const auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
if (common::AnfAlgo::GetCNodeName(kernel) == op_run_info->base_op_run_info.op_name) {
op_run_info->base_op_run_info.abstract = kernel->abstract();
}
void UpdateOutputAbstract(const VectorRef &outputs, const session::BackendOpRunInfoPtr &op_run_info) {
auto output_size = outputs.size();
if (output_size == 1) {
auto output_tensor = utils::cast<tensor::TensorPtr>(outputs[0]);
MS_EXCEPTION_IF_NULL(output_tensor);
op_run_info->base_op_run_info.abstract = output_tensor->ToAbstract();
return;
}
AbstractBasePtrList elements;
for (size_t i = 0; i < output_size; ++i) {
auto output_tensor = utils::cast<tensor::TensorPtr>(outputs[i]);
MS_EXCEPTION_IF_NULL(output_tensor);
(void)elements.emplace_back(output_tensor->ToAbstract());
}
op_run_info->base_op_run_info.abstract = std::make_shared<abstract::AbstractTuple>(elements);
}
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index) {
@ -1250,7 +1256,7 @@ void MindRTBackend::RunOpImpl(bool single_op_cache_hit, const OpCompilerInfoPtr
ClearInputDeviceAddress(graph, device_context);
if (is_dynamic_shape) {
UpdateOutputAbstract(graph, op_run_info);
UpdateOutputAbstract(*outputs, op_run_info);
}
if (op_compiler_info->need_erase_) {
EraseSingleOpCache(op_run_info->base_op_run_info.graph_info);
@ -1289,7 +1295,7 @@ void MindRTBackend::RunOpImplDynamic(bool single_op_cache_hit, const OpCompilerI
ClearGraphDeviceAddressDynamic(graph, device_context, op_run_info->is_gradient_out);
ClearInputDeviceAddressDynamic(graph, device_context);
if (is_dynamic_shape) {
UpdateOutputAbstract(graph, op_run_info);
UpdateOutputAbstract(*outputs, op_run_info);
}
if (op_compiler_info->need_erase_) {
EraseSingleOpCache(op_run_info->base_op_run_info.graph_info);

View File

@ -127,12 +127,8 @@ void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) {
auto output_index = output_pair.second;
auto &input_node = input_pair.first;
auto input_node_output_index = input_pair.second;
auto input_addr = AnfAlgo::GetMutableOutputAddr(input_node, input_node_output_index, false);
auto ref_node_output_addr = AnfAlgo::GetMutableOutputAddr(ref_node, output_index, false);
if (input_addr != ref_node_output_addr) {
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
}
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
}
}