Fix sens ref

Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
zjun 2021-06-21 15:51:04 +08:00
parent a8553078b9
commit 28ca6bb087
1 changed files with 13 additions and 9 deletions

View File

@ -2355,26 +2355,30 @@ abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, co
// update abstract info for input params
ValuePtr input_value = parse::data_converter::PyDataToValue(args[index]);
MS_EXCEPTION_IF_NULL(input_value);
auto abs = abstract::FromValue(input_value, true);
auto input_abs = abstract::FromValue(input_value, true);
if (param_node->abstract() != nullptr) {
auto input_shape = param_node->abstract()->Broaden()->BuildShape()->ToString();
auto ir_shape = abs->BuildShape()->ToString();
auto input_shape = input_abs->BuildShape()->ToString();
auto param_tensor_abs = param_node->abstract();
if (param_tensor_abs->isa<abstract::AbstractRef>()) {
param_tensor_abs = param_tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
}
auto ir_shape = param_tensor_abs->BuildShape()->ToString();
// Exclude const input
if (input_shape != "()" && ir_shape != "()") {
if (input_shape != ir_shape) {
MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << " ,"
MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << ", "
<< param->DebugString();
}
auto input_dtype = param_node->abstract()->BuildType()->ToString();
auto ir_dtype = abs->BuildType()->ToString();
auto ir_dtype = param_tensor_abs->BuildType()->ToString();
auto input_dtype = input_abs->BuildType()->ToString();
if (input_dtype != ir_dtype) {
MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << " ,"
MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << ", "
<< param->DebugString();
}
}
}
args_spec.emplace_back(abs);
param_node->set_abstract(abs->Broaden());
args_spec.emplace_back(input_abs);
param_node->set_abstract(input_abs->Broaden());
index++;
}
}