modify CheckpointStrategy to adapt load operator

This commit is contained in:
huangbingjian 2021-02-25 11:57:49 +08:00
parent 52fac12367
commit 0bbd95d7a0
1 changed files with 7 additions and 8 deletions

View File

@ -2701,14 +2701,13 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n
if (!IsValueNode<Primitive>(cnode->input(0))) {
return param_names;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (prim->name() == CAST && cnode->inputs().size() >= 1) {
auto cast_input = cnode->inputs()[1];
if (cast_input->isa<Parameter>()) {
auto cast_input_parameter = cast_input->cast<ParameterPtr>();
if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) {
param_names.push_back({cast_input_parameter->name(), i});
if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) ||
IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
auto inp = cnode->input(1);
if (inp->isa<Parameter>()) {
auto inp_param = inp->cast<ParameterPtr>();
if (inp_param->has_default() && ParameterRequireGrad(inp_param)) {
param_names.push_back({inp_param->name(), i});
}
}
}