forked from mindspore-Ecosystem/mindspore
!3048 support use valuelist or valuetuple of primitives
Merge pull request !3048 from amongo/SupportPrimitiveList
This commit is contained in:
commit
45ad430af2
|
@ -168,15 +168,15 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsAllGraphInValueSequence(const std::vector<ValuePtr> &value_vec) {
|
||||
bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
|
||||
for (auto &elem : value_vec) {
|
||||
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
|
||||
const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
|
||||
auto is_graph = IsAllGraphInValueSequence(vec);
|
||||
auto is_graph = IsAllFuncInValueSequence(vec);
|
||||
if (!is_graph) {
|
||||
return false;
|
||||
}
|
||||
} else if (!elem->isa<FuncGraph>()) {
|
||||
} else if (!elem->isa<FuncGraph>() && !elem->isa<Primitive>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -196,6 +196,8 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
|
|||
FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
|
||||
manager->AddFuncGraph(new_fg);
|
||||
node = NewValueNode(new_fg);
|
||||
} else if (elem->isa<Primitive>()) {
|
||||
node = NewValueNode(elem);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString();
|
||||
}
|
||||
|
@ -205,19 +207,21 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
|
|||
return cnode;
|
||||
}
|
||||
|
||||
// transform the ValueTuple or ValueList of graph node to make tuple of const graph node
|
||||
bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
|
||||
// transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node
|
||||
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
|
||||
const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value());
|
||||
if (!IsAllGraphInValueSequence(value_vec)) {
|
||||
if (!IsAllFuncInValueSequence(value_vec)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
|
||||
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
|
||||
// So if has graph in list, try to replace the node with make tuple of graph value node.
|
||||
// we do this because the graphmanger won't investigate the graph inside valuetuple,
|
||||
// change the vector of graph to be make_tuple of graph value node
|
||||
// change the vector of graph to be make_tuple of graph value node.
|
||||
// (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all
|
||||
// independent nodes.
|
||||
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
|
||||
// replace the ret ptr to be make tuple of graph value node
|
||||
*transformed = node_tuple_graphs;
|
||||
|
@ -251,7 +255,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
|
|||
|
||||
// if the constant node is constant of vector of graph ,add graph to manager
|
||||
if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
|
||||
(void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
|
||||
(void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
|
||||
&resolved_node);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,9 @@ from dataclasses import dataclass
|
|||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, Model, context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn import ReLU
|
||||
from ...ut_filter import non_graph_engine
|
||||
|
@ -66,3 +69,58 @@ def function_access_base(number):
|
|||
def test_access_0040():
|
||||
""" test_access_0040 """
|
||||
function_access_base(2)
|
||||
|
||||
|
||||
class OpSeqNet(Cell):
|
||||
def __init__(self, loop_count=1):
|
||||
super().__init__()
|
||||
self.loop_count = loop_count
|
||||
self.op_seq = (P.Sqrt(), P.Reciprocal(), P.Square())
|
||||
|
||||
def construct(self, x):
|
||||
t = x
|
||||
for op in self.op_seq:
|
||||
t = op(t)
|
||||
return t
|
||||
|
||||
|
||||
def test_op_seq_test():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = OpSeqNet()
|
||||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_me = Tensor(input_np)
|
||||
net(input_me)
|
||||
|
||||
|
||||
_grad_fusion = C.MultitypeFuncGraph("grad_fushion")
|
||||
|
||||
|
||||
@_grad_fusion.register("Tensor", "Function")
|
||||
def tensor_grad_scale(x, op):
|
||||
return op(x)
|
||||
|
||||
|
||||
class AllReduceTest(Cell):
|
||||
def __init__(self, loop_count=1):
|
||||
super().__init__()
|
||||
self.op_list = ()
|
||||
self.fushion_flag = [0, 1, 1, 0, 1, 0]
|
||||
for i in self.fushion_flag:
|
||||
op = P.AllReduce().add_prim_attr('fusion', i)
|
||||
self.op_list = self.op_list + (op,)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, x):
|
||||
ret = ()
|
||||
for _ in self.fushion_flag:
|
||||
ret = ret + (x,)
|
||||
fushion_res = self.hyper_map(F.partial(_grad_fusion), ret, self.op_list)
|
||||
return fushion_res
|
||||
|
||||
|
||||
def test_allreduce_fushio_test():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = AllReduceTest()
|
||||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_me = Tensor(input_np)
|
||||
net(input_me)
|
||||
|
|
Loading…
Reference in New Issue