forked from mindspore-Ecosystem/mindspore
Associate func_graph in CellList to manager
This commit is contained in:
parent
89e3a499b1
commit
f8d7ed29e0
|
@ -73,7 +73,6 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
|
|||
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
|
||||
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
||||
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
|
||||
const char PYTHON_MOD_IS_CELL_LIST[] = "is_cell_list";
|
||||
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
|
||||
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
|
||||
const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids";
|
||||
|
|
|
@ -319,35 +319,21 @@ bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, A
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
|
||||
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
|
||||
const ValuePtr &value, AnfNodePtr *const transformed) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto &value_vec = GetValue<ValuePtrList>(value);
|
||||
if (value_vec.empty()) {
|
||||
return false;
|
||||
}
|
||||
for (auto &elem : value_vec) {
|
||||
MS_EXCEPTION_IF_NULL(elem);
|
||||
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
|
||||
const auto &vec = GetValue<ValuePtrList>(elem);
|
||||
auto is_graph = IsAllFuncInValueSequence(vec);
|
||||
if (!is_graph) {
|
||||
return false;
|
||||
}
|
||||
} else if (!elem->isa<FuncGraph>() && !elem->isa<Primitive>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
|
||||
const std::vector<ValuePtr> &value_vec) {
|
||||
std::vector<AnfNodePtr> nodes;
|
||||
nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
bool is_all_func = true;
|
||||
for (auto &elem : value_vec) {
|
||||
MS_EXCEPTION_IF_NULL(elem);
|
||||
AnfNodePtr node = nullptr;
|
||||
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
|
||||
const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
|
||||
node = TransformToMakeTupleNodes(manager, func_graph, vec);
|
||||
is_all_func = is_all_func && TransformVectorFuncValueNode(manager, func_graph, elem, &node);
|
||||
} else if (elem->isa<FuncGraph>()) {
|
||||
FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
|
||||
manager->AddFuncGraph(new_fg);
|
||||
|
@ -355,34 +341,20 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
|
|||
} else if (elem->isa<Primitive>()) {
|
||||
node = NewValueNode(elem);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString();
|
||||
is_all_func = false;
|
||||
}
|
||||
nodes.emplace_back(node);
|
||||
(void)nodes.emplace_back(node);
|
||||
}
|
||||
auto cnode = func_graph->NewCNode(std::move(nodes));
|
||||
return cnode;
|
||||
}
|
||||
|
||||
// Transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node
|
||||
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
|
||||
const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const auto &value_vec = GetValue<ValuePtrList>(value_node->value());
|
||||
if (!IsAllFuncInValueSequence(value_vec)) {
|
||||
return false;
|
||||
if (is_all_func) {
|
||||
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
|
||||
// So if has graph in list, try to replace the node with make tuple of graph value node.
|
||||
// We do this because the graph manager won't investigate the graph inside valuetuple,
|
||||
// change the vector of graph to be make_tuple of graph value node.
|
||||
// (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
|
||||
// independent nodes.
|
||||
*transformed = func_graph->NewCNode(std::move(nodes));
|
||||
}
|
||||
|
||||
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
|
||||
// So if has graph in list, try to replace the node with make tuple of graph value node.
|
||||
// We do this because the graph manager won't investigate the graph inside valuetuple,
|
||||
// change the vector of graph to be make_tuple of graph value node.
|
||||
// (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
|
||||
// independent nodes.
|
||||
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
|
||||
// Replace the ret ptr to be make tuple of graph value node
|
||||
*transformed = node_tuple_graphs;
|
||||
|
||||
return true;
|
||||
return is_all_func;
|
||||
}
|
||||
|
||||
// Resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager.
|
||||
|
@ -402,8 +374,8 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
|
|||
|
||||
// If the constant node is constant of vector of graph, add graph to manager.
|
||||
if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
|
||||
(void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
|
||||
&resolved_node);
|
||||
auto value = resolved_node->cast<ValueNodePtr>()->value();
|
||||
(void)TransformVectorFuncValueNode(manager, node->func_graph(), value, &resolved_node);
|
||||
}
|
||||
return resolved_node;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
|
||||
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
|
||||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
is_class_type, check_obj_bool, is_cell_list, python_isinstance, ms_isinstance)
|
||||
is_class_type, check_obj_bool, python_isinstance, ms_isinstance)
|
||||
|
||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||
|
@ -32,4 +32,4 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
|
|||
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
|
||||
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
|
||||
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'is_class_type', 'check_obj_bool', 'is_cell_list', 'python_isinstance', 'ms_isinstance']
|
||||
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance']
|
||||
|
|
|
@ -452,11 +452,6 @@ def ms_isinstance(x, cmp_type):
|
|||
return isinstance(x, pytype_to_mstype.get(cmp_type))
|
||||
|
||||
|
||||
def is_cell_list(obj):
|
||||
"""Check if obj is nn.CellList"""
|
||||
return isinstance(obj, nn.CellList)
|
||||
|
||||
|
||||
def get_module_namespace(obj):
|
||||
"""Get the module's namespace."""
|
||||
logger.debug("get module namespace, module: %r", obj)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,23 +15,15 @@
|
|||
""" test_celllist """
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, Model
|
||||
from mindspore import context
|
||||
from mindspore.nn import AvgPool2d
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn import Flatten
|
||||
from mindspore.nn import ReLU
|
||||
from mindspore.nn import SequentialCell
|
||||
from mindspore import context, nn, Tensor, Model, ParameterTuple
|
||||
from mindspore import dtype as mstype
|
||||
from ...ut_filter import non_graph_engine
|
||||
|
||||
|
||||
# pylint: disable=W0212
|
||||
|
||||
|
||||
class Net3(Cell):
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tuple = (ReLU(), ReLU())
|
||||
self.tuple = (nn.ReLU(), nn.ReLU())
|
||||
|
||||
def construct(self, x):
|
||||
for op in self.tuple:
|
||||
|
@ -43,18 +35,38 @@ class Net3(Cell):
|
|||
def test_cell_list():
|
||||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_me = Tensor(input_np)
|
||||
net = Net3()
|
||||
net = Net()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
model = Model(net)
|
||||
model.predict(input_me)
|
||||
|
||||
|
||||
class SequenceNet(Cell):
|
||||
class CellListNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seq = SequentialCell([AvgPool2d(3, 1), ReLU(), Flatten()])
|
||||
self.values = list(self.seq._cells.values())
|
||||
self.all = nn.CellList([nn.Conv2d(120, 240, 4, has_bias=False,
|
||||
weight_init=Tensor(np.ones([240, 120, 4, 4]), mstype.float32)),
|
||||
nn.Conv2d(240, 480, 4, has_bias=False,
|
||||
weight_init=Tensor(np.ones([480, 240, 4, 4]), mstype.float32))])
|
||||
self.params = ParameterTuple(self.get_parameters())
|
||||
self.weight_list = [(240, 120, 4, 4), (480, 240, 4, 4)]
|
||||
self.info = [self.all, self.params, self.weight_list]
|
||||
|
||||
def construct(self, x):
|
||||
x = self.seq(x)
|
||||
return x
|
||||
func = None
|
||||
conv, params, weight_list = self.info
|
||||
for _, (_conv, _, _weight_list) in enumerate(zip(conv, params, weight_list)):
|
||||
if _weight_list[0] == 240:
|
||||
func = _conv
|
||||
out = func(x)
|
||||
return out
|
||||
|
||||
|
||||
def test_cell_list_zip():
|
||||
"""
|
||||
Feature: nn.CellList
|
||||
Description: Fix the problem of no manager for this func graph when using nn.CellList.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.ones([1, 120, 1024, 640]), mstype.float32)
|
||||
CellListNet()(x)
|
||||
|
|
Loading…
Reference in New Issue