Associate func_graph in CellList to manager

This commit is contained in:
huangbingjian 2022-07-28 14:51:56 +08:00
parent 89e3a499b1
commit f8d7ed29e0
5 changed files with 52 additions and 74 deletions

View File

@ -73,7 +73,6 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_IS_CELL_LIST[] = "is_cell_list";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids";

View File

@ -319,35 +319,21 @@ bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, A
return true;
}
bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
const ValuePtr &value, AnfNodePtr *const transformed) {
MS_EXCEPTION_IF_NULL(value);
const auto &value_vec = GetValue<ValuePtrList>(value);
if (value_vec.empty()) {
return false;
}
for (auto &elem : value_vec) {
MS_EXCEPTION_IF_NULL(elem);
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
const auto &vec = GetValue<ValuePtrList>(elem);
auto is_graph = IsAllFuncInValueSequence(vec);
if (!is_graph) {
return false;
}
} else if (!elem->isa<FuncGraph>() && !elem->isa<Primitive>()) {
return false;
}
}
return true;
}
AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
const std::vector<ValuePtr> &value_vec) {
std::vector<AnfNodePtr> nodes;
nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
bool is_all_func = true;
for (auto &elem : value_vec) {
MS_EXCEPTION_IF_NULL(elem);
AnfNodePtr node = nullptr;
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
node = TransformToMakeTupleNodes(manager, func_graph, vec);
is_all_func = is_all_func && TransformVectorFuncValueNode(manager, func_graph, elem, &node);
} else if (elem->isa<FuncGraph>()) {
FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
manager->AddFuncGraph(new_fg);
@ -355,34 +341,20 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
} else if (elem->isa<Primitive>()) {
node = NewValueNode(elem);
} else {
MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString();
is_all_func = false;
}
nodes.emplace_back(node);
(void)nodes.emplace_back(node);
}
auto cnode = func_graph->NewCNode(std::move(nodes));
return cnode;
}
// Transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
MS_EXCEPTION_IF_NULL(value_node);
const auto &value_vec = GetValue<ValuePtrList>(value_node->value());
if (!IsAllFuncInValueSequence(value_vec)) {
return false;
if (is_all_func) {
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
// So if has graph in list, try to replace the node with make tuple of graph value node.
// We do this because the graph manager won't investigate the graph inside valuetuple,
// change the vector of graph to be make_tuple of graph value node.
// (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
// independent nodes.
*transformed = func_graph->NewCNode(std::move(nodes));
}
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
// So if has graph in list, try to replace the node with make tuple of graph value node.
// We do this because the graph manager won't investigate the graph inside valuetuple,
// change the vector of graph to be make_tuple of graph value node.
// (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
// independent nodes.
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
// Replace the ret ptr to be make tuple of graph value node
*transformed = node_tuple_graphs;
return true;
return is_all_func;
}
// Resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager.
@ -402,8 +374,8 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
// If the constant node is constant of vector of graph, add graph to manager.
if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
(void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
&resolved_node);
auto value = resolved_node->cast<ValueNodePtr>()->value();
(void)TransformVectorFuncValueNode(manager, node->func_graph(), value, &resolved_node);
}
return resolved_node;
}

View File

@ -23,7 +23,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, check_obj_bool, is_cell_list, python_isinstance, ms_isinstance)
is_class_type, check_obj_bool, python_isinstance, ms_isinstance)
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
@ -32,4 +32,4 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'is_class_type', 'check_obj_bool', 'is_cell_list', 'python_isinstance', 'ms_isinstance']
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance']

View File

@ -452,11 +452,6 @@ def ms_isinstance(x, cmp_type):
return isinstance(x, pytype_to_mstype.get(cmp_type))
def is_cell_list(obj):
"""Check if obj is nn.CellList"""
return isinstance(obj, nn.CellList)
def get_module_namespace(obj):
"""Get the module's namespace."""
logger.debug("get module namespace, module: %r", obj)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,23 +15,15 @@
""" test_celllist """
import numpy as np
from mindspore import Tensor, Model
from mindspore import context
from mindspore.nn import AvgPool2d
from mindspore.nn import Cell
from mindspore.nn import Flatten
from mindspore.nn import ReLU
from mindspore.nn import SequentialCell
from mindspore import context, nn, Tensor, Model, ParameterTuple
from mindspore import dtype as mstype
from ...ut_filter import non_graph_engine
# pylint: disable=W0212
class Net3(Cell):
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.tuple = (ReLU(), ReLU())
self.tuple = (nn.ReLU(), nn.ReLU())
def construct(self, x):
for op in self.tuple:
@ -43,18 +35,38 @@ class Net3(Cell):
def test_cell_list():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net = Net3()
net = Net()
context.set_context(mode=context.GRAPH_MODE)
model = Model(net)
model.predict(input_me)
class SequenceNet(Cell):
class CellListNet(nn.Cell):
def __init__(self):
super().__init__()
self.seq = SequentialCell([AvgPool2d(3, 1), ReLU(), Flatten()])
self.values = list(self.seq._cells.values())
self.all = nn.CellList([nn.Conv2d(120, 240, 4, has_bias=False,
weight_init=Tensor(np.ones([240, 120, 4, 4]), mstype.float32)),
nn.Conv2d(240, 480, 4, has_bias=False,
weight_init=Tensor(np.ones([480, 240, 4, 4]), mstype.float32))])
self.params = ParameterTuple(self.get_parameters())
self.weight_list = [(240, 120, 4, 4), (480, 240, 4, 4)]
self.info = [self.all, self.params, self.weight_list]
def construct(self, x):
x = self.seq(x)
return x
func = None
conv, params, weight_list = self.info
for _, (_conv, _, _weight_list) in enumerate(zip(conv, params, weight_list)):
if _weight_list[0] == 240:
func = _conv
out = func(x)
return out
def test_cell_list_zip():
"""
Feature: nn.CellList
Description: Fix the problem of no manager for this func graph when using nn.CellList.
Expectation: No exception.
"""
x = Tensor(np.ones([1, 120, 1024, 640]), mstype.float32)
CellListNet()(x)