forked from mindspore-Ecosystem/mindspore
parent
a8553078b9
commit
28ca6bb087
|
@ -2355,26 +2355,30 @@ abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, co
|
|||
// update abstract info for input params
|
||||
ValuePtr input_value = parse::data_converter::PyDataToValue(args[index]);
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
auto abs = abstract::FromValue(input_value, true);
|
||||
auto input_abs = abstract::FromValue(input_value, true);
|
||||
if (param_node->abstract() != nullptr) {
|
||||
auto input_shape = param_node->abstract()->Broaden()->BuildShape()->ToString();
|
||||
auto ir_shape = abs->BuildShape()->ToString();
|
||||
auto input_shape = input_abs->BuildShape()->ToString();
|
||||
auto param_tensor_abs = param_node->abstract();
|
||||
if (param_tensor_abs->isa<abstract::AbstractRef>()) {
|
||||
param_tensor_abs = param_tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
|
||||
}
|
||||
auto ir_shape = param_tensor_abs->BuildShape()->ToString();
|
||||
// Exclude const input
|
||||
if (input_shape != "()" && ir_shape != "()") {
|
||||
if (input_shape != ir_shape) {
|
||||
MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << " ,"
|
||||
MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << ", "
|
||||
<< param->DebugString();
|
||||
}
|
||||
auto input_dtype = param_node->abstract()->BuildType()->ToString();
|
||||
auto ir_dtype = abs->BuildType()->ToString();
|
||||
auto ir_dtype = param_tensor_abs->BuildType()->ToString();
|
||||
auto input_dtype = input_abs->BuildType()->ToString();
|
||||
if (input_dtype != ir_dtype) {
|
||||
MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << " ,"
|
||||
MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << ", "
|
||||
<< param->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
args_spec.emplace_back(abs);
|
||||
param_node->set_abstract(abs->Broaden());
|
||||
args_spec.emplace_back(input_abs);
|
||||
param_node->set_abstract(input_abs->Broaden());
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue