!19467 Fix return cell grad bug
Merge pull request !19467 from zjun/fix_return_cell
This commit is contained in:
commit
d1c67aa241
|
@ -2102,9 +2102,20 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
|
|||
// get input node and value
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
ValuePtrList input_args;
|
||||
std::vector<int> value_index;
|
||||
for (size_t i = 0; i < out_tuple.size(); i++) {
|
||||
auto v = parse::data_converter::PyDataToValue(out_tuple[i]);
|
||||
// Graph have no define for grad
|
||||
if (v->isa<FuncGraph>()) {
|
||||
continue;
|
||||
}
|
||||
value_index.emplace_back(i);
|
||||
input_args.emplace_back(v);
|
||||
inputs.emplace_back(GetInput(out_tuple[i], false));
|
||||
input_args.emplace_back(parse::data_converter::PyDataToValue(out_tuple[i]));
|
||||
}
|
||||
py::tuple value_outs(value_index.size());
|
||||
for (size_t i = 0; i < value_index.size(); ++i) {
|
||||
value_outs[i] = out_tuple[value_index[i]];
|
||||
}
|
||||
auto cnode = curr_g_->NewCNode(inputs);
|
||||
MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString();
|
||||
|
@ -2116,7 +2127,7 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
|
|||
return;
|
||||
}
|
||||
// run ad for maketuple node
|
||||
ValuePtr out_value = parse::data_converter::PyDataToValue(out);
|
||||
ValuePtr out_value = parse::data_converter::PyDataToValue(value_outs);
|
||||
ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, out_value);
|
||||
}
|
||||
|
||||
|
|
|
@ -58,6 +58,24 @@ def _ones_like_sparse_tensor(x):
|
|||
return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x))
|
||||
|
||||
|
||||
newenv = base.EnvInstance_()
|
||||
|
||||
|
||||
@ones_like_leaf.register("Function")
|
||||
def _ones_like_func(x):
|
||||
"""
|
||||
Derivation of a function.
|
||||
|
||||
Args:
|
||||
x (Function): x
|
||||
|
||||
Returns:
|
||||
EnvInstance_, value is newenv.
|
||||
"""
|
||||
# Unused parameters are placeholders.
|
||||
return newenv
|
||||
|
||||
|
||||
ones_like = base.HyperMap(ones_like_leaf)
|
||||
"""
|
||||
`ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.
|
||||
|
|
Loading…
Reference in New Issue