!3048 support use valuelist or valuetuple of primitives

Merge pull request !3048 from amongo/SupportPrimitiveList
This commit is contained in:
mindspore-ci-bot 2020-07-15 10:45:36 +08:00 committed by Gitee
commit 45ad430af2
2 changed files with 73 additions and 11 deletions

View File

@ -168,15 +168,15 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
return true;
}
bool IsAllGraphInValueSequence(const std::vector<ValuePtr> &value_vec) {
bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
for (auto &elem : value_vec) {
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
auto is_graph = IsAllGraphInValueSequence(vec);
auto is_graph = IsAllFuncInValueSequence(vec);
if (!is_graph) {
return false;
}
} else if (!elem->isa<FuncGraph>()) {
} else if (!elem->isa<FuncGraph>() && !elem->isa<Primitive>()) {
return false;
}
}
@ -196,6 +196,8 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
manager->AddFuncGraph(new_fg);
node = NewValueNode(new_fg);
} else if (elem->isa<Primitive>()) {
node = NewValueNode(elem);
} else {
MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString();
}
@ -205,19 +207,21 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
return cnode;
}
// transform the ValueTuple or ValueList of graph node to make tuple of const graph node
bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
// transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
MS_EXCEPTION_IF_NULL(value_node);
const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value());
if (!IsAllGraphInValueSequence(value_vec)) {
if (!IsAllFuncInValueSequence(value_vec)) {
return false;
}
// The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
// So if has graph in list, try to replace the node with make tuple of graph value node.
// we do this because the graphmanger won't investigate the graph inside valuetuple,
// change the vector of graph to be make_tuple of graph value node
// change the vector of graph to be make_tuple of graph value node.
// (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all
// independent nodes.
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
// replace the ret ptr to be make tuple of graph value node
*transformed = node_tuple_graphs;
@ -251,7 +255,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
// if the constant node is constant of vector of graph ,add graph to manager
if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
(void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
(void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
&resolved_node);
}

View File

@ -17,6 +17,9 @@ from dataclasses import dataclass
import numpy as np
from mindspore import Tensor, Model, context
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.nn import Cell
from mindspore.nn import ReLU
from ...ut_filter import non_graph_engine
@ -66,3 +69,58 @@ def function_access_base(number):
def test_access_0040():
""" test_access_0040 """
function_access_base(2)
class OpSeqNet(Cell):
def __init__(self, loop_count=1):
super().__init__()
self.loop_count = loop_count
self.op_seq = (P.Sqrt(), P.Reciprocal(), P.Square())
def construct(self, x):
t = x
for op in self.op_seq:
t = op(t)
return t
def test_op_seq_test():
context.set_context(mode=context.GRAPH_MODE)
net = OpSeqNet()
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net(input_me)
_grad_fusion = C.MultitypeFuncGraph("grad_fushion")
@_grad_fusion.register("Tensor", "Function")
def tensor_grad_scale(x, op):
return op(x)
class AllReduceTest(Cell):
def __init__(self, loop_count=1):
super().__init__()
self.op_list = ()
self.fushion_flag = [0, 1, 1, 0, 1, 0]
for i in self.fushion_flag:
op = P.AllReduce().add_prim_attr('fusion', i)
self.op_list = self.op_list + (op,)
self.hyper_map = C.HyperMap()
def construct(self, x):
ret = ()
for _ in self.fushion_flag:
ret = ret + (x,)
fushion_res = self.hyper_map(F.partial(_grad_fusion), ret, self.op_list)
return fushion_res
def test_allreduce_fushio_test():
context.set_context(mode=context.GRAPH_MODE)
net = AllReduceTest()
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net(input_me)