forked from mindspore-Ecosystem/mindspore
!19467 Fix return cell grad bug
Merge pull request !19467 from zjun/fix_return_cell
This commit is contained in:
commit
d1c67aa241
|
@ -2102,9 +2102,20 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
|
||||||
// get input node and value
|
// get input node and value
|
||||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||||
ValuePtrList input_args;
|
ValuePtrList input_args;
|
||||||
|
std::vector<int> value_index;
|
||||||
for (size_t i = 0; i < out_tuple.size(); i++) {
|
for (size_t i = 0; i < out_tuple.size(); i++) {
|
||||||
|
auto v = parse::data_converter::PyDataToValue(out_tuple[i]);
|
||||||
|
// Graph have no define for grad
|
||||||
|
if (v->isa<FuncGraph>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
value_index.emplace_back(i);
|
||||||
|
input_args.emplace_back(v);
|
||||||
inputs.emplace_back(GetInput(out_tuple[i], false));
|
inputs.emplace_back(GetInput(out_tuple[i], false));
|
||||||
input_args.emplace_back(parse::data_converter::PyDataToValue(out_tuple[i]));
|
}
|
||||||
|
py::tuple value_outs(value_index.size());
|
||||||
|
for (size_t i = 0; i < value_index.size(); ++i) {
|
||||||
|
value_outs[i] = out_tuple[value_index[i]];
|
||||||
}
|
}
|
||||||
auto cnode = curr_g_->NewCNode(inputs);
|
auto cnode = curr_g_->NewCNode(inputs);
|
||||||
MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString();
|
MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString();
|
||||||
|
@ -2116,7 +2127,7 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// run ad for maketuple node
|
// run ad for maketuple node
|
||||||
ValuePtr out_value = parse::data_converter::PyDataToValue(out);
|
ValuePtr out_value = parse::data_converter::PyDataToValue(value_outs);
|
||||||
ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, out_value);
|
ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, out_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,24 @@ def _ones_like_sparse_tensor(x):
|
||||||
return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x))
|
return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x))
|
||||||
|
|
||||||
|
|
||||||
|
newenv = base.EnvInstance_()
|
||||||
|
|
||||||
|
|
||||||
|
@ones_like_leaf.register("Function")
|
||||||
|
def _ones_like_func(x):
|
||||||
|
"""
|
||||||
|
Derivation of a function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Function): x
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EnvInstance_, value is newenv.
|
||||||
|
"""
|
||||||
|
# Unused parameters are placeholders.
|
||||||
|
return newenv
|
||||||
|
|
||||||
|
|
||||||
ones_like = base.HyperMap(ones_like_leaf)
|
ones_like = base.HyperMap(ones_like_leaf)
|
||||||
"""
|
"""
|
||||||
`ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.
|
`ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.
|
||||||
|
|
Loading…
Reference in New Issue