forked from mindspore-Ecosystem/mindspore
!21782 fix python 3.8 args error
Merge pull request !21782 from chujinjin/fix_python_3.8_args_error
This commit is contained in:
commit
4b0918c64b
|
@ -347,7 +347,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::v
|
||||||
return graph_info;
|
return graph_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
py::args FilterTensorArgs(const py::args &args, bool has_sens = false) {
|
py::list FilterTensorArgs(const py::args &args, bool has_sens = false) {
|
||||||
size_t size = args.size();
|
size_t size = args.size();
|
||||||
if (size == 0 && has_sens) {
|
if (size == 0 && has_sens) {
|
||||||
MS_LOG(EXCEPTION) << "The size of args is 0, when the flag of sens is set to True";
|
MS_LOG(EXCEPTION) << "The size of args is 0, when the flag of sens is set to True";
|
||||||
|
@ -1943,7 +1943,7 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
|
||||||
}
|
}
|
||||||
// Convert input args to parameters for top cell graph in construct.
|
// Convert input args to parameters for top cell graph in construct.
|
||||||
std::vector<ValuePtr> input_param_values;
|
std::vector<ValuePtr> input_param_values;
|
||||||
py::args only_tensors = FilterTensorArgs(args);
|
py::list only_tensors = FilterTensorArgs(args);
|
||||||
auto df_builder = GetDfbuilder(top_cell_->cell_id());
|
auto df_builder = GetDfbuilder(top_cell_->cell_id());
|
||||||
MS_EXCEPTION_IF_NULL(df_builder);
|
MS_EXCEPTION_IF_NULL(df_builder);
|
||||||
for (size_t i = 0; i < only_tensors.size(); ++i) {
|
for (size_t i = 0; i < only_tensors.size(); ++i) {
|
||||||
|
@ -2369,7 +2369,7 @@ std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights,
|
||||||
return w_args;
|
return w_args;
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph) {
|
abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||||
std::size_t size = args.size();
|
std::size_t size = args.size();
|
||||||
abstract::AbstractBasePtrList args_spec;
|
abstract::AbstractBasePtrList args_spec;
|
||||||
|
|
|
@ -233,7 +233,7 @@ class GradExecutor {
|
||||||
FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
|
FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
|
||||||
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
|
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
|
||||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph);
|
abstract::AbstractBasePtrList GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph);
|
||||||
// Manage resource for construct forward graph.
|
// Manage resource for construct forward graph.
|
||||||
std::string &graph_phase() { return graph_phase_; }
|
std::string &graph_phase() { return graph_phase_; }
|
||||||
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
|
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
|
||||||
|
|
Loading…
Reference in New Issue