forked from mindspore-Ecosystem/mindspore
!21782 fix python 3.8 args error
Merge pull request !21782 from chujinjin/fix_python_3.8_args_error
This commit is contained in:
commit
4b0918c64b
|
@ -347,7 +347,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::v
|
|||
return graph_info;
|
||||
}
|
||||
|
||||
py::args FilterTensorArgs(const py::args &args, bool has_sens = false) {
|
||||
py::list FilterTensorArgs(const py::args &args, bool has_sens = false) {
|
||||
size_t size = args.size();
|
||||
if (size == 0 && has_sens) {
|
||||
MS_LOG(EXCEPTION) << "The size of args is 0, when the flag of sens is set to True";
|
||||
|
@ -1943,7 +1943,7 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
|
|||
}
|
||||
// Convert input args to parameters for top cell graph in construct.
|
||||
std::vector<ValuePtr> input_param_values;
|
||||
py::args only_tensors = FilterTensorArgs(args);
|
||||
py::list only_tensors = FilterTensorArgs(args);
|
||||
auto df_builder = GetDfbuilder(top_cell_->cell_id());
|
||||
MS_EXCEPTION_IF_NULL(df_builder);
|
||||
for (size_t i = 0; i < only_tensors.size(); ++i) {
|
||||
|
@ -2369,7 +2369,7 @@ std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights,
|
|||
return w_args;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph) {
|
||||
abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph) {
|
||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
std::size_t size = args.size();
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
|
|
|
@ -233,7 +233,7 @@ class GradExecutor {
|
|||
FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph);
|
||||
// Manage resource for construct forward graph.
|
||||
std::string &graph_phase() { return graph_phase_; }
|
||||
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
|
||||
|
|
Loading…
Reference in New Issue