!21782 fix python 3.8 args error

Merge pull request !21782 from chujinjin/fix_python_3.8_args_error
This commit is contained in:
i-robot 2021-08-13 08:20:21 +00:00 committed by Gitee
commit 4b0918c64b
2 changed files with 4 additions and 4 deletions

View File

@ -347,7 +347,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::v
return graph_info;
}
py::args FilterTensorArgs(const py::args &args, bool has_sens = false) {
py::list FilterTensorArgs(const py::args &args, bool has_sens = false) {
size_t size = args.size();
if (size == 0 && has_sens) {
MS_LOG(EXCEPTION) << "The size of args is 0, when the flag of sens is set to True";
@ -1943,7 +1943,7 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
}
// Convert input args to parameters for top cell graph in construct.
std::vector<ValuePtr> input_param_values;
py::args only_tensors = FilterTensorArgs(args);
py::list only_tensors = FilterTensorArgs(args);
auto df_builder = GetDfbuilder(top_cell_->cell_id());
MS_EXCEPTION_IF_NULL(df_builder);
for (size_t i = 0; i < only_tensors.size(); ++i) {
@ -2369,7 +2369,7 @@ std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights,
return w_args;
}
abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph) {
abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph) {
MS_EXCEPTION_IF_NULL(bprop_graph);
std::size_t size = args.size();
abstract::AbstractBasePtrList args_spec;

View File

@ -233,7 +233,7 @@ class GradExecutor {
FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph);
abstract::AbstractBasePtrList GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph);
// Manage resource for construct forward graph.
std::string &graph_phase() { return graph_phase_; }
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);