forked from mindspore-Ecosystem/mindspore
modify CheckpointStrategy to adapt load operator
This commit is contained in:
parent
52fac12367
commit
0bbd95d7a0
|
@ -2701,14 +2701,13 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n
|
|||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return param_names;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
if (prim->name() == CAST && cnode->inputs().size() >= 1) {
|
||||
auto cast_input = cnode->inputs()[1];
|
||||
if (cast_input->isa<Parameter>()) {
|
||||
auto cast_input_parameter = cast_input->cast<ParameterPtr>();
|
||||
if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) {
|
||||
param_names.push_back({cast_input_parameter->name(), i});
|
||||
if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
|
||||
auto inp = cnode->input(1);
|
||||
if (inp->isa<Parameter>()) {
|
||||
auto inp_param = inp->cast<ParameterPtr>();
|
||||
if (inp_param->has_default() && ParameterRequireGrad(inp_param)) {
|
||||
param_names.push_back({inp_param->name(), i});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue