From 28ca6bb0873893721c73d51fb7bf00ec8061b47a Mon Sep 17 00:00:00 2001 From: zjun Date: Mon, 21 Jun 2021 15:51:04 +0800 Subject: [PATCH] Fix sens ref Signed-off-by: zjun --- .../pipeline/pynative/pynative_execute.cc | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 297b4f08f1d..26fd137f204 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -2355,26 +2355,30 @@ abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, co // update abstract info for input params ValuePtr input_value = parse::data_converter::PyDataToValue(args[index]); MS_EXCEPTION_IF_NULL(input_value); - auto abs = abstract::FromValue(input_value, true); + auto input_abs = abstract::FromValue(input_value, true); if (param_node->abstract() != nullptr) { - auto input_shape = param_node->abstract()->Broaden()->BuildShape()->ToString(); - auto ir_shape = abs->BuildShape()->ToString(); + auto input_shape = input_abs->BuildShape()->ToString(); + auto param_tensor_abs = param_node->abstract(); + if (param_tensor_abs->isa()) { + param_tensor_abs = param_tensor_abs->cast()->CloneAsTensor(); + } + auto ir_shape = param_tensor_abs->BuildShape()->ToString(); // Exclude const input if (input_shape != "()" && ir_shape != "()") { if (input_shape != ir_shape) { - MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << " ," + MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << ", " << param->DebugString(); } - auto input_dtype = param_node->abstract()->BuildType()->ToString(); - auto ir_dtype = abs->BuildType()->ToString(); + auto ir_dtype = param_tensor_abs->BuildType()->ToString(); + auto input_dtype = input_abs->BuildType()->ToString(); if (input_dtype != ir_dtype) { - MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << " ," + MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << ", " << param->DebugString(); } } } - args_spec.emplace_back(abs); - param_node->set_abstract(abs->Broaden()); + args_spec.emplace_back(input_abs); + param_node->set_abstract(input_abs->Broaden()); index++; } }