From d2f97fdc879604baaa4530b93b36029127d07ae3 Mon Sep 17 00:00:00 2001 From: zjun Date: Tue, 6 Jul 2021 16:20:32 +0800 Subject: [PATCH] Ignore funcgraph grad Signed-off-by: zjun --- .../pipeline/pynative/pynative_execute.cc | 15 +++++++++++++-- .../composite/multitype_ops/ones_like_impl.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 0b9902c6e58..eb42fc16546 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -2102,9 +2102,20 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co // get input node and value std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; ValuePtrList input_args; + std::vector value_index; for (size_t i = 0; i < out_tuple.size(); i++) { + auto v = parse::data_converter::PyDataToValue(out_tuple[i]); + // Graph have no define for grad + if (v->isa()) { + continue; + } + value_index.emplace_back(i); + input_args.emplace_back(v); inputs.emplace_back(GetInput(out_tuple[i], false)); - input_args.emplace_back(parse::data_converter::PyDataToValue(out_tuple[i])); + } + py::tuple value_outs(value_index.size()); + for (size_t i = 0; i < value_index.size(); ++i) { + value_outs[i] = out_tuple[value_index[i]]; } auto cnode = curr_g_->NewCNode(inputs); MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString(); @@ -2116,7 +2127,7 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co return; } // run ad for maketuple node - ValuePtr out_value = parse::data_converter::PyDataToValue(out); + ValuePtr out_value = parse::data_converter::PyDataToValue(value_outs); ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, out_value); } diff --git a/mindspore/ops/composite/multitype_ops/ones_like_impl.py b/mindspore/ops/composite/multitype_ops/ones_like_impl.py index 1325cbce0a8..54b2a31b51b 100644 --- a/mindspore/ops/composite/multitype_ops/ones_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/ones_like_impl.py @@ -58,6 +58,24 @@ def _ones_like_sparse_tensor(x): return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x)) +newenv = base.EnvInstance_() + + +@ones_like_leaf.register("Function") +def _ones_like_func(x): + """ + Derivation of a function. + + Args: + x (Function): x + + Returns: + EnvInstance_, value is newenv. + """ + # Unused parameters are placeholders. + return newenv + + ones_like = base.HyperMap(ones_like_leaf) """ `ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.