!19467 Fix return cell grad bug

Merge pull request !19467 from zjun/fix_return_cell
This commit is contained in:
i-robot 2021-07-06 13:30:34 +00:00 committed by Gitee
commit d1c67aa241
2 changed files with 31 additions and 2 deletions

View File

@ -2102,9 +2102,20 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
// get input node and value
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
ValuePtrList input_args;
std::vector<int> value_index;
for (size_t i = 0; i < out_tuple.size(); i++) {
auto v = parse::data_converter::PyDataToValue(out_tuple[i]);
// Graph have no define for grad
if (v->isa<FuncGraph>()) {
continue;
}
value_index.emplace_back(i);
input_args.emplace_back(v);
inputs.emplace_back(GetInput(out_tuple[i], false));
input_args.emplace_back(parse::data_converter::PyDataToValue(out_tuple[i]));
}
py::tuple value_outs(value_index.size());
for (size_t i = 0; i < value_index.size(); ++i) {
value_outs[i] = out_tuple[value_index[i]];
}
auto cnode = curr_g_->NewCNode(inputs);
MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString();
@ -2116,7 +2127,7 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
return;
}
// run ad for maketuple node
ValuePtr out_value = parse::data_converter::PyDataToValue(out);
ValuePtr out_value = parse::data_converter::PyDataToValue(value_outs);
ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, out_value);
}

View File

@ -58,6 +58,24 @@ def _ones_like_sparse_tensor(x):
return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x))
newenv = base.EnvInstance_()
@ones_like_leaf.register("Function")
def _ones_like_func(x):
"""
Derivation of a function.
Args:
x (Function): x
Returns:
EnvInstance_, value is newenv.
"""
# Unused parameters are placeholders.
return newenv
ones_like = base.HyperMap(ones_like_leaf)
"""
`ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.